Database Refactor

combine service/users data w/ mod unit

split sliding sync related out of service/users

instrument database entry points

remove increment crap from database interface

de-wrap all database get() calls

de-wrap all database insert() calls

de-wrap all database remove() calls

refactor database interface for async streaming

add query key serializer for database

implement Debug for result handle

add query deserializer for database

add deserialization trait for option handle

start a stream utils suite

de-wrap/asyncify/type-query count_one_time_keys()

de-wrap/asyncify users count

add admin query users command suite

de-wrap/asyncify users exists

de-wrap/partially asyncify user filter related

asyncify/de-wrap users device/keys related

asyncify/de-wrap user auth/misc related

asyncify/de-wrap users blurhash

asyncify/de-wrap account_data get; merge Data into Service

partial asyncify/de-wrap uiaa; merge Data into Service

partially asyncify/de-wrap transaction_ids get; merge Data into Service

partially asyncify/de-wrap key_backups; merge Data into Service

asyncify/de-wrap pusher service getters; merge Data into Service

asyncify/de-wrap rooms alias getters/some iterators

asyncify/de-wrap rooms directory getters/iterator

partially asyncify/de-wrap rooms lazy-loading

partially asyncify/de-wrap rooms metadata

asyncify/dewrap rooms outlier

asyncify/dewrap rooms pdu_metadata

dewrap/partially asyncify rooms read receipt

de-wrap rooms search service

de-wrap/partially asyncify rooms user service

partial de-wrap rooms state_compressor

de-wrap rooms state_cache

de-wrap room state et al

de-wrap rooms timeline service

additional users device/keys related

de-wrap/asyncify sender

asyncify services

refactor database to TryFuture/TryStream

refactor services for TryFuture/TryStream

asyncify api handlers

additional asyncification for admin module

abstract stream related; support reverse streams

additional stream conversions

asyncify state-res related

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-08-08 17:18:30 +00:00 committed by strawberry
parent 6001014078
commit 946ca364e0
203 changed files with 12202 additions and 10709 deletions

View file

@ -46,7 +46,7 @@ bytes.workspace = true
conduit-core.workspace = true
conduit-database.workspace = true
const-str.workspace = true
futures-util.workspace = true
futures.workspace = true
hickory-resolver.workspace = true
http.workspace = true
image.workspace = true

View file

@ -1,152 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use conduit::{Error, Result};
use database::Map;
use ruma::{
api::client::error::ErrorKind,
events::{AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, RoomAccountDataEventType},
serde::Raw,
RoomId, UserId,
};
use crate::{globals, Dep};
pub(super) struct Data {
roomuserdataid_accountdata: Arc<Map>,
roomusertype_roomuserdataid: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
}
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
roomuserdataid_accountdata: db["roomuserdataid_accountdata"].clone(),
roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
}
}
/// Places one event in the account data of the user and removes the
/// previous entry.
pub(super) fn update(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: &RoomAccountDataEventType,
data: &serde_json::Value,
) -> Result<()> {
let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF);
let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
roomuserdataid.push(0xFF);
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
let mut key = prefix;
key.extend_from_slice(event_type.to_string().as_bytes());
if data.get("type").is_none() || data.get("content").is_none() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Account data doesn't have all required fields.",
));
}
self.roomuserdataid_accountdata.insert(
&roomuserdataid,
&serde_json::to_vec(&data).expect("to_vec always works on json values"),
)?;
let prev = self.roomusertype_roomuserdataid.get(&key)?;
self.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid)?;
// Remove old entry
if let Some(prev) = prev {
self.roomuserdataid_accountdata.remove(&prev)?;
}
Ok(())
}
/// Searches the account data for a specific kind.
pub(super) fn get(
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: &RoomAccountDataEventType,
) -> Result<Option<Box<serde_json::value::RawValue>>> {
let mut key = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(kind.to_string().as_bytes());
self.roomusertype_roomuserdataid
.get(&key)?
.and_then(|roomuserdataid| {
self.roomuserdataid_accountdata
.get(&roomuserdataid)
.transpose()
})
.transpose()?
.map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize")))
.transpose()
}
/// Returns all changes to the account data that happened after `since`.
pub(super) fn changes_since(
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
) -> Result<Vec<AnyRawAccountDataEvent>> {
let mut userdata = HashMap::new();
let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF);
// Skip the data that's exactly at since, because we sent that last time
let mut first_possible = prefix.clone();
first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes());
for r in self
.roomuserdataid_accountdata
.iter_from(&first_possible, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(k, v)| {
Ok::<_, Error>((
k,
match room_id {
None => serde_json::from_slice::<Raw<AnyGlobalAccountDataEvent>>(&v)
.map(AnyRawAccountDataEvent::Global)
.map_err(|_| Error::bad_database("Database contains invalid account data."))?,
Some(_) => serde_json::from_slice::<Raw<AnyRoomAccountDataEvent>>(&v)
.map(AnyRawAccountDataEvent::Room)
.map_err(|_| Error::bad_database("Database contains invalid account data."))?,
},
))
}) {
let (kind, data) = r?;
userdata.insert(kind, data);
}
Ok(userdata.into_values().collect())
}
}

View file

@ -1,52 +1,158 @@
mod data;
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;
use conduit::Result;
use data::Data;
use conduit::{
implement,
utils::{stream::TryIgnore, ReadyExt},
Err, Error, Result,
};
use database::{Deserialized, Map};
use futures::{StreamExt, TryFutureExt};
use ruma::{
events::{AnyRawAccountDataEvent, RoomAccountDataEventType},
events::{AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, RoomAccountDataEventType},
serde::Raw,
RoomId, UserId,
};
use serde_json::value::RawValue;
use crate::{globals, Dep};
pub struct Service {
services: Services,
db: Data,
}
struct Data {
roomuserdataid_accountdata: Arc<Map>,
roomusertype_roomuserdataid: Arc<Map>,
}
struct Services {
globals: Dep<globals::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(&args),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
db: Data {
roomuserdataid_accountdata: args.db["roomuserdataid_accountdata"].clone(),
roomusertype_roomuserdataid: args.db["roomusertype_roomuserdataid"].clone(),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Places one event in the account data of the user and removes the
/// previous entry.
#[allow(clippy::needless_pass_by_value)]
pub fn update(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
data: &serde_json::Value,
) -> Result<()> {
self.db.update(room_id, user_id, &event_type, data)
/// Places one event in the account data of the user and removes the
/// previous entry.
#[allow(clippy::needless_pass_by_value)]
#[implement(Service)]
pub async fn update(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value,
) -> Result<()> {
let event_type = event_type.to_string();
let count = self.services.globals.next_count()?;
let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF);
let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&count.to_be_bytes());
roomuserdataid.push(0xFF);
roomuserdataid.extend_from_slice(event_type.as_bytes());
let mut key = prefix;
key.extend_from_slice(event_type.as_bytes());
if data.get("type").is_none() || data.get("content").is_none() {
return Err!(Request(InvalidParam("Account data doesn't have all required fields.")));
}
/// Searches the account data for a specific kind.
#[allow(clippy::needless_pass_by_value)]
pub fn get(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
) -> Result<Option<Box<serde_json::value::RawValue>>> {
self.db.get(room_id, user_id, &event_type)
self.db.roomuserdataid_accountdata.insert(
&roomuserdataid,
&serde_json::to_vec(&data).expect("to_vec always works on json values"),
);
let prev_key = (room_id, user_id, &event_type);
let prev = self.db.roomusertype_roomuserdataid.qry(&prev_key).await;
self.db
.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid);
// Remove old entry
if let Ok(prev) = prev {
self.db.roomuserdataid_accountdata.remove(&prev);
}
/// Returns all changes to the account data that happened after `since`.
#[tracing::instrument(skip_all, name = "since", level = "debug")]
pub fn changes_since(
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
) -> Result<Vec<AnyRawAccountDataEvent>> {
self.db.changes_since(room_id, user_id, since)
}
Ok(())
}
/// Searches the account data for a specific kind.
#[implement(Service)]
pub async fn get(
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
) -> Result<Box<RawValue>> {
let key = (room_id, user_id, kind.to_string());
self.db
.roomusertype_roomuserdataid
.qry(&key)
.and_then(|roomuserdataid| self.db.roomuserdataid_accountdata.qry(&roomuserdataid))
.await
.deserialized_json()
}
/// Returns all changes to the account data that happened after `since`.
#[implement(Service)]
pub async fn changes_since(
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
) -> Result<Vec<AnyRawAccountDataEvent>> {
let mut userdata = HashMap::new();
let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF);
// Skip the data that's exactly at since, because we sent that last time
let mut first_possible = prefix.clone();
first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes());
self.db
.roomuserdataid_accountdata
.raw_stream_from(&first_possible)
.ignore_err()
.ready_take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(k, v)| {
let v = match room_id {
None => serde_json::from_slice::<Raw<AnyGlobalAccountDataEvent>>(v)
.map(AnyRawAccountDataEvent::Global)
.map_err(|_| Error::bad_database("Database contains invalid account data."))?,
Some(_) => serde_json::from_slice::<Raw<AnyRoomAccountDataEvent>>(v)
.map(AnyRawAccountDataEvent::Room)
.map_err(|_| Error::bad_database("Database contains invalid account data."))?,
};
Ok((k.to_owned(), v))
})
.ignore_err()
.ready_for_each(|(kind, data)| {
userdata.insert(kind, data);
})
.await;
Ok(userdata.into_values().collect())
}

View file

@ -5,7 +5,7 @@ use std::{
};
use conduit::{debug, defer, error, log, Server};
use futures_util::future::{AbortHandle, Abortable};
use futures::future::{AbortHandle, Abortable};
use ruma::events::room::message::RoomMessageEventContent;
use rustyline_async::{Readline, ReadlineError, ReadlineEvent};
use termimad::MadSkin;

View file

@ -30,7 +30,7 @@ use crate::Services;
pub async fn create_admin_room(services: &Services) -> Result<()> {
let room_id = RoomId::new(services.globals.server_name());
let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?;
let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id);
let state_lock = services.rooms.state.mutex.lock(&room_id).await;

View file

@ -17,108 +17,108 @@ use serde_json::value::to_raw_value;
use crate::pdu::PduBuilder;
impl super::Service {
/// Invite the user to the conduit admin room.
///
/// In conduit, this is equivalent to granting admin privileges.
pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> {
let Some(room_id) = self.get_admin_room()? else {
return Ok(());
};
/// Invite the user to the conduit admin room.
///
/// In conduit, this is equivalent to granting admin privileges.
#[implement(super::Service)]
pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> {
let Ok(room_id) = self.get_admin_room().await else {
return Ok(());
};
let state_lock = self.services.state.mutex.lock(&room_id).await;
let state_lock = self.services.state.mutex.lock(&room_id).await;
// Use the server user to grant the new admin's power level
let server_user = &self.services.globals.server_user;
// Use the server user to grant the new admin's power level
let server_user = &self.services.globals.server_user;
// Invite and join the real user
self.services
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Invite,
displayname: None,
avatar_url: None,
is_direct: None,
third_party_invite: None,
blurhash: None,
reason: None,
join_authorized_via_users_server: None,
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(user_id.to_string()),
redacts: None,
timestamp: None,
},
server_user,
&room_id,
&state_lock,
)
.await?;
self.services
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join,
displayname: None,
avatar_url: None,
is_direct: None,
third_party_invite: None,
blurhash: None,
reason: None,
join_authorized_via_users_server: None,
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(user_id.to_string()),
redacts: None,
timestamp: None,
},
user_id,
&room_id,
&state_lock,
)
.await?;
// Invite and join the real user
self.services
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Invite,
displayname: None,
avatar_url: None,
is_direct: None,
third_party_invite: None,
blurhash: None,
reason: None,
join_authorized_via_users_server: None,
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(user_id.to_string()),
redacts: None,
timestamp: None,
},
server_user,
&room_id,
&state_lock,
)
.await?;
self.services
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join,
displayname: None,
avatar_url: None,
is_direct: None,
third_party_invite: None,
blurhash: None,
reason: None,
join_authorized_via_users_server: None,
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(user_id.to_string()),
redacts: None,
timestamp: None,
},
user_id,
&room_id,
&state_lock,
)
.await?;
// Set power level
let users = BTreeMap::from_iter([(server_user.clone(), 100.into()), (user_id.to_owned(), 100.into())]);
// Set power level
let users = BTreeMap::from_iter([(server_user.clone(), 100.into()), (user_id.to_owned(), 100.into())]);
self.services
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomPowerLevels,
content: to_raw_value(&RoomPowerLevelsEventContent {
users,
..Default::default()
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(String::new()),
redacts: None,
timestamp: None,
},
server_user,
&room_id,
&state_lock,
)
.await?;
self.services
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomPowerLevels,
content: to_raw_value(&RoomPowerLevelsEventContent {
users,
..Default::default()
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(String::new()),
redacts: None,
timestamp: None,
},
server_user,
&room_id,
&state_lock,
)
.await?;
// Set room tag
let room_tag = &self.services.server.config.admin_room_tag;
if !room_tag.is_empty() {
if let Err(e) = self.set_room_tag(&room_id, user_id, room_tag) {
error!(?room_id, ?user_id, ?room_tag, ?e, "Failed to set tag for admin grant");
}
// Set room tag
let room_tag = &self.services.server.config.admin_room_tag;
if !room_tag.is_empty() {
if let Err(e) = self.set_room_tag(&room_id, user_id, room_tag).await {
error!(?room_id, ?user_id, ?room_tag, ?e, "Failed to set tag for admin grant");
}
}
// Send welcome message
self.services.timeline.build_and_append_pdu(
// Send welcome message
self.services.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&RoomMessageEventContent::text_markdown(
@ -135,19 +135,18 @@ impl super::Service {
&state_lock,
).await?;
Ok(())
}
Ok(())
}
#[implement(super::Service)]
fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result<()> {
async fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result<()> {
let mut event = self
.services
.account_data
.get(Some(room_id), user_id, RoomAccountDataEventType::Tag)?
.map(|event| serde_json::from_str(event.get()))
.and_then(Result::ok)
.unwrap_or_else(|| TagEvent {
.get(Some(room_id), user_id, RoomAccountDataEventType::Tag)
.await
.and_then(|event| serde_json::from_str(event.get()).map_err(Into::into))
.unwrap_or_else(|_| TagEvent {
content: TagEventContent {
tags: BTreeMap::new(),
},
@ -158,12 +157,15 @@ fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result<
.tags
.insert(tag.to_owned().into(), TagInfo::new());
self.services.account_data.update(
Some(room_id),
user_id,
RoomAccountDataEventType::Tag,
&serde_json::to_value(event)?,
)?;
self.services
.account_data
.update(
Some(room_id),
user_id,
RoomAccountDataEventType::Tag,
&serde_json::to_value(event)?,
)
.await?;
Ok(())
}

View file

@ -12,6 +12,7 @@ use std::{
use async_trait::async_trait;
use conduit::{debug, err, error, error::default_log, pdu::PduBuilder, Error, PduEvent, Result, Server};
pub use create::create_admin_room;
use futures::{FutureExt, TryFutureExt};
use loole::{Receiver, Sender};
use ruma::{
events::{
@ -142,17 +143,18 @@ impl Service {
/// admin room as the admin user.
pub async fn send_text(&self, body: &str) {
self.send_message(RoomMessageEventContent::text_markdown(body))
.await;
.await
.ok();
}
/// Sends a message to the admin room as the admin user (see send_text() for
/// convenience).
pub async fn send_message(&self, message_content: RoomMessageEventContent) {
if let Ok(Some(room_id)) = self.get_admin_room() {
let user_id = &self.services.globals.server_user;
self.respond_to_room(message_content, &room_id, user_id)
.await;
}
pub async fn send_message(&self, message_content: RoomMessageEventContent) -> Result<()> {
let user_id = &self.services.globals.server_user;
let room_id = self.get_admin_room().await?;
self.respond_to_room(message_content, &room_id, user_id)
.boxed()
.await
}
/// Posts a command to the command processor queue and returns. Processing
@ -193,8 +195,12 @@ impl Service {
async fn handle_command(&self, command: CommandInput) {
match self.process_command(command).await {
Ok(Some(output)) | Err(output) => self.handle_response(output).await,
Ok(None) => debug!("Command successful with no response"),
Ok(Some(output)) | Err(output) => self
.handle_response(output)
.boxed()
.await
.unwrap_or_else(default_log),
}
}
@ -218,71 +224,67 @@ impl Service {
}
/// Checks whether a given user is an admin of this server
pub async fn user_is_admin(&self, user_id: &UserId) -> Result<bool> {
if let Ok(Some(admin_room)) = self.get_admin_room() {
self.services.state_cache.is_joined(user_id, &admin_room)
} else {
Ok(false)
}
pub async fn user_is_admin(&self, user_id: &UserId) -> bool {
let Ok(admin_room) = self.get_admin_room().await else {
return false;
};
self.services
.state_cache
.is_joined(user_id, &admin_room)
.await
}
/// Gets the room ID of the admin room
///
/// Errors are propagated from the database, and will have None if there is
/// no admin room
pub fn get_admin_room(&self) -> Result<Option<OwnedRoomId>> {
if let Some(room_id) = self
pub async fn get_admin_room(&self) -> Result<OwnedRoomId> {
let room_id = self
.services
.alias
.resolve_local_alias(&self.services.globals.admin_alias)?
{
if self
.services
.state_cache
.is_joined(&self.services.globals.server_user, &room_id)?
{
return Ok(Some(room_id));
}
}
.resolve_local_alias(&self.services.globals.admin_alias)
.await?;
Ok(None)
self.services
.state_cache
.is_joined(&self.services.globals.server_user, &room_id)
.await
.then_some(room_id)
.ok_or_else(|| err!(Request(NotFound("Admin user not joined to admin room"))))
}
async fn handle_response(&self, content: RoomMessageEventContent) {
async fn handle_response(&self, content: RoomMessageEventContent) -> Result<()> {
let Some(Relation::Reply {
in_reply_to,
}) = content.relates_to.as_ref()
else {
return;
return Ok(());
};
let Ok(Some(pdu)) = self.services.timeline.get_pdu(&in_reply_to.event_id) else {
let Ok(pdu) = self.services.timeline.get_pdu(&in_reply_to.event_id).await else {
error!(
event_id = ?in_reply_to.event_id,
"Missing admin command in_reply_to event"
);
return;
return Ok(());
};
let response_sender = if self.is_admin_room(&pdu.room_id) {
let response_sender = if self.is_admin_room(&pdu.room_id).await {
&self.services.globals.server_user
} else {
&pdu.sender
};
self.respond_to_room(content, &pdu.room_id, response_sender)
.await;
.await
}
async fn respond_to_room(&self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId) {
assert!(
self.user_is_admin(user_id)
.await
.expect("checked user is admin"),
"sender is not admin"
);
async fn respond_to_room(
&self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId,
) -> Result<()> {
assert!(self.user_is_admin(user_id).await, "sender is not admin");
let state_lock = self.services.state.mutex.lock(room_id).await;
let response_pdu = PduBuilder {
event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&content).expect("event is valid, we just created it"),
@ -292,6 +294,7 @@ impl Service {
timestamp: None,
};
let state_lock = self.services.state.mutex.lock(room_id).await;
if let Err(e) = self
.services
.timeline
@ -302,6 +305,8 @@ impl Service {
.await
.unwrap_or_else(default_log);
}
Ok(())
}
async fn handle_response_error(
@ -355,12 +360,12 @@ impl Service {
}
// Prevent unescaped !admin from being used outside of the admin room
if is_public_prefix && !self.is_admin_room(&pdu.room_id) {
if is_public_prefix && !self.is_admin_room(&pdu.room_id).await {
return false;
}
// Only senders who are admin can proceed
if !self.user_is_admin(&pdu.sender).await.unwrap_or(false) {
if !self.user_is_admin(&pdu.sender).await {
return false;
}
@ -368,7 +373,7 @@ impl Service {
// the administrator can execute commands as conduit
let emergency_password_set = self.services.globals.emergency_password().is_some();
let from_server = pdu.sender == *server_user && !emergency_password_set;
if from_server && self.is_admin_room(&pdu.room_id) {
if from_server && self.is_admin_room(&pdu.room_id).await {
return false;
}
@ -377,12 +382,11 @@ impl Service {
}
#[must_use]
pub fn is_admin_room(&self, room_id: &RoomId) -> bool {
if let Ok(Some(admin_room_id)) = self.get_admin_room() {
admin_room_id == room_id
} else {
false
}
pub async fn is_admin_room(&self, room_id_: &RoomId) -> bool {
self.get_admin_room()
.map_ok(|room_id| room_id == room_id_)
.await
.unwrap_or(false)
}
/// Sets the self-reference to crate::Services which will provide context to

View file

@ -1,7 +1,8 @@
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::{Database, Map};
use conduit::{err, utils::stream::TryIgnore, Result};
use database::{Database, Deserialized, Map};
use futures::Stream;
use ruma::api::appservice::Registration;
pub struct Data {
@ -19,7 +20,7 @@ impl Data {
pub(super) fn register_appservice(&self, yaml: &Registration) -> Result<String> {
let id = yaml.id.as_str();
self.id_appserviceregistrations
.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes());
Ok(id.to_owned())
}
@ -31,24 +32,19 @@ impl Data {
/// * `service_name` - the name you send to register the service previously
pub(super) fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.id_appserviceregistrations
.remove(service_name.as_bytes())?;
.remove(service_name.as_bytes());
Ok(())
}
pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
pub async fn get_registration(&self, id: &str) -> Result<Registration> {
self.id_appserviceregistrations
.get(id.as_bytes())?
.map(|bytes| {
serde_yaml::from_slice(&bytes)
.map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations."))
})
.transpose()
.qry(id)
.await
.deserialized_json()
.map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}")))
}
pub(super) fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
utils::string_from_bytes(&id)
.map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
})))
pub(super) fn iter_ids(&self) -> impl Stream<Item = String> + Send + '_ {
self.id_appserviceregistrations.keys().ignore_err()
}
}

View file

@ -2,9 +2,10 @@ mod data;
use std::{collections::BTreeMap, sync::Arc};
use async_trait::async_trait;
use conduit::{err, Result};
use data::Data;
use futures_util::Future;
use futures::{Future, StreamExt, TryStreamExt};
use regex::RegexSet;
use ruma::{
api::appservice::{Namespace, Registration},
@ -126,13 +127,22 @@ struct Services {
sending: Dep<sending::Service>,
}
#[async_trait]
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let mut registration_info = BTreeMap::new();
let db = Data::new(args.db);
Ok(Arc::new(Self {
db: Data::new(args.db),
services: Services {
sending: args.depend::<sending::Service>("sending"),
},
registration_info: RwLock::new(BTreeMap::new()),
}))
}
async fn worker(self: Arc<Self>) -> Result<()> {
// Inserting registrations into cache
for appservice in iter_ids(&db)? {
registration_info.insert(
for appservice in iter_ids(&self.db).await? {
self.registration_info.write().await.insert(
appservice.0,
appservice
.1
@ -141,13 +151,7 @@ impl crate::Service for Service {
);
}
Ok(Arc::new(Self {
db,
services: Services {
sending: args.depend::<sending::Service>("sending"),
},
registration_info: RwLock::new(registration_info),
}))
Ok(())
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
@ -155,7 +159,7 @@ impl crate::Service for Service {
impl Service {
#[inline]
pub fn all(&self) -> Result<Vec<(String, Registration)>> { iter_ids(&self.db) }
pub async fn all(&self) -> Result<Vec<(String, Registration)>> { iter_ids(&self.db).await }
/// Registers an appservice and returns the ID to the caller
pub async fn register_appservice(&self, yaml: Registration) -> Result<String> {
@ -188,7 +192,8 @@ impl Service {
// sending to the URL
self.services
.sending
.cleanup_events(service_name.to_owned())?;
.cleanup_events(service_name.to_owned())
.await;
Ok(())
}
@ -251,15 +256,9 @@ impl Service {
}
}
fn iter_ids(db: &Data) -> Result<Vec<(String, Registration)>> {
db.iter_ids()?
.filter_map(Result::ok)
.map(move |id| {
Ok((
id.clone(),
db.get_registration(&id)?
.expect("iter_ids only returns appservices that exist"),
))
})
.collect()
async fn iter_ids(db: &Data) -> Result<Vec<(String, Registration)>> {
db.iter_ids()
.then(|id| async move { Ok((id.clone(), db.get_registration(&id).await?)) })
.try_collect()
.await
}

View file

@ -33,6 +33,7 @@ impl crate::Service for Service {
async fn worker(self: Arc<Self>) -> Result<()> {
self.set_emergency_access()
.await
.inspect_err(|e| error!("Could not set the configured emergency password for the conduit user: {e}"))?;
Ok(())
@ -44,7 +45,7 @@ impl crate::Service for Service {
impl Service {
/// Sets the emergency password and push rules for the @conduit account in
/// case emergency password is set
fn set_emergency_access(&self) -> Result<bool> {
async fn set_emergency_access(&self) -> Result<bool> {
let conduit_user = &self.services.globals.server_user;
self.services
@ -56,17 +57,20 @@ impl Service {
None => (Ruleset::new(), false),
};
self.services.account_data.update(
None,
conduit_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)?;
self.services
.account_data
.update(
None,
conduit_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)
.await?;
if pwd_set {
warn!(
@ -75,7 +79,7 @@ impl Service {
);
} else {
// logs out any users still in the server service account and removes sessions
self.services.users.deactivate_account(conduit_user)?;
self.services.users.deactivate_account(conduit_user).await?;
}
Ok(pwd_set)

View file

@ -4,8 +4,8 @@ use std::{
};
use conduit::{trace, utils, Error, Result, Server};
use database::{Database, Map};
use futures_util::{stream::FuturesUnordered, StreamExt};
use database::{Database, Deserialized, Map};
use futures::{pin_mut, stream::FuturesUnordered, FutureExt, StreamExt};
use ruma::{
api::federation::discovery::{ServerSigningKeys, VerifyKey},
signatures::Ed25519KeyPair,
@ -83,7 +83,7 @@ impl Data {
.checked_add(1)
.expect("counter must not overflow u64");
self.global.insert(COUNTER, &counter.to_be_bytes())?;
self.global.insert(COUNTER, &counter.to_be_bytes());
Ok(*counter)
}
@ -102,7 +102,7 @@ impl Data {
fn stored_count(global: &Arc<Map>) -> Result<u64> {
global
.get(COUNTER)?
.get(COUNTER)
.as_deref()
.map_or(Ok(0_u64), utils::u64_from_bytes)
}
@ -133,36 +133,18 @@ impl Data {
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
// Events for rooms we are in
for room_id in self
.services
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
let short_roomid = self
.services
.short
.get_shortroomid(&room_id)
.ok()
.flatten()
.expect("room exists")
.to_be_bytes()
.to_vec();
let rooms_joined = self.services.state_cache.rooms_joined(user_id);
pin_mut!(rooms_joined);
while let Some(room_id) = rooms_joined.next().await {
let Ok(short_roomid) = self.services.short.get_shortroomid(room_id).await else {
continue;
};
let roomid_bytes = room_id.as_bytes().to_vec();
let mut roomid_prefix = roomid_bytes.clone();
roomid_prefix.push(0xFF);
// PDUs
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
// EDUs
futures.push(Box::pin(async move {
let _result = self.services.typing.wait_for_update(&room_id).await;
}));
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
// Key changes
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
@ -174,6 +156,19 @@ impl Data {
self.roomusertype_roomuserdataid
.watch_prefix(&roomuser_prefix),
);
// PDUs
let short_roomid = short_roomid.to_be_bytes().to_vec();
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
// EDUs
let typing_room_id = room_id.to_owned();
let typing_wait_for_update = async move {
self.services.typing.wait_for_update(&typing_room_id).await;
};
futures.push(typing_wait_for_update.boxed());
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
}
let mut globaluserdata_prefix = vec![0xFF];
@ -190,12 +185,14 @@ impl Data {
// One time keys
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
futures.push(Box::pin(async move {
// Server shutdown
let server_shutdown = async move {
while self.services.server.running() {
let _result = self.services.server.signal.subscribe().recv().await;
self.services.server.signal.subscribe().recv().await.ok();
}
}));
};
futures.push(server_shutdown.boxed());
if !self.services.server.running() {
return Ok(());
}
@ -209,10 +206,10 @@ impl Data {
}
pub fn load_keypair(&self) -> Result<Ed25519KeyPair> {
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|| {
let keypair_bytes = self.global.get(b"keypair").map_or_else(
|_| {
let keypair = utils::generate_keypair();
self.global.insert(b"keypair", &keypair)?;
self.global.insert(b"keypair", &keypair);
Ok::<_, Error>(keypair)
},
|val| Ok(val.to_vec()),
@ -241,7 +238,10 @@ impl Data {
}
#[inline]
pub fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
pub fn remove_keypair(&self) -> Result<()> {
self.global.remove(b"keypair");
Ok(())
}
/// TODO: the key valid until timestamp (`valid_until_ts`) is only honored
/// in room version > 4
@ -250,15 +250,15 @@ impl Data {
///
/// This doesn't actually check that the keys provided are newer than the
/// old set.
pub fn add_signing_key(
pub async fn add_signing_key(
&self, origin: &ServerName, new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
) -> BTreeMap<OwnedServerSigningKeyId, VerifyKey> {
// Not atomic, but this is not critical
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
let signingkeys = self.server_signingkeys.qry(origin).await;
let mut keys = signingkeys
.and_then(|keys| serde_json::from_slice(&keys).ok())
.unwrap_or_else(|| {
.and_then(|keys| serde_json::from_slice(&keys).map_err(Into::into))
.unwrap_or_else(|_| {
// Just insert "now", it doesn't matter
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
});
@ -275,7 +275,7 @@ impl Data {
self.server_signingkeys.insert(
origin.as_bytes(),
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
)?;
);
let mut tree = keys.verify_keys;
tree.extend(
@ -284,45 +284,38 @@ impl Data {
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
Ok(tree)
tree
}
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
pub fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let signingkeys = self
.signing_keys_for(origin)?
.map_or_else(BTreeMap::new, |keys: ServerSigningKeys| {
pub async fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
self.signing_keys_for(origin).await.map_or_else(
|_| Ok(BTreeMap::new()),
|keys: ServerSigningKeys| {
let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
tree
});
Ok(signingkeys)
Ok(tree)
},
)
}
pub fn signing_keys_for(&self, origin: &ServerName) -> Result<Option<ServerSigningKeys>> {
let signingkeys = self
.server_signingkeys
.get(origin.as_bytes())?
.and_then(|bytes| serde_json::from_slice(&bytes).ok());
Ok(signingkeys)
pub async fn signing_keys_for(&self, origin: &ServerName) -> Result<ServerSigningKeys> {
self.server_signingkeys
.qry(origin)
.await
.deserialized_json()
}
pub fn database_version(&self) -> Result<u64> {
self.global.get(b"version")?.map_or(Ok(0), |version| {
utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
})
}
pub async fn database_version(&self) -> u64 { self.global.qry("version").await.deserialized().unwrap_or(0) }
#[inline]
pub fn bump_database_version(&self, new_version: u64) -> Result<()> {
self.global.insert(b"version", &new_version.to_be_bytes())?;
self.global.insert(b"version", &new_version.to_be_bytes());
Ok(())
}

View file

@ -1,17 +1,15 @@
use std::{
collections::{HashMap, HashSet},
fs::{self},
io::Write,
mem::size_of,
sync::Arc,
use conduit::{
debug_info, debug_warn, error, info,
result::NotFound,
utils::{stream::TryIgnore, IterStream, ReadyExt},
warn, Err, Error, Result,
};
use conduit::{debug, debug_info, debug_warn, error, info, utils, warn, Error, Result};
use futures::{FutureExt, StreamExt};
use itertools::Itertools;
use ruma::{
events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType},
push::Ruleset,
EventId, OwnedRoomId, RoomId, UserId,
UserId,
};
use crate::{media, Services};
@ -33,12 +31,14 @@ pub(crate) const DATABASE_VERSION: u64 = 13;
pub(crate) const CONDUIT_DATABASE_VERSION: u64 = 16;
pub(crate) async fn migrations(services: &Services) -> Result<()> {
let users_count = services.users.count().await;
// Matrix resource ownership is based on the server name; changing it
// requires recreating the database from scratch.
if services.users.count()? > 0 {
if users_count > 0 {
let conduit_user = &services.globals.server_user;
if !services.users.exists(conduit_user)? {
if !services.users.exists(conduit_user).await {
error!("The {} server user does not exist, and the database is not new.", conduit_user);
return Err(Error::bad_database(
"Cannot reuse an existing database after changing the server name, please delete the old one first.",
@ -46,7 +46,7 @@ pub(crate) async fn migrations(services: &Services) -> Result<()> {
}
}
if services.users.count()? > 0 {
if users_count > 0 {
migrate(services).await
} else {
fresh(services).await
@ -62,9 +62,9 @@ async fn fresh(services: &Services) -> Result<()> {
.db
.bump_database_version(DATABASE_VERSION)?;
db["global"].insert(b"feat_sha256_media", &[])?;
db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[])?;
db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?;
db["global"].insert(b"feat_sha256_media", &[]);
db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[]);
db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[]);
// Create the admin room and server user on first run
crate::admin::create_admin_room(services).await?;
@ -82,136 +82,109 @@ async fn migrate(services: &Services) -> Result<()> {
let db = &services.db;
let config = &services.server.config;
if services.globals.db.database_version()? < 1 {
db_lt_1(services).await?;
if services.globals.db.database_version().await < 11 {
return Err!(Database(
"Database schema version {} is no longer supported",
services.globals.db.database_version().await
));
}
if services.globals.db.database_version()? < 2 {
db_lt_2(services).await?;
}
if services.globals.db.database_version()? < 3 {
db_lt_3(services).await?;
}
if services.globals.db.database_version()? < 4 {
db_lt_4(services).await?;
}
if services.globals.db.database_version()? < 5 {
db_lt_5(services).await?;
}
if services.globals.db.database_version()? < 6 {
db_lt_6(services).await?;
}
if services.globals.db.database_version()? < 7 {
db_lt_7(services).await?;
}
if services.globals.db.database_version()? < 8 {
db_lt_8(services).await?;
}
if services.globals.db.database_version()? < 9 {
db_lt_9(services).await?;
}
if services.globals.db.database_version()? < 10 {
db_lt_10(services).await?;
}
if services.globals.db.database_version()? < 11 {
db_lt_11(services).await?;
}
if services.globals.db.database_version()? < 12 {
if services.globals.db.database_version().await < 12 {
db_lt_12(services).await?;
}
// This migration can be reused as-is anytime the server-default rules are
// updated.
if services.globals.db.database_version()? < 13 {
if services.globals.db.database_version().await < 13 {
db_lt_13(services).await?;
}
if db["global"].get(b"feat_sha256_media")?.is_none() {
if db["global"].qry("feat_sha256_media").await.is_not_found() {
media::migrations::migrate_sha256_media(services).await?;
} else if config.media_startup_check {
media::migrations::checkup_sha256_media(services).await?;
}
if db["global"]
.get(b"fix_bad_double_separator_in_state_cache")?
.is_none()
.qry("fix_bad_double_separator_in_state_cache")
.await
.is_not_found()
{
fix_bad_double_separator_in_state_cache(services).await?;
}
if db["global"]
.get(b"retroactively_fix_bad_data_from_roomuserid_joined")?
.is_none()
.qry("retroactively_fix_bad_data_from_roomuserid_joined")
.await
.is_not_found()
{
retroactively_fix_bad_data_from_roomuserid_joined(services).await?;
}
let version_match = services.globals.db.database_version().unwrap() == DATABASE_VERSION
|| services.globals.db.database_version().unwrap() == CONDUIT_DATABASE_VERSION;
let version_match = services.globals.db.database_version().await == DATABASE_VERSION
|| services.globals.db.database_version().await == CONDUIT_DATABASE_VERSION;
assert!(
version_match,
"Failed asserting local database version {} is equal to known latest conduwuit database version {}",
services.globals.db.database_version().unwrap(),
services.globals.db.database_version().await,
DATABASE_VERSION,
);
{
let patterns = services.globals.forbidden_usernames();
if !patterns.is_empty() {
for user_id in services
services
.users
.iter()
.filter_map(Result::ok)
.filter(|user| !services.users.is_deactivated(user).unwrap_or(true))
.filter(|user| user.server_name() == config.server_name)
{
let matches = patterns.matches(user_id.localpart());
if matches.matched_any() {
warn!(
"User {} matches the following forbidden username patterns: {}",
user_id.to_string(),
matches
.into_iter()
.map(|x| &patterns.patterns()[x])
.join(", ")
);
}
}
}
}
{
let patterns = services.globals.forbidden_alias_names();
if !patterns.is_empty() {
for address in services.rooms.metadata.iter_ids() {
let room_id = address?;
let room_aliases = services.rooms.alias.local_aliases_for_room(&room_id);
for room_alias_result in room_aliases {
let room_alias = room_alias_result?;
let matches = patterns.matches(room_alias.alias());
.stream()
.filter(|user_id| services.users.is_active_local(user_id))
.ready_for_each(|user_id| {
let matches = patterns.matches(user_id.localpart());
if matches.matched_any() {
warn!(
"Room with alias {} ({}) matches the following forbidden room name patterns: {}",
room_alias,
&room_id,
"User {} matches the following forbidden username patterns: {}",
user_id.to_string(),
matches
.into_iter()
.map(|x| &patterns.patterns()[x])
.join(", ")
);
}
}
})
.await;
}
}
{
let patterns = services.globals.forbidden_alias_names();
if !patterns.is_empty() {
for room_id in services
.rooms
.metadata
.iter_ids()
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await
{
services
.rooms
.alias
.local_aliases_for_room(&room_id)
.ready_for_each(|room_alias| {
let matches = patterns.matches(room_alias.alias());
if matches.matched_any() {
warn!(
"Room with alias {} ({}) matches the following forbidden room name patterns: {}",
room_alias,
&room_id,
matches
.into_iter()
.map(|x| &patterns.patterns()[x])
.join(", ")
);
}
})
.await;
}
}
}
@ -224,424 +197,17 @@ async fn migrate(services: &Services) -> Result<()> {
Ok(())
}
async fn db_lt_1(services: &Services) -> Result<()> {
let db = &services.db;
let roomserverids = &db["roomserverids"];
let serverroomids = &db["serverroomids"];
for (roomserverid, _) in roomserverids.iter() {
let mut parts = roomserverid.split(|&b| b == 0xFF);
let room_id = parts.next().expect("split always returns one element");
let Some(servername) = parts.next() else {
error!("Migration: Invalid roomserverid in db.");
continue;
};
let mut serverroomid = servername.to_vec();
serverroomid.push(0xFF);
serverroomid.extend_from_slice(room_id);
serverroomids.insert(&serverroomid, &[])?;
}
services.globals.db.bump_database_version(1)?;
info!("Migration: 0 -> 1 finished");
Ok(())
}
async fn db_lt_2(services: &Services) -> Result<()> {
let db = &services.db;
// We accidentally inserted hashed versions of "" into the db instead of just ""
let userid_password = &db["roomserverids"];
for (userid, password) in userid_password.iter() {
let empty_pass = utils::hash::password("").expect("our own password to be properly hashed");
let password = std::str::from_utf8(&password).expect("password is valid utf-8");
let empty_hashed_password = utils::hash::verify_password(password, &empty_pass).is_ok();
if empty_hashed_password {
userid_password.insert(&userid, b"")?;
}
}
services.globals.db.bump_database_version(2)?;
info!("Migration: 1 -> 2 finished");
Ok(())
}
async fn db_lt_3(services: &Services) -> Result<()> {
let db = &services.db;
// Move media to filesystem
let mediaid_file = &db["mediaid_file"];
for (key, content) in mediaid_file.iter() {
if content.is_empty() {
continue;
}
#[allow(deprecated)]
let path = services.media.get_media_file(&key);
let mut file = fs::File::create(path)?;
file.write_all(&content)?;
mediaid_file.insert(&key, &[])?;
}
services.globals.db.bump_database_version(3)?;
info!("Migration: 2 -> 3 finished");
Ok(())
}
async fn db_lt_4(services: &Services) -> Result<()> {
let config = &services.server.config;
// Add federated users to services as deactivated
for our_user in services.users.iter() {
let our_user = our_user?;
if services.users.is_deactivated(&our_user)? {
continue;
}
for room in services.rooms.state_cache.rooms_joined(&our_user) {
for user in services.rooms.state_cache.room_members(&room?) {
let user = user?;
if user.server_name() != config.server_name {
info!(?user, "Migration: creating user");
services.users.create(&user, None)?;
}
}
}
}
services.globals.db.bump_database_version(4)?;
info!("Migration: 3 -> 4 finished");
Ok(())
}
async fn db_lt_5(services: &Services) -> Result<()> {
let db = &services.db;
// Upgrade user data store
let roomuserdataid_accountdata = &db["roomuserdataid_accountdata"];
let roomusertype_roomuserdataid = &db["roomusertype_roomuserdataid"];
for (roomuserdataid, _) in roomuserdataid_accountdata.iter() {
let mut parts = roomuserdataid.split(|&b| b == 0xFF);
let room_id = parts.next().unwrap();
let user_id = parts.next().unwrap();
let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap();
let mut key = room_id.to_vec();
key.push(0xFF);
key.extend_from_slice(user_id);
key.push(0xFF);
key.extend_from_slice(event_type);
roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
}
services.globals.db.bump_database_version(5)?;
info!("Migration: 4 -> 5 finished");
Ok(())
}
async fn db_lt_6(services: &Services) -> Result<()> {
let db = &services.db;
// Set room member count
let roomid_shortstatehash = &db["roomid_shortstatehash"];
for (roomid, _) in roomid_shortstatehash.iter() {
let string = utils::string_from_bytes(&roomid).unwrap();
let room_id = <&RoomId>::try_from(string.as_str()).unwrap();
services.rooms.state_cache.update_joined_count(room_id)?;
}
services.globals.db.bump_database_version(6)?;
info!("Migration: 5 -> 6 finished");
Ok(())
}
async fn db_lt_7(services: &Services) -> Result<()> {
let db = &services.db;
// Upgrade state store
let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new();
let mut current_sstatehash: Option<u64> = None;
let mut current_room = None;
let mut current_state = HashSet::new();
let handle_state = |current_sstatehash: u64,
current_room: &RoomId,
current_state: HashSet<_>,
last_roomstates: &mut HashMap<_, _>| {
let last_roomsstatehash = last_roomstates.get(current_room);
let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()),
|&last_roomsstatehash| {
services
.rooms
.state_compressor
.load_shortstatehash_info(last_roomsstatehash)
},
)?;
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew = current_state
.difference(&parent_stateinfo.1)
.copied()
.collect::<HashSet<_>>();
let statediffremoved = parent_stateinfo
.1
.difference(&current_state)
.copied()
.collect::<HashSet<_>>();
(statediffnew, statediffremoved)
} else {
(current_state, HashSet::new())
};
services.rooms.state_compressor.save_state_from_diff(
current_sstatehash,
Arc::new(statediffnew),
Arc::new(statediffremoved),
2, // every state change is 2 event changes on average
states_parents,
)?;
/*
let mut tmp = services.rooms.load_shortstatehash_info(&current_sstatehash)?;
let state = tmp.pop().unwrap();
println!(
"{}\t{}{:?}: {:?} + {:?} - {:?}",
current_room,
" ".repeat(tmp.len()),
utils::u64_from_bytes(&current_sstatehash).unwrap(),
tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()),
state
.2
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>(),
state
.3
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>()
);
*/
Ok::<_, Error>(())
};
let stateid_shorteventid = &db["stateid_shorteventid"];
let shorteventid_eventid = &db["shorteventid_eventid"];
for (k, seventid) in stateid_shorteventid.iter() {
let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()]).expect("number of bytes is correct");
let sstatekey = k[size_of::<u64>()..].to_vec();
if Some(sstatehash) != current_sstatehash {
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_deref().unwrap(),
current_state,
&mut last_roomstates,
)?;
last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash);
}
current_state = HashSet::new();
current_sstatehash = Some(sstatehash);
let event_id = shorteventid_eventid.get(&seventid).unwrap().unwrap();
let string = utils::string_from_bytes(&event_id).unwrap();
let event_id = <&EventId>::try_from(string.as_str()).unwrap();
let pdu = services.rooms.timeline.get_pdu(event_id).unwrap().unwrap();
if Some(&pdu.room_id) != current_room.as_ref() {
current_room = Some(pdu.room_id.clone());
}
}
let mut val = sstatekey;
val.extend_from_slice(&seventid);
current_state.insert(val.try_into().expect("size is correct"));
}
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_deref().unwrap(),
current_state,
&mut last_roomstates,
)?;
}
services.globals.db.bump_database_version(7)?;
info!("Migration: 6 -> 7 finished");
Ok(())
}
async fn db_lt_8(services: &Services) -> Result<()> {
let db = &services.db;
let roomid_shortstatehash = &db["roomid_shortstatehash"];
let roomid_shortroomid = &db["roomid_shortroomid"];
let pduid_pdu = &db["pduid_pdu"];
let eventid_pduid = &db["eventid_pduid"];
// Generate short room ids for all rooms
for (room_id, _) in roomid_shortstatehash.iter() {
let shortroomid = services.globals.next_count()?.to_be_bytes();
roomid_shortroomid.insert(&room_id, &shortroomid)?;
info!("Migration: 8");
}
// Update pduids db layout
let batch = pduid_pdu
.iter()
.filter_map(|(key, v)| {
if !key.starts_with(b"!") {
return None;
}
let mut parts = key.splitn(2, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_key = short_room_id.to_vec();
new_key.extend_from_slice(count);
Some(database::OwnedKeyVal(new_key, v))
})
.collect::<Vec<_>>();
pduid_pdu.insert_batch(batch.iter().map(database::KeyVal::from))?;
let batch2 = eventid_pduid
.iter()
.filter_map(|(k, value)| {
if !value.starts_with(b"!") {
return None;
}
let mut parts = value.splitn(2, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_value = short_room_id.to_vec();
new_value.extend_from_slice(count);
Some(database::OwnedKeyVal(k, new_value))
})
.collect::<Vec<_>>();
eventid_pduid.insert_batch(batch2.iter().map(database::KeyVal::from))?;
services.globals.db.bump_database_version(8)?;
info!("Migration: 7 -> 8 finished");
Ok(())
}
async fn db_lt_9(services: &Services) -> Result<()> {
let db = &services.db;
let tokenids = &db["tokenids"];
let roomid_shortroomid = &db["roomid_shortroomid"];
// Update tokenids db layout
let mut iter = tokenids
.iter()
.filter_map(|(key, _)| {
if !key.starts_with(b"!") {
return None;
}
let mut parts = key.splitn(4, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let word = parts.next().unwrap();
let _pdu_id_room = parts.next().unwrap();
let pdu_id_count = parts.next().unwrap();
let short_room_id = roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_key = short_room_id.to_vec();
new_key.extend_from_slice(word);
new_key.push(0xFF);
new_key.extend_from_slice(pdu_id_count);
Some(database::OwnedKeyVal(new_key, Vec::<u8>::new()))
})
.peekable();
while iter.peek().is_some() {
let batch = iter.by_ref().take(1000).collect::<Vec<_>>();
tokenids.insert_batch(batch.iter().map(database::KeyVal::from))?;
debug!("Inserted smaller batch");
}
info!("Deleting starts");
let batch2: Vec<_> = tokenids
.iter()
.filter_map(|(key, _)| {
if key.starts_with(b"!") {
Some(key)
} else {
None
}
})
.collect();
for key in batch2 {
tokenids.remove(&key)?;
}
services.globals.db.bump_database_version(9)?;
info!("Migration: 8 -> 9 finished");
Ok(())
}
async fn db_lt_10(services: &Services) -> Result<()> {
let db = &services.db;
let statekey_shortstatekey = &db["statekey_shortstatekey"];
let shortstatekey_statekey = &db["shortstatekey_statekey"];
// Add other direction for shortstatekeys
for (statekey, shortstatekey) in statekey_shortstatekey.iter() {
shortstatekey_statekey.insert(&shortstatekey, &statekey)?;
}
// Force E2EE device list updates so we can send them over federation
for user_id in services.users.iter().filter_map(Result::ok) {
services.users.mark_device_key_update(&user_id)?;
}
services.globals.db.bump_database_version(10)?;
info!("Migration: 9 -> 10 finished");
Ok(())
}
#[allow(unreachable_code)]
async fn db_lt_11(services: &Services) -> Result<()> {
error!("Dropping a column to clear data is not implemented yet.");
//let userdevicesessionid_uiaarequest = &db["userdevicesessionid_uiaarequest"];
//userdevicesessionid_uiaarequest.clear()?;
services.globals.db.bump_database_version(11)?;
info!("Migration: 10 -> 11 finished");
Ok(())
}
async fn db_lt_12(services: &Services) -> Result<()> {
let config = &services.server.config;
for username in services.users.list_local_users()? {
let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) {
for username in &services
.users
.list_local_users()
.map(UserId::to_owned)
.collect::<Vec<_>>()
.await
{
let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) {
Ok(u) => u,
Err(e) => {
warn!("Invalid username {username}: {e}");
@ -652,7 +218,7 @@ async fn db_lt_12(services: &Services) -> Result<()> {
let raw_rules_list = services
.account_data
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap()
.await
.expect("Username is invalid");
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
@ -694,12 +260,15 @@ async fn db_lt_12(services: &Services) -> Result<()> {
}
}
services.account_data.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)?;
services
.account_data
.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)
.await?;
}
services.globals.db.bump_database_version(12)?;
@ -710,8 +279,14 @@ async fn db_lt_12(services: &Services) -> Result<()> {
async fn db_lt_13(services: &Services) -> Result<()> {
let config = &services.server.config;
for username in services.users.list_local_users()? {
let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) {
for username in &services
.users
.list_local_users()
.map(UserId::to_owned)
.collect::<Vec<_>>()
.await
{
let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) {
Ok(u) => u,
Err(e) => {
warn!("Invalid username {username}: {e}");
@ -722,7 +297,7 @@ async fn db_lt_13(services: &Services) -> Result<()> {
let raw_rules_list = services
.account_data
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap()
.await
.expect("Username is invalid");
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
@ -733,12 +308,15 @@ async fn db_lt_13(services: &Services) -> Result<()> {
.global
.update_with_server_default(user_default_rules);
services.account_data.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)?;
services
.account_data
.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)
.await?;
}
services.globals.db.bump_database_version(13)?;
@ -754,32 +332,37 @@ async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result<
let _cork = db.cork_and_sync();
let mut iter_count: usize = 0;
for (mut key, value) in roomuserid_joined.iter() {
iter_count = iter_count.saturating_add(1);
debug_info!(%iter_count);
let first_sep_index = key
.iter()
.position(|&i| i == 0xFF)
.expect("found 0xFF delim");
roomuserid_joined
.raw_stream()
.ignore_err()
.ready_for_each(|(key, value)| {
let mut key = key.to_vec();
iter_count = iter_count.saturating_add(1);
debug_info!(%iter_count);
let first_sep_index = key
.iter()
.position(|&i| i == 0xFF)
.expect("found 0xFF delim");
if key
.iter()
.get(first_sep_index..=first_sep_index.saturating_add(1))
.copied()
.collect_vec()
== vec![0xFF, 0xFF]
{
debug_warn!("Found bad key: {key:?}");
roomuserid_joined.remove(&key)?;
if key
.iter()
.get(first_sep_index..=first_sep_index.saturating_add(1))
.copied()
.collect_vec()
== vec![0xFF, 0xFF]
{
debug_warn!("Found bad key: {key:?}");
roomuserid_joined.remove(&key);
key.remove(first_sep_index);
debug_warn!("Fixed key: {key:?}");
roomuserid_joined.insert(&key, &value)?;
}
}
key.remove(first_sep_index);
debug_warn!("Fixed key: {key:?}");
roomuserid_joined.insert(&key, value);
}
})
.await;
db.db.cleanup()?;
db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[])?;
db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[]);
info!("Finished fixing");
Ok(())
@ -795,69 +378,71 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services)
.rooms
.metadata
.iter_ids()
.filter_map(Result::ok)
.collect_vec();
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await;
for room_id in room_ids.clone() {
for room_id in &room_ids {
debug_info!("Fixing room {room_id}");
let users_in_room = services
.rooms
.state_cache
.room_members(&room_id)
.filter_map(Result::ok)
.collect_vec();
.room_members(room_id)
.collect::<Vec<_>>()
.await;
let joined_members = users_in_room
.iter()
.stream()
.filter(|user_id| {
services
.rooms
.state_accessor
.get_member(&room_id, user_id)
.unwrap_or(None)
.map_or(false, |membership| membership.membership == MembershipState::Join)
.get_member(room_id, user_id)
.map(|member| member.map_or(false, |member| member.membership == MembershipState::Join))
})
.collect_vec();
.collect::<Vec<_>>()
.await;
let non_joined_members = users_in_room
.iter()
.stream()
.filter(|user_id| {
services
.rooms
.state_accessor
.get_member(&room_id, user_id)
.unwrap_or(None)
.map_or(false, |membership| {
membership.membership == MembershipState::Leave || membership.membership == MembershipState::Ban
})
.get_member(room_id, user_id)
.map(|member| member.map_or(false, |member| member.membership == MembershipState::Join))
})
.collect_vec();
.collect::<Vec<_>>()
.await;
for user_id in joined_members {
debug_info!("User is joined, marking as joined");
services
.rooms
.state_cache
.mark_as_joined(user_id, &room_id)?;
services.rooms.state_cache.mark_as_joined(user_id, room_id);
}
for user_id in non_joined_members {
debug_info!("User is left or banned, marking as left");
services.rooms.state_cache.mark_as_left(user_id, &room_id)?;
services.rooms.state_cache.mark_as_left(user_id, room_id);
}
}
for room_id in room_ids {
for room_id in &room_ids {
debug_info!(
"Updating joined count for room {room_id} to fix servers in room after correcting membership states"
);
services.rooms.state_cache.update_joined_count(&room_id)?;
services
.rooms
.state_cache
.update_joined_count(room_id)
.await;
}
db.db.cleanup()?;
db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?;
db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[]);
info!("Finished fixing");
Ok(())

View file

@ -288,8 +288,8 @@ impl Service {
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
pub fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let mut keys = self.db.verify_keys_for(origin)?;
pub async fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let mut keys = self.db.verify_keys_for(origin).await?;
if origin == self.server_name() {
keys.insert(
format!("ed25519:{}", self.keypair().version())
@ -304,8 +304,8 @@ impl Service {
Ok(keys)
}
pub fn signing_keys_for(&self, origin: &ServerName) -> Result<Option<ServerSigningKeys>> {
self.db.signing_keys_for(origin)
pub async fn signing_keys_for(&self, origin: &ServerName) -> Result<ServerSigningKeys> {
self.db.signing_keys_for(origin).await
}
pub fn well_known_client(&self) -> &Option<Url> { &self.config.well_known.client }

View file

@ -1,346 +0,0 @@
use std::{collections::BTreeMap, sync::Arc};
use conduit::{utils, Error, Result};
use database::Map;
use ruma::{
api::client::{
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
error::ErrorKind,
},
serde::Raw,
OwnedRoomId, RoomId, UserId,
};
use crate::{globals, Dep};
pub(super) struct Data {
backupid_algorithm: Arc<Map>,
backupid_etag: Arc<Map>,
backupkeyid_backup: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
}
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
backupid_algorithm: db["backupid_algorithm"].clone(),
backupid_etag: db["backupid_etag"].clone(),
backupkeyid_backup: db["backupkeyid_backup"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
}
}
pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let version = self.services.globals.next_count()?.to_string();
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.insert(
&key,
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
)?;
self.backupid_etag
.insert(&key, &self.services.globals.next_count()?.to_be_bytes())?;
Ok(version)
}
pub(super) fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.remove(&key)?;
self.backupid_etag.remove(&key)?;
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
pub(super) fn update_backup(
&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() {
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
}
self.backupid_algorithm
.insert(&key, backup_metadata.json().get().as_bytes())?;
self.backupid_etag
.insert(&key, &self.services.globals.next_count()?.to_be_bytes())?;
Ok(version.to_owned())
}
pub(super) fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm
.iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.next()
.map(|(key, _)| {
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
})
.transpose()
}
pub(super) fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm
.iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.next()
.map(|(key, value)| {
let version = utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
Ok((
version,
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?,
))
})
.transpose()
}
pub(super) fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm
.get(&key)?
.map_or(Ok(None), |bytes| {
serde_json::from_slice(&bytes)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
})
}
pub(super) fn add_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() {
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
}
self.backupid_etag
.insert(&key, &self.services.globals.next_count()?.to_be_bytes())?;
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup
.insert(&key, key_data.json().get().as_bytes())?;
Ok(())
}
pub(super) fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
}
pub(super) fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
Ok(utils::u64_from_bytes(
&self
.backupid_etag
.get(&key)?
.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
)
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
.to_string())
}
pub(super) fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xFF);
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
for result in self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
let room_id = RoomId::parse(
utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?;
let key_data = serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
Ok::<_, Error>((room_id, session_id, key_data))
}) {
let (room_id, session_id, key_data) = result?;
rooms
.entry(room_id)
.or_insert_with(|| RoomKeyBackup {
sessions: BTreeMap::new(),
})
.sessions
.insert(session_id, key_data);
}
Ok(rooms)
}
pub(super) fn get_room(
&self, user_id: &UserId, version: &str, room_id: &RoomId,
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
Ok(self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
let key_data = serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
Ok::<_, Error>((session_id, key_data))
})
.filter_map(Result::ok)
.collect())
}
pub(super) fn get_session(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup
.get(&key)?
.map(|value| {
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))
})
.transpose()
}
pub(super) fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
pub(super) fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
pub(super) fn delete_room_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
}

View file

@ -1,93 +1,319 @@
mod data;
use std::{collections::BTreeMap, sync::Arc};
use conduit::Result;
use data::Data;
use conduit::{
err, implement, utils,
utils::stream::{ReadyExt, TryIgnore},
Err, Error, Result,
};
use database::{Deserialized, Ignore, Interfix, Map};
use futures::StreamExt;
use ruma::{
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
serde::Raw,
OwnedRoomId, RoomId, UserId,
};
use crate::{globals, Dep};
pub struct Service {
db: Data,
services: Services,
}
struct Data {
backupid_algorithm: Arc<Map>,
backupid_etag: Arc<Map>,
backupkeyid_backup: Arc<Map>,
}
struct Services {
globals: Dep<globals::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(&args),
db: Data {
backupid_algorithm: args.db["backupid_algorithm"].clone(),
backupid_etag: args.db["backupid_etag"].clone(),
backupkeyid_backup: args.db["backupkeyid_backup"].clone(),
},
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
self.db.create_backup(user_id, backup_metadata)
}
#[implement(Service)]
pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let version = self.services.globals.next_count()?.to_string();
pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
self.db.delete_backup(user_id, version)
}
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
pub fn update_backup(
&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> {
self.db.update_backup(user_id, version, backup_metadata)
}
self.db.backupid_algorithm.insert(
&key,
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
);
pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
self.db.get_latest_backup_version(user_id)
}
self.db
.backupid_etag
.insert(&key, &self.services.globals.next_count()?.to_be_bytes());
pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
self.db.get_latest_backup(user_id)
}
pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
self.db.get_backup(user_id, version)
}
pub fn add_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
) -> Result<()> {
self.db
.add_key(user_id, version, room_id, session_id, key_data)
}
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { self.db.count_keys(user_id, version) }
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { self.db.get_etag(user_id, version) }
pub fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
self.db.get_all(user_id, version)
}
pub fn get_room(
&self, user_id: &UserId, version: &str, room_id: &RoomId,
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
self.db.get_room(user_id, version, room_id)
}
pub fn get_session(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>> {
self.db.get_session(user_id, version, room_id, session_id)
}
pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
self.db.delete_all_keys(user_id, version)
}
pub fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
self.db.delete_room_keys(user_id, version, room_id)
}
pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
self.db
.delete_room_key(user_id, version, room_id, session_id)
}
Ok(version)
}
#[implement(Service)]
pub async fn delete_backup(&self, user_id: &UserId, version: &str) {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.db.backupid_algorithm.remove(&key);
self.db.backupid_etag.remove(&key);
let key = (user_id, version, Interfix);
self.db
.backupkeyid_backup
.keys_raw_prefix(&key)
.ignore_err()
.ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
.await;
}
#[implement(Service)]
pub async fn update_backup(
&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> {
let key = (user_id, version);
if self.db.backupid_algorithm.qry(&key).await.is_err() {
return Err!(Request(NotFound("Tried to update nonexistent backup.")));
}
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.db
.backupid_algorithm
.insert(&key, backup_metadata.json().get().as_bytes());
self.db
.backupid_etag
.insert(&key, &self.services.globals.next_count()?.to_be_bytes());
Ok(version.to_owned())
}
#[implement(Service)]
pub async fn get_latest_backup_version(&self, user_id: &UserId) -> Result<String> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.db
.backupid_algorithm
.rev_raw_keys_from(&last_possible_key)
.ignore_err()
.ready_take_while(move |key| key.starts_with(&prefix))
.next()
.await
.ok_or_else(|| err!(Request(NotFound("No backup versions found"))))
.and_then(|key| {
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
})
}
#[implement(Service)]
pub async fn get_latest_backup(&self, user_id: &UserId) -> Result<(String, Raw<BackupAlgorithm>)> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.db
.backupid_algorithm
.rev_raw_stream_from(&last_possible_key)
.ignore_err()
.ready_take_while(move |(key, _)| key.starts_with(&prefix))
.next()
.await
.ok_or_else(|| err!(Request(NotFound("No backup found"))))
.and_then(|(key, val)| {
let version = utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
let algorithm = serde_json::from_slice(val)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?;
Ok((version, algorithm))
})
}
#[implement(Service)]
pub async fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Raw<BackupAlgorithm>> {
let key = (user_id, version);
self.db
.backupid_algorithm
.qry(&key)
.await
.deserialized_json()
}
#[implement(Service)]
pub async fn add_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
) -> Result<()> {
let key = (user_id, version);
if self.db.backupid_algorithm.qry(&key).await.is_err() {
return Err!(Request(NotFound("Tried to update nonexistent backup.")));
}
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.db
.backupid_etag
.insert(&key, &self.services.globals.next_count()?.to_be_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.db
.backupkeyid_backup
.insert(&key, key_data.json().get().as_bytes());
Ok(())
}
#[implement(Service)]
pub async fn count_keys(&self, user_id: &UserId, version: &str) -> usize {
let prefix = (user_id, version);
self.db
.backupkeyid_backup
.keys_raw_prefix(&prefix)
.count()
.await
}
#[implement(Service)]
pub async fn get_etag(&self, user_id: &UserId, version: &str) -> String {
let key = (user_id, version);
self.db
.backupid_etag
.qry(&key)
.await
.deserialized::<u64>()
.as_ref()
.map(ToString::to_string)
.expect("Backup has no etag.")
}
#[implement(Service)]
pub async fn get_all(&self, user_id: &UserId, version: &str) -> BTreeMap<OwnedRoomId, RoomKeyBackup> {
type KeyVal<'a> = ((Ignore, Ignore, &'a RoomId, &'a str), &'a [u8]);
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
let default = || RoomKeyBackup {
sessions: BTreeMap::new(),
};
let prefix = (user_id, version, Interfix);
self.db
.backupkeyid_backup
.stream_prefix(&prefix)
.ignore_err()
.ready_for_each(|((_, _, room_id, session_id), value): KeyVal<'_>| {
let key_data = serde_json::from_slice(value).expect("Invalid KeyBackupData JSON");
rooms
.entry(room_id.into())
.or_insert_with(default)
.sessions
.insert(session_id.into(), key_data);
})
.await;
rooms
}
#[implement(Service)]
pub async fn get_room(
&self, user_id: &UserId, version: &str, room_id: &RoomId,
) -> BTreeMap<String, Raw<KeyBackupData>> {
type KeyVal<'a> = ((Ignore, Ignore, Ignore, &'a str), &'a [u8]);
let prefix = (user_id, version, room_id, Interfix);
self.db
.backupkeyid_backup
.stream_prefix(&prefix)
.ignore_err()
.map(|((.., session_id), value): KeyVal<'_>| {
let session_id = session_id.to_owned();
let key_backup_data = serde_json::from_slice(value).expect("Invalid KeyBackupData JSON");
(session_id, key_backup_data)
})
.collect()
.await
}
#[implement(Service)]
pub async fn get_session(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<Raw<KeyBackupData>> {
let key = (user_id, version, room_id, session_id);
self.db
.backupkeyid_backup
.qry(&key)
.await
.deserialized_json()
}
#[implement(Service)]
pub async fn delete_all_keys(&self, user_id: &UserId, version: &str) {
let key = (user_id, version, Interfix);
self.db
.backupkeyid_backup
.keys_raw_prefix(&key)
.ignore_err()
.ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
.await;
}
#[implement(Service)]
pub async fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) {
let key = (user_id, version, room_id, Interfix);
self.db
.backupkeyid_backup
.keys_raw_prefix(&key)
.ignore_err()
.ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
.await;
}
#[implement(Service)]
pub async fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) {
let key = (user_id, version, room_id, session_id);
self.db
.backupkeyid_backup
.keys_raw_prefix(&key)
.ignore_err()
.ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key))
.await;
}

View file

@ -1,7 +1,7 @@
use std::{panic::AssertUnwindSafe, sync::Arc, time::Duration};
use conduit::{debug, debug_warn, error, trace, utils::time, warn, Err, Error, Result, Server};
use futures_util::FutureExt;
use futures::FutureExt;
use tokio::{
sync::{Mutex, MutexGuard},
task::{JoinHandle, JoinSet},

View file

@ -2,10 +2,11 @@ use std::sync::Arc;
use conduit::{
debug, debug_info, trace,
utils::{str_from_bytes, string_from_bytes},
utils::{str_from_bytes, stream::TryIgnore, string_from_bytes, ReadyExt},
Err, Error, Result,
};
use database::{Database, Map};
use futures::StreamExt;
use ruma::{api::client::error::ErrorKind, http_headers::ContentDisposition, Mxc, OwnedMxcUri, UserId};
use super::{preview::UrlPreviewData, thumbnail::Dim};
@ -59,7 +60,7 @@ impl Data {
.unwrap_or_default(),
);
self.mediaid_file.insert(&key, &[])?;
self.mediaid_file.insert(&key, &[]);
if let Some(user) = user {
let mut key: Vec<u8> = Vec::new();
@ -68,13 +69,13 @@ impl Data {
key.extend_from_slice(b"/");
key.extend_from_slice(mxc.media_id.as_bytes());
let user = user.as_bytes().to_vec();
self.mediaid_user.insert(&key, &user)?;
self.mediaid_user.insert(&key, &user);
}
Ok(key)
}
pub(super) fn delete_file_mxc(&self, mxc: &Mxc<'_>) -> Result<()> {
pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) {
debug!("MXC URI: {mxc}");
let mut prefix: Vec<u8> = Vec::new();
@ -85,25 +86,31 @@ impl Data {
prefix.push(0xFF);
trace!("MXC db prefix: {prefix:?}");
for (key, _) in self.mediaid_file.scan_prefix(prefix.clone()) {
debug!("Deleting key: {:?}", key);
self.mediaid_file.remove(&key)?;
}
self.mediaid_file
.raw_keys_prefix(&prefix)
.ignore_err()
.ready_for_each(|key| {
debug!("Deleting key: {:?}", key);
self.mediaid_file.remove(key);
})
.await;
for (key, value) in self.mediaid_user.scan_prefix(prefix.clone()) {
if key.starts_with(&prefix) {
let user = str_from_bytes(&value).unwrap_or_default();
self.mediaid_user
.raw_stream_prefix(&prefix)
.ignore_err()
.ready_for_each(|(key, val)| {
if key.starts_with(&prefix) {
let user = str_from_bytes(val).unwrap_or_default();
debug_info!("Deleting key {key:?} which was uploaded by user {user}");
debug_info!("Deleting key \"{key:?}\" which was uploaded by user {user}");
self.mediaid_user.remove(&key)?;
}
}
Ok(())
self.mediaid_user.remove(key);
}
})
.await;
}
/// Searches for all files with the given MXC
pub(super) fn search_mxc_metadata_prefix(&self, mxc: &Mxc<'_>) -> Result<Vec<Vec<u8>>> {
pub(super) async fn search_mxc_metadata_prefix(&self, mxc: &Mxc<'_>) -> Result<Vec<Vec<u8>>> {
debug!("MXC URI: {mxc}");
let mut prefix: Vec<u8> = Vec::new();
@ -115,9 +122,10 @@ impl Data {
let keys: Vec<Vec<u8>> = self
.mediaid_file
.scan_prefix(prefix)
.map(|(key, _)| key)
.collect();
.keys_prefix_raw(&prefix)
.ignore_err()
.collect()
.await;
if keys.is_empty() {
return Err!(Database("Failed to find any keys in database for `{mxc}`",));
@ -128,7 +136,7 @@ impl Data {
Ok(keys)
}
pub(super) fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result<Metadata> {
pub(super) async fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result<Metadata> {
let mut prefix: Vec<u8> = Vec::new();
prefix.extend_from_slice(b"mxc://");
prefix.extend_from_slice(mxc.server_name.as_bytes());
@ -139,10 +147,13 @@ impl Data {
prefix.extend_from_slice(&dim.height.to_be_bytes());
prefix.push(0xFF);
let (key, _) = self
let key = self
.mediaid_file
.scan_prefix(prefix)
.raw_keys_prefix(&prefix)
.ignore_err()
.map(ToOwned::to_owned)
.next()
.await
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
let mut parts = key.rsplit(|&b| b == 0xFF);
@ -177,28 +188,31 @@ impl Data {
}
/// Gets all the MXCs associated with a user
pub(super) fn get_all_user_mxcs(&self, user_id: &UserId) -> Vec<OwnedMxcUri> {
let user_id = user_id.as_bytes().to_vec();
pub(super) async fn get_all_user_mxcs(&self, user_id: &UserId) -> Vec<OwnedMxcUri> {
self.mediaid_user
.iter()
.filter_map(|(key, user)| {
if *user == user_id {
let mxc_s = string_from_bytes(&key).ok()?;
Some(OwnedMxcUri::from(mxc_s))
} else {
None
}
})
.stream()
.ignore_err()
.ready_filter_map(|(key, user): (&str, &UserId)| (user == user_id).then(|| key.into()))
.collect()
.await
}
/// Gets all the media keys in our database (this includes all the metadata
/// associated with it such as width, height, content-type, etc)
pub(crate) fn get_all_media_keys(&self) -> Vec<Vec<u8>> { self.mediaid_file.iter().map(|(key, _)| key).collect() }
pub(crate) async fn get_all_media_keys(&self) -> Vec<Vec<u8>> {
self.mediaid_file
.raw_keys()
.ignore_err()
.map(<[u8]>::to_vec)
.collect()
.await
}
#[inline]
pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) }
pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> {
self.url_previews.remove(url.as_bytes());
Ok(())
}
pub(super) fn set_url_preview(
&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration,
@ -233,11 +247,13 @@ impl Data {
value.push(0xFF);
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
self.url_previews.insert(url.as_bytes(), &value)
self.url_previews.insert(url.as_bytes(), &value);
Ok(())
}
pub(super) fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
let values = self.url_previews.get(url.as_bytes()).ok()??;
pub(super) async fn get_url_preview(&self, url: &str) -> Result<UrlPreviewData> {
let values = self.url_previews.qry(url).await?;
let mut values = values.split(|&b| b == 0xFF);
@ -291,7 +307,7 @@ impl Data {
x => x,
};
Some(UrlPreviewData {
Ok(UrlPreviewData {
title,
description,
image,

View file

@ -7,7 +7,11 @@ use std::{
time::Instant,
};
use conduit::{debug, debug_info, debug_warn, error, info, warn, Config, Result};
use conduit::{
debug, debug_info, debug_warn, error, info,
utils::{stream::TryIgnore, ReadyExt},
warn, Config, Result,
};
use crate::{globals, Services};
@ -23,12 +27,17 @@ pub(crate) async fn migrate_sha256_media(services: &Services) -> Result<()> {
// Move old media files to new names
let mut changes = Vec::<(PathBuf, PathBuf)>::new();
for (key, _) in mediaid_file.iter() {
let old = services.media.get_media_file_b64(&key);
let new = services.media.get_media_file_sha256(&key);
debug!(?key, ?old, ?new, num = changes.len(), "change");
changes.push((old, new));
}
mediaid_file
.raw_keys()
.ignore_err()
.ready_for_each(|key| {
let old = services.media.get_media_file_b64(key);
let new = services.media.get_media_file_sha256(key);
debug!(?key, ?old, ?new, num = changes.len(), "change");
changes.push((old, new));
})
.await;
// move the file to the new location
for (old_path, path) in changes {
if old_path.exists() {
@ -41,11 +50,11 @@ pub(crate) async fn migrate_sha256_media(services: &Services) -> Result<()> {
// Apply fix from when sha256_media was backward-incompat and bumped the schema
// version from 13 to 14. For users satisfying these conditions we can go back.
if services.globals.db.database_version()? == 14 && globals::migrations::DATABASE_VERSION == 13 {
if services.globals.db.database_version().await == 14 && globals::migrations::DATABASE_VERSION == 13 {
services.globals.db.bump_database_version(13)?;
}
db["global"].insert(b"feat_sha256_media", &[])?;
db["global"].insert(b"feat_sha256_media", &[]);
info!("Finished applying sha256_media");
Ok(())
}
@ -71,7 +80,7 @@ pub(crate) async fn checkup_sha256_media(services: &Services) -> Result<()> {
.filter_map(|ent| ent.map_or(None, |ent| Some(ent.path().into_os_string())))
.collect();
for key in media.db.get_all_media_keys() {
for key in media.db.get_all_media_keys().await {
let new_path = media.get_media_file_sha256(&key).into_os_string();
let old_path = media.get_media_file_b64(&key).into_os_string();
if let Err(e) = handle_media_check(&dbs, config, &files, &key, &new_path, &old_path).await {
@ -112,8 +121,8 @@ async fn handle_media_check(
"Media is missing at all paths. Removing from database..."
);
mediaid_file.remove(key)?;
mediaid_user.remove(key)?;
mediaid_file.remove(key);
mediaid_user.remove(key);
}
if config.media_compat_file_link && !old_exists && new_exists {

View file

@ -97,7 +97,7 @@ impl Service {
/// Deletes a file in the database and from the media directory via an MXC
pub async fn delete(&self, mxc: &Mxc<'_>) -> Result<()> {
if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc) {
if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc).await {
for key in keys {
trace!(?mxc, "MXC Key: {key:?}");
debug_info!(?mxc, "Deleting from filesystem");
@ -107,7 +107,7 @@ impl Service {
}
debug_info!(?mxc, "Deleting from database");
_ = self.db.delete_file_mxc(mxc);
self.db.delete_file_mxc(mxc).await;
}
Ok(())
@ -120,7 +120,7 @@ impl Service {
///
/// currently, this is only practical for local users
pub async fn delete_from_user(&self, user: &UserId) -> Result<usize> {
let mxcs = self.db.get_all_user_mxcs(user);
let mxcs = self.db.get_all_user_mxcs(user).await;
let mut deletion_count: usize = 0;
for mxc in mxcs {
@ -150,7 +150,7 @@ impl Service {
content_disposition,
content_type,
key,
}) = self.db.search_file_metadata(mxc, &Dim::default())
}) = self.db.search_file_metadata(mxc, &Dim::default()).await
{
let mut content = Vec::new();
let path = self.get_media_file(&key);
@ -170,7 +170,7 @@ impl Service {
/// Gets all the MXC URIs in our media database
pub async fn get_all_mxcs(&self) -> Result<Vec<OwnedMxcUri>> {
let all_keys = self.db.get_all_media_keys();
let all_keys = self.db.get_all_media_keys().await;
let mut mxcs = Vec::with_capacity(all_keys.len());
@ -209,7 +209,7 @@ impl Service {
pub async fn delete_all_remote_media_at_after_time(
&self, time: SystemTime, before: bool, after: bool, yes_i_want_to_delete_local_media: bool,
) -> Result<usize> {
let all_keys = self.db.get_all_media_keys();
let all_keys = self.db.get_all_media_keys().await;
let mut remote_mxcs = Vec::with_capacity(all_keys.len());
for key in all_keys {
@ -343,9 +343,10 @@ impl Service {
}
#[inline]
pub fn get_metadata(&self, mxc: &Mxc<'_>) -> Option<FileMeta> {
pub async fn get_metadata(&self, mxc: &Mxc<'_>) -> Option<FileMeta> {
self.db
.search_file_metadata(mxc, &Dim::default())
.await
.map(|metadata| FileMeta {
content_disposition: metadata.content_disposition,
content_type: metadata.content_type,

View file

@ -71,16 +71,16 @@ pub async fn download_image(&self, url: &str) -> Result<UrlPreviewData> {
#[implement(Service)]
pub async fn get_url_preview(&self, url: &str) -> Result<UrlPreviewData> {
if let Some(preview) = self.db.get_url_preview(url) {
if let Ok(preview) = self.db.get_url_preview(url).await {
return Ok(preview);
}
// ensure that only one request is made per URL
let _request_lock = self.url_preview_mutex.lock(url).await;
match self.db.get_url_preview(url) {
Some(preview) => Ok(preview),
None => self.request_url_preview(url).await,
match self.db.get_url_preview(url).await {
Ok(preview) => Ok(preview),
Err(_) => self.request_url_preview(url).await,
}
}

View file

@ -54,9 +54,9 @@ impl super::Service {
// 0, 0 because that's the original file
let dim = dim.normalized();
if let Ok(metadata) = self.db.search_file_metadata(mxc, &dim) {
if let Ok(metadata) = self.db.search_file_metadata(mxc, &dim).await {
self.get_thumbnail_saved(metadata).await
} else if let Ok(metadata) = self.db.search_file_metadata(mxc, &Dim::default()) {
} else if let Ok(metadata) = self.db.search_file_metadata(mxc, &Dim::default()).await {
self.get_thumbnail_generate(mxc, &dim, metadata).await
} else {
Ok(None)

View file

@ -19,6 +19,7 @@ pub mod resolver;
pub mod rooms;
pub mod sending;
pub mod server_keys;
pub mod sync;
pub mod transaction_ids;
pub mod uiaa;
pub mod updates;

View file

@ -1,7 +1,12 @@
use std::sync::Arc;
use conduit::{debug_warn, utils, Error, Result};
use database::Map;
use conduit::{
debug_warn, utils,
utils::{stream::TryIgnore, ReadyExt},
Result,
};
use database::{Deserialized, Map};
use futures::Stream;
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId};
use super::Presence;
@ -31,39 +36,35 @@ impl Data {
}
}
pub fn get_presence(&self, user_id: &UserId) -> Result<Option<(u64, PresenceEvent)>> {
if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? {
let count = utils::u64_from_bytes(&count_bytes)
.map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?;
pub async fn get_presence(&self, user_id: &UserId) -> Result<(u64, PresenceEvent)> {
let count = self
.userid_presenceid
.qry(user_id)
.await
.deserialized::<u64>()?;
let key = presenceid_key(count, user_id);
self.presenceid_presence
.get(&key)?
.map(|presence_bytes| -> Result<(u64, PresenceEvent)> {
Ok((
count,
Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id, &self.services.users)?,
))
})
.transpose()
} else {
Ok(None)
}
let key = presenceid_key(count, user_id);
let bytes = self.presenceid_presence.qry(&key).await?;
let event = Presence::from_json_bytes(&bytes)?
.to_presence_event(user_id, &self.services.users)
.await;
Ok((count, event))
}
pub(super) fn set_presence(
pub(super) async fn set_presence(
&self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option<bool>,
last_active_ago: Option<UInt>, status_msg: Option<String>,
) -> Result<()> {
let last_presence = self.get_presence(user_id)?;
let last_presence = self.get_presence(user_id).await;
let state_changed = match last_presence {
None => true,
Some(ref presence) => presence.1.content.presence != *presence_state,
Err(_) => true,
Ok(ref presence) => presence.1.content.presence != *presence_state,
};
let status_msg_changed = match last_presence {
None => true,
Some(ref last_presence) => {
Err(_) => true,
Ok(ref last_presence) => {
let old_msg = last_presence
.1
.content
@ -79,8 +80,8 @@ impl Data {
let now = utils::millis_since_unix_epoch();
let last_last_active_ts = match last_presence {
None => 0,
Some((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()),
Err(_) => 0,
Ok((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()),
};
let last_active_ts = match last_active_ago {
@ -90,12 +91,7 @@ impl Data {
// TODO: tighten for state flicker?
if !status_msg_changed && !state_changed && last_active_ts < last_last_active_ts {
debug_warn!(
"presence spam {:?} last_active_ts:{:?} < {:?}",
user_id,
last_active_ts,
last_last_active_ts
);
debug_warn!("presence spam {user_id:?} last_active_ts:{last_active_ts:?} < {last_last_active_ts:?}",);
return Ok(());
}
@ -115,41 +111,42 @@ impl Data {
let key = presenceid_key(count, user_id);
self.presenceid_presence
.insert(&key, &presence.to_json_bytes()?)?;
.insert(&key, &presence.to_json_bytes()?);
self.userid_presenceid
.insert(user_id.as_bytes(), &count.to_be_bytes())?;
.insert(user_id.as_bytes(), &count.to_be_bytes());
if let Some((last_count, _)) = last_presence {
if let Ok((last_count, _)) = last_presence {
let key = presenceid_key(last_count, user_id);
self.presenceid_presence.remove(&key)?;
self.presenceid_presence.remove(&key);
}
Ok(())
}
pub(super) fn remove_presence(&self, user_id: &UserId) -> Result<()> {
if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? {
let count = utils::u64_from_bytes(&count_bytes)
.map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?;
let key = presenceid_key(count, user_id);
self.presenceid_presence.remove(&key)?;
self.userid_presenceid.remove(user_id.as_bytes())?;
}
pub(super) async fn remove_presence(&self, user_id: &UserId) {
let Ok(count) = self
.userid_presenceid
.qry(user_id)
.await
.deserialized::<u64>()
else {
return;
};
Ok(())
let key = presenceid_key(count, user_id);
self.presenceid_presence.remove(&key);
self.userid_presenceid.remove(user_id.as_bytes());
}
pub fn presence_since<'a>(&'a self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + 'a> {
Box::new(
self.presenceid_presence
.iter()
.flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, Vec<u8>)> {
let (count, user_id) = presenceid_parse(&key)?;
Ok((user_id.to_owned(), count, presence_bytes))
})
.filter(move |(_, count, _)| *count > since),
)
pub fn presence_since(&self, since: u64) -> impl Stream<Item = (OwnedUserId, u64, Vec<u8>)> + Send + '_ {
self.presenceid_presence
.raw_stream()
.ignore_err()
.ready_filter_map(move |(key, presence_bytes)| {
let (count, user_id) = presenceid_parse(key).expect("invalid presenceid_parse");
(count > since).then(|| (user_id.to_owned(), count, presence_bytes.to_vec()))
})
}
}
@ -162,7 +159,7 @@ fn presenceid_key(count: u64, user_id: &UserId) -> Vec<u8> {
fn presenceid_parse(key: &[u8]) -> Result<(u64, &UserId)> {
let (count, user_id) = key.split_at(8);
let user_id = user_id_from_bytes(user_id)?;
let count = utils::u64_from_bytes(count).unwrap();
let count = utils::u64_from_u8(count);
Ok((count, user_id))
}

View file

@ -4,8 +4,8 @@ mod presence;
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use conduit::{checked, debug, error, Error, Result, Server};
use futures_util::{stream::FuturesUnordered, StreamExt};
use conduit::{checked, debug, error, result::LogErr, Error, Result, Server};
use futures::{stream::FuturesUnordered, Stream, StreamExt, TryFutureExt};
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId};
use tokio::{sync::Mutex, time::sleep};
@ -58,7 +58,9 @@ impl crate::Service for Service {
loop {
debug_assert!(!receiver.is_closed(), "channel error");
tokio::select! {
Some(user_id) = presence_timers.next() => self.process_presence_timer(&user_id)?,
Some(user_id) = presence_timers.next() => {
self.process_presence_timer(&user_id).await.log_err().ok();
},
event = receiver.recv_async() => match event {
Err(_e) => return Ok(()),
Ok((user_id, timeout)) => {
@ -82,28 +84,27 @@ impl crate::Service for Service {
impl Service {
/// Returns the latest presence event for the given user.
#[inline]
pub fn get_presence(&self, user_id: &UserId) -> Result<Option<PresenceEvent>> {
if let Some((_, presence)) = self.db.get_presence(user_id)? {
Ok(Some(presence))
} else {
Ok(None)
}
pub async fn get_presence(&self, user_id: &UserId) -> Result<PresenceEvent> {
self.db
.get_presence(user_id)
.map_ok(|(_, presence)| presence)
.await
}
/// Pings the presence of the given user in the given room, setting the
/// specified state.
pub fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> {
pub async fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> {
const REFRESH_TIMEOUT: u64 = 60 * 25 * 1000;
let last_presence = self.db.get_presence(user_id)?;
let last_presence = self.db.get_presence(user_id).await;
let state_changed = match last_presence {
None => true,
Some((_, ref presence)) => presence.content.presence != *new_state,
Err(_) => true,
Ok((_, ref presence)) => presence.content.presence != *new_state,
};
let last_last_active_ago = match last_presence {
None => 0_u64,
Some((_, ref presence)) => presence.content.last_active_ago.unwrap_or_default().into(),
Err(_) => 0_u64,
Ok((_, ref presence)) => presence.content.last_active_ago.unwrap_or_default().into(),
};
if !state_changed && last_last_active_ago < REFRESH_TIMEOUT {
@ -111,17 +112,18 @@ impl Service {
}
let status_msg = match last_presence {
Some((_, ref presence)) => presence.content.status_msg.clone(),
None => Some(String::new()),
Ok((_, ref presence)) => presence.content.status_msg.clone(),
Err(_) => Some(String::new()),
};
let last_active_ago = UInt::new(0);
let currently_active = *new_state == PresenceState::Online;
self.set_presence(user_id, new_state, Some(currently_active), last_active_ago, status_msg)
.await
}
/// Adds a presence event which will be saved until a new event replaces it.
pub fn set_presence(
pub async fn set_presence(
&self, user_id: &UserId, state: &PresenceState, currently_active: Option<bool>, last_active_ago: Option<UInt>,
status_msg: Option<String>,
) -> Result<()> {
@ -131,7 +133,8 @@ impl Service {
};
self.db
.set_presence(user_id, presence_state, currently_active, last_active_ago, status_msg)?;
.set_presence(user_id, presence_state, currently_active, last_active_ago, status_msg)
.await?;
if self.timeout_remote_users || self.services.globals.user_is_local(user_id) {
let timeout = match presence_state {
@ -154,28 +157,33 @@ impl Service {
///
/// TODO: Why is this not used?
#[allow(dead_code)]
pub fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) }
pub async fn remove_presence(&self, user_id: &UserId) { self.db.remove_presence(user_id).await }
/// Returns the most recent presence updates that happened after the event
/// with id `since`.
#[inline]
pub fn presence_since(&self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + '_> {
pub fn presence_since(&self, since: u64) -> impl Stream<Item = (OwnedUserId, u64, Vec<u8>)> + Send + '_ {
self.db.presence_since(since)
}
pub fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> {
#[inline]
pub async fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> {
let presence = Presence::from_json_bytes(bytes)?;
presence.to_presence_event(user_id, &self.services.users)
let event = presence
.to_presence_event(user_id, &self.services.users)
.await;
Ok(event)
}
fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> {
async fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> {
let mut presence_state = PresenceState::Offline;
let mut last_active_ago = None;
let mut status_msg = None;
let presence_event = self.get_presence(user_id)?;
let presence_event = self.get_presence(user_id).await;
if let Some(presence_event) = presence_event {
if let Ok(presence_event) = presence_event {
presence_state = presence_event.content.presence;
last_active_ago = presence_event.content.last_active_ago;
status_msg = presence_event.content.status_msg;
@ -192,7 +200,8 @@ impl Service {
);
if let Some(new_state) = new_state {
self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg)?;
self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg)
.await?;
}
Ok(())

View file

@ -1,5 +1,3 @@
use std::sync::Arc;
use conduit::{utils, Error, Result};
use ruma::{
events::presence::{PresenceEvent, PresenceEventContent},
@ -42,7 +40,7 @@ impl Presence {
}
/// Creates a PresenceEvent from available data.
pub(super) fn to_presence_event(&self, user_id: &UserId, users: &Arc<users::Service>) -> Result<PresenceEvent> {
pub(super) async fn to_presence_event(&self, user_id: &UserId, users: &users::Service) -> PresenceEvent {
let now = utils::millis_since_unix_epoch();
let last_active_ago = if self.currently_active {
None
@ -50,16 +48,16 @@ impl Presence {
Some(UInt::new_saturating(now.saturating_sub(self.last_active_ts)))
};
Ok(PresenceEvent {
PresenceEvent {
sender: user_id.to_owned(),
content: PresenceEventContent {
presence: self.state.clone(),
status_msg: self.status_msg.clone(),
currently_active: Some(self.currently_active),
last_active_ago,
displayname: users.displayname(user_id)?,
avatar_url: users.avatar_url(user_id)?,
displayname: users.displayname(user_id).await.ok(),
avatar_url: users.avatar_url(user_id).await.ok(),
},
})
}
}
}

View file

@ -1,77 +0,0 @@
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::{Database, Map};
use ruma::{
api::client::push::{set_pusher, Pusher},
UserId,
};
pub(super) struct Data {
senderkey_pusher: Arc<Map>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
senderkey_pusher: db["senderkey_pusher"].clone(),
}
}
pub(super) fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result<()> {
match pusher {
set_pusher::v3::PusherAction::Post(data) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
self.senderkey_pusher
.insert(&key, &serde_json::to_vec(pusher).expect("Pusher is valid JSON value"))?;
Ok(())
},
set_pusher::v3::PusherAction::Delete(ids) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(ids.pushkey.as_bytes());
self.senderkey_pusher.remove(&key).map_err(Into::into)
},
}
}
pub(super) fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
let mut senderkey = sender.as_bytes().to_vec();
senderkey.push(0xFF);
senderkey.extend_from_slice(pushkey.as_bytes());
self.senderkey_pusher
.get(&senderkey)?
.map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
.transpose()
}
pub(super) fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xFF);
self.senderkey_pusher
.scan_prefix(prefix)
.map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
.collect()
}
pub(super) fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
let mut parts = k.splitn(2, |&b| b == 0xFF);
let _senderkey = parts.next();
let push_key = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
let push_key_string = utils::string_from_bytes(push_key)
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
Ok(push_key_string)
}))
}
}

View file

@ -1,9 +1,13 @@
mod data;
use std::{fmt::Debug, mem, sync::Arc};
use bytes::BytesMut;
use conduit::{debug_error, err, trace, utils::string_from_bytes, warn, Err, PduEvent, Result};
use conduit::{
debug_error, err, trace,
utils::{stream::TryIgnore, string_from_bytes},
Err, PduEvent, Result,
};
use database::{Deserialized, Ignore, Interfix, Map};
use futures::{Stream, StreamExt};
use ipaddress::IPAddress;
use ruma::{
api::{
@ -22,12 +26,11 @@ use ruma::{
uint, RoomId, UInt, UserId,
};
use self::data::Data;
use crate::{client, globals, rooms, users, Dep};
pub struct Service {
services: Services,
db: Data,
services: Services,
}
struct Services {
@ -38,9 +41,16 @@ struct Services {
users: Dep<users::Service>,
}
struct Data {
senderkey_pusher: Arc<Map>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data {
senderkey_pusher: args.db["senderkey_pusher"].clone(),
},
services: Services {
globals: args.depend::<globals::Service>("globals"),
client: args.depend::<client::Service>("client"),
@ -48,7 +58,6 @@ impl crate::Service for Service {
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
users: args.depend::<users::Service>("users"),
},
db: Data::new(args.db),
}))
}
@ -56,19 +65,52 @@ impl crate::Service for Service {
}
impl Service {
pub fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result<()> {
self.db.set_pusher(sender, pusher)
pub fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) {
match pusher {
set_pusher::v3::PusherAction::Post(data) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
self.db
.senderkey_pusher
.insert(&key, &serde_json::to_vec(pusher).expect("Pusher is valid JSON value"));
},
set_pusher::v3::PusherAction::Delete(ids) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(ids.pushkey.as_bytes());
self.db.senderkey_pusher.remove(&key);
},
}
}
pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
self.db.get_pusher(sender, pushkey)
pub async fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Pusher> {
let senderkey = (sender, pushkey);
self.db
.senderkey_pusher
.qry(&senderkey)
.await
.deserialized_json()
}
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { self.db.get_pushers(sender) }
pub async fn get_pushers(&self, sender: &UserId) -> Vec<Pusher> {
let prefix = (sender, Interfix);
self.db
.senderkey_pusher
.stream_prefix(&prefix)
.ignore_err()
.map(|(_, val): (Ignore, &[u8])| serde_json::from_slice(val).expect("Invalid Pusher in db."))
.collect()
.await
}
#[must_use]
pub fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + '_> {
self.db.get_pushkeys(sender)
pub fn get_pushkeys<'a>(&'a self, sender: &'a UserId) -> impl Stream<Item = &str> + Send + 'a {
let prefix = (sender, Interfix);
self.db
.senderkey_pusher
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, pushkey): (Ignore, &str)| pushkey)
}
#[tracing::instrument(skip(self, dest, request))]
@ -161,15 +203,18 @@ impl Service {
let power_levels: RoomPowerLevelsEventContent = self
.services
.state_accessor
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
.map(|ev| {
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")
.await
.and_then(|ev| {
serde_json::from_str(ev.content.get())
.map_err(|e| err!(Database("invalid m.room.power_levels event: {e:?}")))
.map_err(|e| err!(Database(error!("invalid m.room.power_levels event: {e:?}"))))
})
.transpose()?
.unwrap_or_default();
for action in self.get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id)? {
for action in self
.get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id)
.await?
{
let n = match action {
Action::Notify => true,
Action::SetTweak(tweak) => {
@ -197,7 +242,7 @@ impl Service {
}
#[tracing::instrument(skip(self, user, ruleset, pdu), level = "debug")]
pub fn get_actions<'a>(
pub async fn get_actions<'a>(
&self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent,
pdu: &Raw<AnySyncTimelineEvent>, room_id: &RoomId,
) -> Result<&'a [Action]> {
@ -207,21 +252,27 @@ impl Service {
notifications: power_levels.notifications.clone(),
};
let room_joined_count = self
.services
.state_cache
.room_joined_count(room_id)
.await
.unwrap_or(1)
.try_into()
.unwrap_or_else(|_| uint!(0));
let user_display_name = self
.services
.users
.displayname(user)
.await
.unwrap_or_else(|_| user.localpart().to_owned());
let ctx = PushConditionRoomCtx {
room_id: room_id.to_owned(),
member_count: UInt::try_from(
self.services
.state_cache
.room_joined_count(room_id)?
.unwrap_or(1),
)
.unwrap_or_else(|_| uint!(0)),
member_count: room_joined_count,
user_id: user.to_owned(),
user_display_name: self
.services
.users
.displayname(user)?
.unwrap_or_else(|| user.localpart().to_owned()),
user_display_name,
power_levels: Some(power_levels),
};
@ -278,9 +329,14 @@ impl Service {
notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
}
notifi.sender_display_name = self.services.users.displayname(&event.sender)?;
notifi.sender_display_name = self.services.users.displayname(&event.sender).await.ok();
notifi.room_name = self.services.state_accessor.get_name(&event.room_id)?;
notifi.room_name = self
.services
.state_accessor
.get_name(&event.room_id)
.await
.ok();
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
.await?;

View file

@ -193,7 +193,7 @@ impl super::Service {
.send()
.await;
trace!("response: {:?}", response);
trace!("response: {response:?}");
if let Err(e) = &response {
debug!("error: {e:?}");
return Ok(None);
@ -206,7 +206,7 @@ impl super::Service {
}
let text = response.text().await?;
trace!("response text: {:?}", text);
trace!("response text: {text:?}");
if text.len() >= 12288 {
debug_warn!("response contains junk");
return Ok(None);
@ -225,7 +225,7 @@ impl super::Service {
return Ok(None);
}
debug_info!("{:?} found at {:?}", dest, m_server);
debug_info!("{dest:?} found at {m_server:?}");
Ok(Some(m_server.to_owned()))
}

View file

@ -1,125 +0,0 @@
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::Map;
use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId};
use crate::{globals, Dep};
pub(super) struct Data {
alias_userid: Arc<Map>,
alias_roomid: Arc<Map>,
aliasid_alias: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
}
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
alias_userid: db["alias_userid"].clone(),
alias_roomid: db["alias_roomid"].clone(),
aliasid_alias: db["aliasid_alias"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
}
}
pub(super) fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> {
// Comes first as we don't want a stuck alias
self.alias_userid
.insert(alias.alias().as_bytes(), user_id.as_bytes())?;
self.alias_roomid
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF);
aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
Ok(())
}
pub(super) fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
let mut prefix = room_id.to_vec();
prefix.push(0xFF);
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
self.aliasid_alias.remove(&key)?;
}
self.alias_roomid.remove(alias.alias().as_bytes())?;
self.alias_userid.remove(alias.alias().as_bytes())?;
} else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist or is invalid."));
}
Ok(())
}
pub(super) fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
self.alias_roomid
.get(alias.alias().as_bytes())?
.map(|bytes| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
})
.transpose()
}
pub(super) fn who_created_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedUserId>> {
self.alias_userid
.get(alias.alias().as_bytes())?
.map(|bytes| {
UserId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("User ID in alias_userid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in alias_roomid is invalid."))
})
.transpose()
}
pub(super) fn local_aliases_for_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a + Send> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
}))
}
pub(super) fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
Box::new(
self.alias_roomid
.iter()
.map(|(room_alias_bytes, room_id_bytes)| {
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?;
let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
Ok((room_id, room_alias_localpart))
}),
)
}
}

View file

@ -1,19 +1,23 @@
mod data;
mod remote;
use std::sync::Arc;
use conduit::{err, Error, Result};
use conduit::{
err,
utils::{stream::TryIgnore, ReadyExt},
Err, Error, Result,
};
use database::{Deserialized, Ignore, Interfix, Map};
use futures::{Stream, StreamExt};
use ruma::{
api::client::error::ErrorKind,
events::{
room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent},
StateEventType,
},
OwnedRoomAliasId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, RoomOrAliasId, UserId,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId, UserId,
};
use self::data::Data;
use crate::{admin, appservice, appservice::RegistrationInfo, globals, rooms, sending, Dep};
pub struct Service {
@ -21,6 +25,12 @@ pub struct Service {
services: Services,
}
struct Data {
alias_userid: Arc<Map>,
alias_roomid: Arc<Map>,
aliasid_alias: Arc<Map>,
}
struct Services {
admin: Dep<admin::Service>,
appservice: Dep<appservice::Service>,
@ -32,7 +42,11 @@ struct Services {
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(&args),
db: Data {
alias_userid: args.db["alias_userid"].clone(),
alias_roomid: args.db["alias_roomid"].clone(),
aliasid_alias: args.db["aliasid_alias"].clone(),
},
services: Services {
admin: args.depend::<admin::Service>("admin"),
appservice: args.depend::<appservice::Service>("appservice"),
@ -50,25 +64,52 @@ impl Service {
#[tracing::instrument(skip(self))]
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> {
if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user {
Err(Error::BadRequest(
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Only the server user can set this alias",
))
} else {
self.db.set_alias(alias, room_id, user_id)
));
}
// Comes first as we don't want a stuck alias
self.db
.alias_userid
.insert(alias.alias().as_bytes(), user_id.as_bytes());
self.db
.alias_roomid
.insert(alias.alias().as_bytes(), room_id.as_bytes());
let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF);
aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
self.db.aliasid_alias.insert(&aliasid, alias.as_bytes());
Ok(())
}
#[tracing::instrument(skip(self))]
pub async fn remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result<()> {
if self.user_can_remove_alias(alias, user_id).await? {
self.db.remove_alias(alias)
} else {
Err(Error::BadRequest(
ErrorKind::forbidden(),
"User is not permitted to remove this alias.",
))
if !self.user_can_remove_alias(alias, user_id).await? {
return Err!(Request(Forbidden("User is not permitted to remove this alias.")));
}
let alias = alias.alias();
let Ok(room_id) = self.db.alias_roomid.qry(&alias).await else {
return Err!(Request(NotFound("Alias does not exist or is invalid.")));
};
let prefix = (&room_id, Interfix);
self.db
.aliasid_alias
.keys_prefix(&prefix)
.ignore_err()
.ready_for_each(|key: &[u8]| self.db.aliasid_alias.remove(&key))
.await;
self.db.alias_roomid.remove(alias.as_bytes());
self.db.alias_userid.remove(alias.as_bytes());
Ok(())
}
pub async fn resolve(&self, room: &RoomOrAliasId) -> Result<OwnedRoomId> {
@ -97,9 +138,9 @@ impl Service {
return self.remote_resolve(room_alias, servers).await;
}
let room_id: Option<OwnedRoomId> = match self.resolve_local_alias(room_alias)? {
Some(r) => Some(r),
None => self.resolve_appservice_alias(room_alias).await?,
let room_id: Option<OwnedRoomId> = match self.resolve_local_alias(room_alias).await {
Ok(r) => Some(r),
Err(_) => self.resolve_appservice_alias(room_alias).await?,
};
room_id.map_or_else(
@ -109,46 +150,54 @@ impl Service {
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
self.db.resolve_local_alias(alias)
pub async fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<OwnedRoomId> {
self.db.alias_roomid.qry(alias.alias()).await.deserialized()
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn local_aliases_for_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a + Send> {
self.db.local_aliases_for_room(room_id)
pub fn local_aliases_for_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &RoomAliasId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.aliasid_alias
.stream_prefix(&prefix)
.ignore_err()
.map(|((Ignore, Ignore), alias): ((Ignore, Ignore), &RoomAliasId)| alias)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
self.db.all_local_aliases()
pub fn all_local_aliases<'a>(&'a self) -> impl Stream<Item = (&RoomId, &str)> + Send + 'a {
self.db
.alias_roomid
.stream()
.ignore_err()
.map(|(alias_localpart, room_id): (&str, &RoomId)| (room_id, alias_localpart))
}
async fn user_can_remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result<bool> {
let Some(room_id) = self.resolve_local_alias(alias)? else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias not found."));
};
let room_id = self
.resolve_local_alias(alias)
.await
.map_err(|_| err!(Request(NotFound("Alias not found."))))?;
let server_user = &self.services.globals.server_user;
// The creator of an alias can remove it
if self
.db
.who_created_alias(alias)?
.is_some_and(|user| user == user_id)
.who_created_alias(alias).await
.is_ok_and(|user| user == user_id)
// Server admins can remove any local alias
|| self.services.admin.user_is_admin(user_id).await?
|| self.services.admin.user_is_admin(user_id).await
// Always allow the server service account to remove the alias, since there may not be an admin room
|| server_user == user_id
{
Ok(true)
// Checking whether the user is able to change canonical aliases of the
// room
} else if let Some(event) =
self.services
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")?
} else if let Ok(event) = self
.services
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")
.await
{
serde_json::from_str(event.content.get())
.map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels"))
@ -157,10 +206,11 @@ impl Service {
})
// If there is no power levels event, only the room creator can change
// canonical aliases
} else if let Some(event) =
self.services
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
} else if let Ok(event) = self
.services
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")
.await
{
Ok(event.sender == user_id)
} else {
@ -168,6 +218,10 @@ impl Service {
}
}
async fn who_created_alias(&self, alias: &RoomAliasId) -> Result<OwnedUserId> {
self.db.alias_userid.qry(alias.alias()).await.deserialized()
}
async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
use ruma::api::appservice::query::query_room_alias;
@ -185,10 +239,11 @@ impl Service {
.await,
Ok(Some(_opt_result))
) {
return Ok(Some(
self.resolve_local_alias(room_alias)?
.ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?,
));
return self
.resolve_local_alias(room_alias)
.await
.map_err(|_| err!(Request(NotFound("Room does not exist."))))
.map(Some);
}
}

View file

@ -24,7 +24,7 @@ impl Data {
}
}
pub(super) fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
// Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
return Ok(Some(Arc::clone(result)));
@ -33,17 +33,14 @@ impl Data {
// We only save auth chains for single events in the db
if key.len() == 1 {
// Check DB cache
let chain = self
.shorteventid_authchain
.get(&key[0].to_be_bytes())?
.map(|chain| {
chain
.chunks_exact(size_of::<u64>())
.map(utils::u64_from_u8)
.collect::<Arc<[u64]>>()
});
let chain = self.shorteventid_authchain.qry(&key[0]).await.map(|chain| {
chain
.chunks_exact(size_of::<u64>())
.map(utils::u64_from_u8)
.collect::<Arc<[u64]>>()
});
if let Some(chain) = chain {
if let Ok(chain) = chain {
// Cache in RAM
self.auth_chain_cache
.lock()
@ -66,7 +63,7 @@ impl Data {
.iter()
.flat_map(|s| s.to_be_bytes().to_vec())
.collect::<Vec<u8>>(),
)?;
);
}
// Cache in RAM

View file

@ -5,7 +5,8 @@ use std::{
sync::Arc,
};
use conduit::{debug, error, trace, validated, warn, Err, Result};
use conduit::{debug, debug_error, trace, utils::IterStream, validated, warn, Err, Result};
use futures::{FutureExt, Stream, StreamExt};
use ruma::{EventId, RoomId};
use self::data::Data;
@ -38,7 +39,7 @@ impl crate::Service for Service {
impl Service {
pub async fn event_ids_iter<'a>(
&'a self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>,
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
) -> Result<impl Stream<Item = Arc<EventId>> + Send + 'a> {
let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len());
for starting_event in &starting_events_ {
starting_events.push(starting_event);
@ -48,7 +49,13 @@ impl Service {
.get_auth_chain(room_id, &starting_events)
.await?
.into_iter()
.filter_map(move |sid| self.services.short.get_eventid_from_short(sid).ok()))
.stream()
.filter_map(|sid| {
self.services
.short
.get_eventid_from_short(sid)
.map(Result::ok)
}))
}
#[tracing::instrument(skip_all, name = "auth_chain")]
@ -61,7 +68,8 @@ impl Service {
for (i, &short) in self
.services
.short
.multi_get_or_create_shorteventid(starting_events)?
.multi_get_or_create_shorteventid(starting_events)
.await
.iter()
.enumerate()
{
@ -85,7 +93,7 @@ impl Service {
}
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key)? {
if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key).await? {
trace!("Found cache entry for whole chunk");
full_auth_chain.extend(cached.iter().copied());
hits = hits.saturating_add(1);
@ -96,12 +104,12 @@ impl Service {
let mut misses2: usize = 0;
let mut chunk_cache = Vec::with_capacity(chunk.len());
for (sevent_id, event_id) in chunk {
if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id])? {
if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await? {
trace!(?event_id, "Found cache entry for event");
chunk_cache.extend(cached.iter().copied());
hits2 = hits2.saturating_add(1);
} else {
let auth_chain = self.get_auth_chain_inner(room_id, event_id)?;
let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?;
self.cache_auth_chain(vec![sevent_id], &auth_chain)?;
chunk_cache.extend(auth_chain.iter());
misses2 = misses2.saturating_add(1);
@ -143,15 +151,16 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id))]
fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> {
async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> {
let mut todo = vec![Arc::from(event_id)];
let mut found = HashSet::new();
while let Some(event_id) = todo.pop() {
trace!(?event_id, "processing auth event");
match self.services.timeline.get_pdu(&event_id) {
Ok(Some(pdu)) => {
match self.services.timeline.get_pdu(&event_id).await {
Err(e) => debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events"),
Ok(pdu) => {
if pdu.room_id != room_id {
return Err!(Request(Forbidden(
"auth event {event_id:?} for incorrect room {} which is not {}",
@ -160,7 +169,11 @@ impl Service {
)));
}
for auth_event in &pdu.auth_events {
let sauthevent = self.services.short.get_or_create_shorteventid(auth_event)?;
let sauthevent = self
.services
.short
.get_or_create_shorteventid(auth_event)
.await;
if found.insert(sauthevent) {
trace!(?event_id, ?auth_event, "adding auth event to processing queue");
@ -168,20 +181,14 @@ impl Service {
}
}
},
Ok(None) => {
warn!(?event_id, "Could not find pdu mentioned in auth events");
},
Err(error) => {
error!(?event_id, ?error, "Could not load event in auth chain");
},
}
}
Ok(found)
}
pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
self.db.get_cached_eventid_authchain(key)
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
self.db.get_cached_eventid_authchain(key).await
}
#[tracing::instrument(skip(self), level = "debug")]

View file

@ -1,39 +0,0 @@
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::{Database, Map};
use ruma::{OwnedRoomId, RoomId};
pub(super) struct Data {
publicroomids: Arc<Map>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
publicroomids: db["publicroomids"].clone(),
}
}
pub(super) fn set_public(&self, room_id: &RoomId) -> Result<()> {
self.publicroomids.insert(room_id.as_bytes(), &[])
}
pub(super) fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
self.publicroomids.remove(room_id.as_bytes())
}
pub(super) fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
}
pub(super) fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
}))
}
}

View file

@ -1,36 +1,44 @@
mod data;
use std::sync::Arc;
use conduit::Result;
use ruma::{OwnedRoomId, RoomId};
use self::data::Data;
use conduit::{implement, utils::stream::TryIgnore, Result};
use database::{Ignore, Map};
use futures::{Stream, StreamExt};
use ruma::RoomId;
pub struct Service {
db: Data,
}
struct Data {
publicroomids: Arc<Map>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data {
publicroomids: args.db["publicroomids"].clone(),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
#[tracing::instrument(skip(self), level = "debug")]
pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) }
#[implement(Service)]
pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_id.as_bytes(), &[]); }
#[tracing::instrument(skip(self), level = "debug")]
pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) }
#[implement(Service)]
pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id.as_bytes()); }
#[tracing::instrument(skip(self), level = "debug")]
pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { self.db.is_public_room(room_id) }
#[implement(Service)]
pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.db.publicroomids.qry(room_id).await.is_ok() }
#[tracing::instrument(skip(self), level = "debug")]
pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.public_rooms() }
#[implement(Service)]
pub fn public_rooms(&self) -> impl Stream<Item = &RoomId> + Send {
self.db
.publicroomids
.keys()
.ignore_err()
.map(|(room_id, _): (&RoomId, Ignore)| room_id)
}

File diff suppressed because it is too large Load diff

View file

@ -3,7 +3,9 @@ use ruma::{CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId};
use serde_json::value::RawValue as RawJsonValue;
impl super::Service {
pub fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
pub async fn parse_incoming_pdu(
&self, pdu: &RawJsonValue,
) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
debug_warn!("Error parsing incoming event {pdu:#?}");
err!(BadServerResponse("Error parsing incoming event {e:?}"))
@ -14,7 +16,7 @@ impl super::Service {
.and_then(|id| RoomId::parse(id.as_str()?).ok())
.ok_or(err!(Request(InvalidParam("Invalid room id in pdu"))))?;
let Ok(room_version_id) = self.services.state.get_room_version(&room_id) else {
let Ok(room_version_id) = self.services.state.get_room_version(&room_id).await else {
return Err!("Server is not in room {room_id}");
};

View file

@ -1,65 +0,0 @@
use std::sync::Arc;
use conduit::Result;
use database::{Database, Map};
use ruma::{DeviceId, RoomId, UserId};
pub(super) struct Data {
lazyloadedids: Arc<Map>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
lazyloadedids: db["lazyloadedids"].clone(),
}
}
pub(super) fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(ll_user.as_bytes());
Ok(self.lazyloadedids.get(&key)?.is_some())
}
pub(super) fn lazy_load_confirm_delivery(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for ll_id in confirmed_user_ids {
let mut key = prefix.clone();
key.extend_from_slice(ll_id.as_bytes());
self.lazyloadedids.insert(&key, &[])?;
}
Ok(())
}
pub(super) fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
self.lazyloadedids.remove(&key)?;
}
Ok(())
}
}

View file

@ -1,21 +1,26 @@
mod data;
use std::{
collections::{HashMap, HashSet},
fmt::Write,
sync::{Arc, Mutex},
};
use conduit::{PduCount, Result};
use conduit::{
implement,
utils::{stream::TryIgnore, ReadyExt},
PduCount, Result,
};
use database::{Interfix, Map};
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use self::data::Data;
pub struct Service {
pub lazy_load_waiting: Mutex<LazyLoadWaiting>,
lazy_load_waiting: Mutex<LazyLoadWaiting>,
db: Data,
}
struct Data {
lazyloadedids: Arc<Map>,
}
type LazyLoadWaiting = HashMap<LazyLoadWaitingKey, LazyLoadWaitingVal>;
type LazyLoadWaitingKey = (OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount);
type LazyLoadWaitingVal = HashSet<OwnedUserId>;
@ -23,8 +28,10 @@ type LazyLoadWaitingVal = HashSet<OwnedUserId>;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
lazy_load_waiting: Mutex::new(HashMap::new()),
db: Data::new(args.db),
lazy_load_waiting: LazyLoadWaiting::new().into(),
db: Data {
lazyloadedids: args.db["lazyloadedids"].clone(),
},
}))
}
@ -40,47 +47,60 @@ impl crate::Service for Service {
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
#[tracing::instrument(skip(self), level = "debug")]
pub fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool> {
self.db
.lazy_load_was_sent_before(user_id, device_id, room_id, ll_user)
}
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
#[inline]
pub async fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> bool {
let key = (user_id, device_id, room_id, ll_user);
self.db.lazyloadedids.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn lazy_load_mark_sent(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>,
count: PduCount,
) {
self.lazy_load_waiting
.lock()
.expect("locked")
.insert((user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count), lazy_load);
}
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn lazy_load_mark_sent(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>, count: PduCount,
) {
let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count);
#[tracing::instrument(skip(self), level = "debug")]
pub async fn lazy_load_confirm_delivery(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount,
) -> Result<()> {
if let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&(
user_id.to_owned(),
device_id.to_owned(),
room_id.to_owned(),
since,
)) {
self.db
.lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|u| &**u))?;
} else {
// Ignore
}
self.lazy_load_waiting
.lock()
.expect("locked")
.insert(key, lazy_load);
}
Ok(())
}
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn lazy_load_confirm_delivery(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount) {
let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), since);
#[tracing::instrument(skip(self), level = "debug")]
pub fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
self.db.lazy_load_reset(user_id, device_id, room_id)
let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&key) else {
return;
};
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for ll_id in &user_ids {
let mut key = prefix.clone();
key.extend_from_slice(ll_id.as_bytes());
self.db.lazyloadedids.insert(&key, &[]);
}
}
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub async fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) {
let prefix = (user_id, device_id, room_id, Interfix);
self.db
.lazyloadedids
.keys_raw_prefix(&prefix)
.ignore_err()
.ready_for_each(|key| self.db.lazyloadedids.remove(key))
.await;
}

View file

@ -1,110 +0,0 @@
use std::sync::Arc;
use conduit::{error, utils, Error, Result};
use database::Map;
use ruma::{OwnedRoomId, RoomId};
use crate::{rooms, Dep};
pub(super) struct Data {
disabledroomids: Arc<Map>,
bannedroomids: Arc<Map>,
roomid_shortroomid: Arc<Map>,
pduid_pdu: Arc<Map>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
}
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
disabledroomids: db["disabledroomids"].clone(),
bannedroomids: db["bannedroomids"].clone(),
roomid_shortroomid: db["roomid_shortroomid"].clone(),
pduid_pdu: db["pduid_pdu"].clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
},
}
}
pub(super) fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match self.services.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(),
None => return Ok(false),
};
// Look for PDUs in that room.
Ok(self
.pduid_pdu
.iter_from(&prefix, false)
.next()
.filter(|(k, _)| k.starts_with(&prefix))
.is_some())
}
pub(super) fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
}))
}
#[inline]
pub(super) fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
}
#[inline]
pub(super) fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
if disabled {
self.disabledroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.disabledroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
#[inline]
pub(super) fn is_banned(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some())
}
#[inline]
pub(super) fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
if banned {
self.bannedroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.bannedroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
pub(super) fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.bannedroomids.iter().map(
|(room_id_bytes, _ /* non-banned rooms should not be in this table */)| {
let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|e| {
error!("Invalid room_id bytes in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids.")
})?
.try_into()
.map_err(|e| {
error!("Invalid room_id in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids")
})?;
Ok(room_id)
},
))
}
}

View file

@ -1,51 +1,92 @@
mod data;
use std::sync::Arc;
use conduit::Result;
use ruma::{OwnedRoomId, RoomId};
use conduit::{implement, utils::stream::TryIgnore, Result};
use database::Map;
use futures::{Stream, StreamExt};
use ruma::RoomId;
use self::data::Data;
use crate::{rooms, Dep};
pub struct Service {
db: Data,
services: Services,
}
struct Data {
disabledroomids: Arc<Map>,
bannedroomids: Arc<Map>,
roomid_shortroomid: Arc<Map>,
pduid_pdu: Arc<Map>,
}
struct Services {
short: Dep<rooms::short::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(&args),
db: Data {
disabledroomids: args.db["disabledroomids"].clone(),
bannedroomids: args.db["bannedroomids"].clone(),
roomid_shortroomid: args.db["roomid_shortroomid"].clone(),
pduid_pdu: args.db["pduid_pdu"].clone(),
},
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Checks if a room exists.
#[inline]
pub fn exists(&self, room_id: &RoomId) -> Result<bool> { self.db.exists(room_id) }
#[implement(Service)]
pub async fn exists(&self, room_id: &RoomId) -> bool {
let Ok(prefix) = self.services.short.get_shortroomid(room_id).await else {
return false;
};
#[must_use]
pub fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { self.db.iter_ids() }
// Look for PDUs in that room.
self.db
.pduid_pdu
.keys_raw_prefix(&prefix)
.ignore_err()
.next()
.await
.is_some()
}
#[inline]
pub fn is_disabled(&self, room_id: &RoomId) -> Result<bool> { self.db.is_disabled(room_id) }
#[implement(Service)]
pub fn iter_ids(&self) -> impl Stream<Item = &RoomId> + Send + '_ { self.db.roomid_shortroomid.keys().ignore_err() }
#[inline]
pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
self.db.disable_room(room_id, disabled)
}
#[inline]
pub fn is_banned(&self, room_id: &RoomId) -> Result<bool> { self.db.is_banned(room_id) }
#[inline]
pub fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { self.db.ban_room(room_id, banned) }
#[inline]
#[must_use]
pub fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
self.db.list_banned_rooms()
#[implement(Service)]
#[inline]
pub fn disable_room(&self, room_id: &RoomId, disabled: bool) {
if disabled {
self.db.disabledroomids.insert(room_id.as_bytes(), &[]);
} else {
self.db.disabledroomids.remove(room_id.as_bytes());
}
}
#[implement(Service)]
#[inline]
pub fn ban_room(&self, room_id: &RoomId, banned: bool) {
if banned {
self.db.bannedroomids.insert(room_id.as_bytes(), &[]);
} else {
self.db.bannedroomids.remove(room_id.as_bytes());
}
}
#[implement(Service)]
pub fn list_banned_rooms(&self) -> impl Stream<Item = &RoomId> + Send + '_ { self.db.bannedroomids.keys().ignore_err() }
#[implement(Service)]
#[inline]
pub async fn is_disabled(&self, room_id: &RoomId) -> bool { self.db.disabledroomids.qry(room_id).await.is_ok() }
#[implement(Service)]
#[inline]
pub async fn is_banned(&self, room_id: &RoomId) -> bool { self.db.bannedroomids.qry(room_id).await.is_ok() }

View file

@ -1,42 +0,0 @@
use std::sync::Arc;
use conduit::{Error, Result};
use database::{Database, Map};
use ruma::{CanonicalJsonObject, EventId};
use crate::PduEvent;
pub(super) struct Data {
eventid_outlierpdu: Arc<Map>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
eventid_outlierpdu: db["eventid_outlierpdu"].clone(),
}
}
pub(super) fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
}
pub(super) fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
}
pub(super) fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
self.eventid_outlierpdu.insert(
event_id.as_bytes(),
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
)
}
}

View file

@ -1,9 +1,7 @@
mod data;
use std::sync::Arc;
use conduit::Result;
use data::Data;
use conduit::{implement, Result};
use database::{Deserialized, Map};
use ruma::{CanonicalJsonObject, EventId};
use crate::PduEvent;
@ -12,31 +10,48 @@ pub struct Service {
db: Data,
}
struct Data {
eventid_outlierpdu: Arc<Map>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data {
eventid_outlierpdu: args.db["eventid_outlierpdu"].clone(),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Returns the pdu from the outlier tree.
pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.db.get_outlier_pdu_json(event_id)
}
/// Returns the pdu from the outlier tree.
///
/// TODO: use this?
#[allow(dead_code)]
pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result<Option<PduEvent>> { self.db.get_outlier_pdu(event_id) }
/// Append the PDU as an outlier.
#[tracing::instrument(skip(self, pdu), level = "debug")]
pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
self.db.add_pdu_outlier(event_id, pdu)
}
/// Returns the pdu from the outlier tree.
#[implement(Service)]
pub async fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
self.db
.eventid_outlierpdu
.qry(event_id)
.await
.deserialized_json()
}
/// Returns the pdu from the outlier tree.
#[implement(Service)]
pub async fn get_pdu_outlier(&self, event_id: &EventId) -> Result<PduEvent> {
self.db
.eventid_outlierpdu
.qry(event_id)
.await
.deserialized_json()
}
/// Append the PDU as an outlier.
#[implement(Service)]
#[tracing::instrument(skip(self, pdu), level = "debug")]
pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) {
self.db.eventid_outlierpdu.insert(
event_id.as_bytes(),
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
);
}

View file

@ -1,7 +1,13 @@
use std::{mem::size_of, sync::Arc};
use conduit::{utils, Error, PduCount, PduEvent, Result};
use conduit::{
result::LogErr,
utils,
utils::{stream::TryIgnore, ReadyExt},
PduCount, PduEvent,
};
use database::Map;
use futures::{Stream, StreamExt};
use ruma::{EventId, RoomId, UserId};
use crate::{rooms, Dep};
@ -17,8 +23,7 @@ struct Services {
timeline: Dep<rooms::timeline::Service>,
}
type PdusIterItem = Result<(PduCount, PduEvent)>;
type PdusIterator<'a> = Box<dyn Iterator<Item = PdusIterItem> + 'a>;
pub(super) type PdusIterItem = (PduCount, PduEvent);
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
@ -33,19 +38,17 @@ impl Data {
}
}
pub(super) fn add_relation(&self, from: u64, to: u64) -> Result<()> {
pub(super) fn add_relation(&self, from: u64, to: u64) {
let mut key = to.to_be_bytes().to_vec();
key.extend_from_slice(&from.to_be_bytes());
self.tofrom_relation.insert(&key, &[])?;
Ok(())
self.tofrom_relation.insert(&key, &[]);
}
pub(super) fn relations_until<'a>(
&'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount,
) -> Result<PdusIterator<'a>> {
) -> impl Stream<Item = PdusIterItem> + Send + 'a + '_ {
let prefix = target.to_be_bytes().to_vec();
let mut current = prefix.clone();
let count_raw = match until {
PduCount::Normal(x) => x.saturating_sub(1),
PduCount::Backfilled(x) => {
@ -55,53 +58,42 @@ impl Data {
};
current.extend_from_slice(&count_raw.to_be_bytes());
Ok(Box::new(
self.tofrom_relation
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(tofrom, _data)| {
let from = utils::u64_from_bytes(&tofrom[(size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
self.tofrom_relation
.rev_raw_keys_from(&current)
.ignore_err()
.ready_take_while(move |key| key.starts_with(&prefix))
.map(|to_from| utils::u64_from_u8(&to_from[(size_of::<u64>())..]))
.filter_map(move |from| async move {
let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes());
let mut pdu = self.services.timeline.get_pdu_from_id(&pduid).await.ok()?;
let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes());
if pdu.sender != user_id {
pdu.remove_transaction_id().log_err().ok();
}
let mut pdu = self
.services
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((PduCount::Normal(from), pdu))
}),
))
Some((PduCount::Normal(from), pdu))
})
}
pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) {
for prev in event_ids {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes());
self.referencedevents.insert(&key, &[])?;
self.referencedevents.insert(&key, &[]);
}
Ok(())
}
pub(super) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(event_id.as_bytes());
Ok(self.referencedevents.get(&key)?.is_some())
pub(super) async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool {
let key = (room_id, event_id);
self.referencedevents.qry(&key).await.is_ok()
}
pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
self.softfailedeventids.insert(event_id.as_bytes(), &[])
pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) {
self.softfailedeventids.insert(event_id.as_bytes(), &[]);
}
pub(super) fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
self.softfailedeventids
.get(event_id.as_bytes())
.map(|o| o.is_some())
pub(super) async fn is_event_soft_failed(&self, event_id: &EventId) -> bool {
self.softfailedeventids.qry(event_id).await.is_ok()
}
}

View file

@ -1,8 +1,8 @@
mod data;
use std::sync::Arc;
use conduit::{PduCount, PduEvent, Result};
use conduit::{utils::stream::IterStream, PduCount, Result};
use futures::StreamExt;
use ruma::{
api::{client::relations::get_relating_events, Direction},
events::{relation::RelationType, TimelineEventType},
@ -10,7 +10,7 @@ use ruma::{
};
use serde::Deserialize;
use self::data::Data;
use self::data::{Data, PdusIterItem};
use crate::{rooms, Dep};
pub struct Service {
@ -51,21 +51,19 @@ impl crate::Service for Service {
impl Service {
#[tracing::instrument(skip(self, from, to), level = "debug")]
pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> {
pub fn add_relation(&self, from: PduCount, to: PduCount) {
match (from, to) {
(PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t),
_ => {
// TODO: Relations with backfilled pdus
Ok(())
},
}
}
#[allow(clippy::too_many_arguments)]
pub fn paginate_relations_with_filter(
&self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: &Option<TimelineEventType>,
filter_rel_type: &Option<RelationType>, from: &Option<String>, to: &Option<String>, limit: &Option<UInt>,
pub async fn paginate_relations_with_filter(
&self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: Option<TimelineEventType>,
filter_rel_type: Option<RelationType>, from: Option<&String>, to: Option<&String>, limit: Option<UInt>,
recurse: bool, dir: Direction,
) -> Result<get_relating_events::v1::Response> {
let from = match from {
@ -76,7 +74,7 @@ impl Service {
},
};
let to = to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
let to = to.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = limit
@ -92,30 +90,32 @@ impl Service {
1
};
let relations_until = &self.relations_until(sender_user, room_id, target, from, depth)?;
let events: Vec<_> = relations_until // TODO: should be relations_after
.iter()
.filter(|(_, pdu)| {
filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t)
&& if let Ok(content) =
serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get())
{
filter_rel_type
.as_ref()
.map_or(true, |r| &content.relates_to.rel_type == r)
} else {
false
}
})
.take(limit)
.filter(|(_, pdu)| {
self.services
.state_accessor
.user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false)
})
.take_while(|(k, _)| Some(k) != to.as_ref()) // Stop at `to`
.collect();
let relations_until: Vec<PdusIterItem> = self
.relations_until(sender_user, room_id, target, from, depth)
.await?;
// TODO: should be relations_after
let events: Vec<_> = relations_until
.into_iter()
.filter(move |(_, pdu): &PdusIterItem| {
if !filter_event_type.as_ref().map_or(true, |t| pdu.kind == *t) {
return false;
}
let Ok(content) = serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) else {
return false;
};
filter_rel_type
.as_ref()
.map_or(true, |r| *r == content.relates_to.rel_type)
})
.take(limit)
.take_while(|(k, _)| Some(*k) != to)
.stream()
.filter_map(|item| self.visibility_filter(sender_user, item))
.collect()
.await;
let next_token = events.last().map(|(count, _)| count).copied();
@ -125,9 +125,9 @@ impl Service {
.map(|(_, pdu)| pdu.to_message_like_event())
.collect(),
Direction::Backward => events
.into_iter()
.rev() // relations are always most recent first
.map(|(_, pdu)| pdu.to_message_like_event())
.into_iter()
.rev() // relations are always most recent first
.map(|(_, pdu)| pdu.to_message_like_event())
.collect(),
};
@ -135,68 +135,85 @@ impl Service {
chunk: events_chunk,
next_batch: next_token.map(|t| t.stringify()),
prev_batch: Some(from.stringify()),
recursion_depth: if recurse {
Some(depth.into())
} else {
None
},
recursion_depth: recurse.then_some(depth.into()),
})
}
pub fn relations_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8,
) -> Result<Vec<(PduCount, PduEvent)>> {
let room_id = self.services.short.get_or_create_shortroomid(room_id)?;
#[allow(unknown_lints)]
#[allow(clippy::manual_unwrap_or_default)]
let target = match self.services.timeline.get_pdu_count(target)? {
Some(PduCount::Normal(c)) => c,
async fn visibility_filter(&self, sender_user: &UserId, item: PdusIterItem) -> Option<PdusIterItem> {
let (_, pdu) = &item;
self.services
.state_accessor
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)
.await
.then_some(item)
}
pub async fn relations_until(
&self, user_id: &UserId, room_id: &RoomId, target: &EventId, until: PduCount, max_depth: u8,
) -> Result<Vec<PdusIterItem>> {
let room_id = self.services.short.get_or_create_shortroomid(room_id).await;
let target = match self.services.timeline.get_pdu_count(target).await {
Ok(PduCount::Normal(c)) => c,
// TODO: Support backfilled relations
_ => 0, // This will result in an empty iterator
};
self.db
let mut pdus: Vec<PdusIterItem> = self
.db
.relations_until(user_id, room_id, target, until)
.map(|mut relations| {
let mut pdus: Vec<_> = (*relations).into_iter().filter_map(Result::ok).collect();
let mut stack: Vec<_> = pdus.clone().iter().map(|pdu| (pdu.to_owned(), 1)).collect();
.collect()
.await;
while let Some(stack_pdu) = stack.pop() {
let target = match stack_pdu.0 .0 {
PduCount::Normal(c) => c,
// TODO: Support backfilled relations
PduCount::Backfilled(_) => 0, // This will result in an empty iterator
};
let mut stack: Vec<_> = pdus.clone().into_iter().map(|pdu| (pdu, 1)).collect();
if let Ok(relations) = self.db.relations_until(user_id, room_id, target, until) {
for relation in relations.flatten() {
if stack_pdu.1 < max_depth {
stack.push((relation.clone(), stack_pdu.1.saturating_add(1)));
}
while let Some(stack_pdu) = stack.pop() {
let target = match stack_pdu.0 .0 {
PduCount::Normal(c) => c,
// TODO: Support backfilled relations
PduCount::Backfilled(_) => 0, // This will result in an empty iterator
};
pdus.push(relation);
}
}
let relations: Vec<PdusIterItem> = self
.db
.relations_until(user_id, room_id, target, until)
.collect()
.await;
for relation in relations {
if stack_pdu.1 < max_depth {
stack.push((relation.clone(), stack_pdu.1.saturating_add(1)));
}
pdus.sort_by(|a, b| a.0.cmp(&b.0));
pdus
})
pdus.push(relation);
}
}
pdus.sort_by(|a, b| a.0.cmp(&b.0));
Ok(pdus)
}
#[inline]
#[tracing::instrument(skip_all, level = "debug")]
pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
self.db.mark_as_referenced(room_id, event_ids)
pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) {
self.db.mark_as_referenced(room_id, event_ids);
}
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
self.db.is_event_referenced(room_id, event_id)
pub async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool {
self.db.is_event_referenced(room_id, event_id).await
}
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.db.mark_event_soft_failed(event_id) }
pub fn mark_event_soft_failed(&self, event_id: &EventId) { self.db.mark_event_soft_failed(event_id) }
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { self.db.is_event_soft_failed(event_id) }
pub async fn is_event_soft_failed(&self, event_id: &EventId) -> bool {
self.db.is_event_soft_failed(event_id).await
}
}

View file

@ -1,10 +1,18 @@
use std::{mem::size_of, sync::Arc};
use conduit::{utils, Error, Result};
use database::Map;
use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, RoomId, UserId};
use conduit::{
utils,
utils::{stream::TryIgnore, ReadyExt},
Error, Result,
};
use database::{Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{
events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent},
serde::Raw,
CanonicalJsonObject, OwnedUserId, RoomId, UserId,
};
use super::AnySyncEphemeralRoomEventIter;
use crate::{globals, Dep};
pub(super) struct Data {
@ -18,6 +26,8 @@ struct Services {
globals: Dep<globals::Service>,
}
pub(super) type ReceiptItem = (OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>);
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
@ -31,7 +41,9 @@ impl Data {
}
}
pub(super) fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> {
pub(super) async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) {
type KeyVal<'a> = (&'a RoomId, u64, &'a UserId);
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
@ -39,108 +51,90 @@ impl Data {
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
// Remove old entry
if let Some((old, _)) = self
.readreceiptid_readreceipt
.iter_from(&last_possible_key, true)
.take_while(|(key, _)| key.starts_with(&prefix))
.find(|(key, _)| {
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element")
== user_id.as_bytes()
}) {
// This is the old room_latest
self.readreceiptid_readreceipt.remove(&old)?;
}
self.readreceiptid_readreceipt
.rev_keys_from_raw(&last_possible_key)
.ignore_err()
.ready_take_while(|(r, ..): &KeyVal<'_>| *r == room_id)
.ready_filter_map(|(r, c, u): KeyVal<'_>| (u == user_id).then_some((r, c, u)))
.ready_for_each(|old: KeyVal<'_>| {
// This is the old room_latest
self.readreceiptid_readreceipt.del(&old);
})
.await;
let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
room_latest_id.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes());
room_latest_id.push(0xFF);
room_latest_id.extend_from_slice(user_id.as_bytes());
self.readreceiptid_readreceipt.insert(
&room_latest_id,
&serde_json::to_vec(event).expect("EduEvent::to_string always works"),
)?;
Ok(())
);
}
pub(super) fn readreceipts_since<'a>(&'a self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'a> {
pub(super) fn readreceipts_since<'a>(
&'a self, room_id: &'a RoomId, since: u64,
) -> impl Stream<Item = ReceiptItem> + Send + 'a {
let after_since = since.saturating_add(1); // +1 so we don't send the event at since
let first_possible_edu = (room_id, after_since);
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
let prefix2 = prefix.clone();
let mut first_possible_edu = prefix.clone();
first_possible_edu.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); // +1 so we don't send the event at since
self.readreceiptid_readreceipt
.stream_raw_from(&first_possible_edu)
.ignore_err()
.ready_take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(k, v)| {
let count_offset = prefix.len().saturating_add(size_of::<u64>());
let user_id_offset = count_offset.saturating_add(1);
Box::new(
self.readreceiptid_readreceipt
.iter_from(&first_possible_edu, false)
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(k, v)| {
let count_offset = prefix.len().saturating_add(size_of::<u64>());
let count = utils::u64_from_bytes(&k[prefix.len()..count_offset])
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
let user_id_offset = count_offset.saturating_add(1);
let user_id = UserId::parse(
utils::string_from_bytes(&k[user_id_offset..])
.map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?,
)
let count = utils::u64_from_bytes(&k[prefix.len()..count_offset])
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
let user_id_str = utils::string_from_bytes(&k[user_id_offset..])
.map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?;
let user_id = UserId::parse(user_id_str)
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v)
.map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
json.remove("room_id");
let mut json = serde_json::from_slice::<CanonicalJsonObject>(v)
.map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
Ok((
user_id,
count,
Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")),
))
}),
)
json.remove("room_id");
let event = Raw::from_json(serde_json::value::to_raw_value(&json)?);
Ok((user_id, count, event))
})
.ignore_err()
}
pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.insert(&key, &count.to_be_bytes())?;
.insert(&key, &count.to_be_bytes());
self.roomuserid_lastprivatereadupdate
.insert(&key, &self.services.globals.next_count()?.to_be_bytes())
.insert(&key, &self.services.globals.next_count().unwrap().to_be_bytes());
}
pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.get(&key)?
.map_or(Ok(None), |v| {
Ok(Some(
utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?,
))
})
pub(super) async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
self.roomuserid_privateread.qry(&key).await.deserialized()
}
pub(super) fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
.roomuserid_lastprivatereadupdate
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
})
.transpose()?
.unwrap_or(0))
pub(super) async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
let key = (room_id, user_id);
self.roomuserid_lastprivatereadupdate
.qry(&key)
.await
.deserialized()
.unwrap_or(0)
}
}

View file

@ -3,16 +3,17 @@ mod data;
use std::{collections::BTreeMap, sync::Arc};
use conduit::{debug, Result};
use data::Data;
use futures::Stream;
use ruma::{
events::{
receipt::{ReceiptEvent, ReceiptEventContent},
AnySyncEphemeralRoomEvent, SyncEphemeralRoomEvent,
SyncEphemeralRoomEvent,
},
serde::Raw,
OwnedUserId, RoomId, UserId,
RoomId, UserId,
};
use self::data::{Data, ReceiptItem};
use crate::{sending, Dep};
pub struct Service {
@ -24,9 +25,6 @@ struct Services {
sending: Dep<sending::Service>,
}
type AnySyncEphemeralRoomEventIter<'a> =
Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
@ -42,44 +40,53 @@ impl crate::Service for Service {
impl Service {
/// Replaces the previous read receipt.
pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> {
self.db.readreceipt_update(user_id, room_id, event)?;
self.services.sending.flush_room(room_id)?;
Ok(())
pub async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) {
self.db.readreceipt_update(user_id, room_id, event).await;
self.services
.sending
.flush_room(room_id)
.await
.expect("room flush failed");
}
/// Returns an iterator over the most recent read_receipts in a room that
/// happened after the event with id `since`.
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn readreceipts_since<'a>(
&'a self, room_id: &RoomId, since: u64,
) -> impl Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a {
&'a self, room_id: &'a RoomId, since: u64,
) -> impl Stream<Item = ReceiptItem> + Send + 'a {
self.db.readreceipts_since(room_id, since)
}
/// Sets a private read marker at `count`.
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
self.db.private_read_set(room_id, user_id, count)
pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) {
self.db.private_read_set(room_id, user_id, count);
}
/// Returns the private read marker.
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
self.db.private_read_get(room_id, user_id)
pub async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
self.db.private_read_get(room_id, user_id).await
}
/// Returns the count of the last typing update in this room.
pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
self.db.last_privateread_update(user_id, room_id)
#[inline]
pub async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
self.db.last_privateread_update(user_id, room_id).await
}
}
#[must_use]
pub fn pack_receipts(receipts: AnySyncEphemeralRoomEventIter<'_>) -> Raw<SyncEphemeralRoomEvent<ReceiptEventContent>> {
pub fn pack_receipts<I>(receipts: I) -> Raw<SyncEphemeralRoomEvent<ReceiptEventContent>>
where
I: Iterator<Item = ReceiptItem>,
{
let mut json = BTreeMap::new();
for (_user, _count, value) in receipts.flatten() {
for (_, _, value) in receipts {
let receipt = serde_json::from_str::<SyncEphemeralRoomEvent<ReceiptEventContent>>(value.json().get());
if let Ok(value) = receipt {
for (event, receipt) in value.content {

View file

@ -1,13 +1,12 @@
use std::sync::Arc;
use conduit::{utils, Result};
use conduit::utils::{set, stream::TryIgnore, IterStream, ReadyExt};
use database::Map;
use futures::StreamExt;
use ruma::RoomId;
use crate::{rooms, Dep};
type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
pub(super) struct Data {
tokenids: Arc<Map>,
services: Services,
@ -28,7 +27,7 @@ impl Data {
}
}
pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) {
let batch = tokenize(message_body)
.map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec();
@ -39,11 +38,10 @@ impl Data {
})
.collect::<Vec<_>>();
self.tokenids
.insert_batch(batch.iter().map(database::KeyVal::from))
self.tokenids.insert_batch(batch.iter());
}
pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) {
let batch = tokenize(message_body).map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
@ -53,46 +51,53 @@ impl Data {
});
for token in batch {
self.tokenids.remove(&token)?;
self.tokenids.remove(&token);
}
Ok(())
}
pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
pub(super) async fn search_pdus(
&self, room_id: &RoomId, search_string: &str,
) -> Option<(Vec<Vec<u8>>, Vec<String>)> {
let prefix = self
.services
.short
.get_shortroomid(room_id)?
.expect("room exists")
.get_shortroomid(room_id)
.await
.ok()?
.to_be_bytes()
.to_vec();
let words: Vec<_> = tokenize(search_string).collect();
let iterators = words.clone().into_iter().map(move |word| {
let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xFF);
let prefix3 = prefix2.clone();
let bufs: Vec<_> = words
.clone()
.into_iter()
.stream()
.then(move |word| {
let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xFF);
let prefix3 = prefix2.clone();
let mut last_possible_id = prefix2.clone();
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
let mut last_possible_id = prefix2.clone();
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
self.tokenids
.iter_from(&last_possible_id, true) // Newest pdus first
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(key, _)| key[prefix3.len()..].to_vec())
});
self.tokenids
.rev_raw_keys_from(&last_possible_id) // Newest pdus first
.ignore_err()
.ready_take_while(move |key| key.starts_with(&prefix2))
.map(move |key| key[prefix3.len()..].to_vec())
.collect::<Vec<_>>()
})
.collect()
.await;
let Some(common_elements) = utils::common_elements(iterators, |a, b| {
// We compare b with a because we reversed the iterator earlier
b.cmp(a)
}) else {
return Ok(None);
};
Ok(Some((Box::new(common_elements), words)))
Some((
set::intersection(bufs.iter().map(|buf| buf.iter()))
.cloned()
.collect(),
words,
))
}
}
@ -100,7 +105,7 @@ impl Data {
///
/// This may be used to tokenize both message bodies (for indexing) or search
/// queries (for querying).
fn tokenize(body: &str) -> impl Iterator<Item = String> + '_ {
fn tokenize(body: &str) -> impl Iterator<Item = String> + Send + '_ {
body.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.filter(|word| word.len() <= 50)

View file

@ -21,20 +21,21 @@ impl crate::Service for Service {
}
impl Service {
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
self.db.index_pdu(shortroomid, pdu_id, message_body)
pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) {
self.db.index_pdu(shortroomid, pdu_id, message_body);
}
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
self.db.deindex_pdu(shortroomid, pdu_id, message_body)
pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) {
self.db.deindex_pdu(shortroomid, pdu_id, message_body);
}
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn search_pdus<'a>(
&'a self, room_id: &RoomId, search_string: &str,
) -> Result<Option<(impl Iterator<Item = Vec<u8>> + 'a, Vec<String>)>> {
self.db.search_pdus(room_id, search_string)
pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option<(Vec<Vec<u8>>, Vec<String>)> {
self.db.search_pdus(room_id, search_string).await
}
}

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use conduit::{utils, warn, Error, Result};
use database::Map;
use conduit::{err, utils, Error, Result};
use database::{Deserialized, Map};
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{globals, Dep};
@ -36,44 +36,46 @@ impl Data {
}
}
pub(super) fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? {
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
} else {
let shorteventid = self.services.globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid
};
pub(super) async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 {
if let Ok(shorteventid) = self.eventid_shorteventid.qry(event_id).await.deserialized() {
return shorteventid;
}
Ok(short)
let shorteventid = self.services.globals.next_count().unwrap();
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes());
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes());
shorteventid
}
pub(super) fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result<Vec<u64>> {
pub(super) async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec<u64> {
let mut ret: Vec<u64> = Vec::with_capacity(event_ids.len());
let keys = event_ids
.iter()
.map(|id| id.as_bytes())
.collect::<Vec<&[u8]>>();
for (i, short) in self
.eventid_shorteventid
.multi_get(&keys)?
.multi_get(keys.iter())
.iter()
.enumerate()
{
#[allow(clippy::single_match_else)]
match short {
Some(short) => ret.push(
utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
utils::u64_from_bytes(short)
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))
.unwrap(),
),
None => {
let short = self.services.globals.next_count()?;
let short = self.services.globals.next_count().unwrap();
self.eventid_shorteventid
.insert(keys[i], &short.to_be_bytes())?;
.insert(keys[i], &short.to_be_bytes());
self.shorteventid_eventid
.insert(&short.to_be_bytes(), keys[i])?;
.insert(&short.to_be_bytes(), keys[i]);
debug_assert!(ret.len() == i, "position of result must match input");
ret.push(short);
@ -81,115 +83,85 @@ impl Data {
}
}
Ok(ret)
ret
}
pub(super) fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes());
pub(super) async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
let key = (event_type, state_key);
self.statekey_shortstatekey.qry(&key).await.deserialized()
}
let short = self
.statekey_shortstatekey
.get(&statekey_vec)?
.map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
pub(super) async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 {
let key = (event_type.to_string(), state_key);
if let Ok(shortstatekey) = self.statekey_shortstatekey.qry(&key).await.deserialized() {
return shortstatekey;
}
let mut key = event_type.to_string().as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(state_key.as_bytes());
let shortstatekey = self.services.globals.next_count().unwrap();
self.statekey_shortstatekey
.insert(&key, &shortstatekey.to_be_bytes());
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &key);
shortstatekey
}
pub(super) async fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
self.shorteventid_eventid
.qry(&shorteventid)
.await
.deserialized()
.map_err(|e| err!(Database("Failed to find EventId from short {shorteventid:?}: {e:?}")))
}
pub(super) async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
self.shortstatekey_statekey
.qry(&shortstatekey)
.await
.deserialized()
.map_err(|e| {
err!(Database(
"Failed to find (StateEventType, state_key) from short {shortstatekey:?}: {e:?}"
))
})
.transpose()?;
Ok(short)
}
pub(super) fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes());
let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?
} else {
let shortstatekey = self.services.globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
shortstatekey
};
Ok(short)
}
pub(super) fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
let bytes = self
.shorteventid_eventid
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
let event_id = EventId::parse_arc(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
Ok(event_id)
}
pub(super) fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
let bytes = self
.shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
let mut parts = bytes.splitn(2, |&b| b == 0xFF);
let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
})?);
let state_key = utils::string_from_bytes(statekey_bytes)
.map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?;
let result = (event_type, state_key);
Ok(result)
}
/// Returns (shortstatehash, already_existed)
pub(super) fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? {
(
utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
true,
)
} else {
let shortstatehash = self.services.globals.next_count()?;
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
})
pub(super) async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) {
if let Ok(shortstatehash) = self
.statehash_shortstatehash
.qry(state_hash)
.await
.deserialized()
{
return (shortstatehash, true);
}
let shortstatehash = self.services.globals.next_count().unwrap();
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes());
(shortstatehash, false)
}
pub(super) fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
pub(super) async fn get_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
self.roomid_shortroomid.qry(room_id).await.deserialized()
}
pub(super) async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 {
self.roomid_shortroomid
.get(room_id.as_bytes())?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")))
.transpose()
}
pub(super) fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? {
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
} else {
let short = self.services.globals.next_count()?;
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short
})
.qry(room_id)
.await
.deserialized()
.unwrap_or_else(|_| {
let short = self.services.globals.next_count().unwrap();
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes());
short
})
}
}

View file

@ -22,38 +22,40 @@ impl crate::Service for Service {
}
impl Service {
pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
self.db.get_or_create_shorteventid(event_id)
pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 {
self.db.get_or_create_shorteventid(event_id).await
}
pub fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result<Vec<u64>> {
self.db.multi_get_or_create_shorteventid(event_ids)
pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec<u64> {
self.db.multi_get_or_create_shorteventid(event_ids).await
}
pub fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
self.db.get_shortstatekey(event_type, state_key)
pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
self.db.get_shortstatekey(event_type, state_key).await
}
pub fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
self.db.get_or_create_shortstatekey(event_type, state_key)
pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 {
self.db
.get_or_create_shortstatekey(event_type, state_key)
.await
}
pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
self.db.get_eventid_from_short(shorteventid)
pub async fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
self.db.get_eventid_from_short(shorteventid).await
}
pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
self.db.get_statekey_from_short(shortstatekey)
pub async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
self.db.get_statekey_from_short(shortstatekey).await
}
/// Returns (shortstatehash, already_existed)
pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
self.db.get_or_create_shortstatehash(state_hash)
pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) {
self.db.get_or_create_shortstatehash(state_hash).await
}
pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> { self.db.get_shortroomid(room_id) }
pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result<u64> { self.db.get_shortroomid(room_id).await }
pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
self.db.get_or_create_shortroomid(room_id)
pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 {
self.db.get_or_create_shortroomid(room_id).await
}
}

View file

@ -7,7 +7,12 @@ use std::{
sync::Arc,
};
use conduit::{checked, debug, debug_info, err, utils::math::usize_from_f64, warn, Error, Result};
use conduit::{
checked, debug, debug_info, err,
utils::{math::usize_from_f64, IterStream},
Error, Result,
};
use futures::{StreamExt, TryFutureExt};
use lru_cache::LruCache;
use ruma::{
api::{
@ -211,12 +216,15 @@ impl Service {
.as_ref()
{
return Ok(if let Some(cached) = cached {
if self.is_accessible_child(
current_room,
&cached.summary.join_rule,
&identifier,
&cached.summary.allowed_room_ids,
) {
if self
.is_accessible_child(
current_room,
&cached.summary.join_rule,
&identifier,
&cached.summary.allowed_room_ids,
)
.await
{
Some(SummaryAccessibility::Accessible(Box::new(cached.summary.clone())))
} else {
Some(SummaryAccessibility::Inaccessible)
@ -228,7 +236,9 @@ impl Service {
Ok(
if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? {
let summary = self.get_room_summary(current_room, children_pdus, &identifier);
let summary = self
.get_room_summary(current_room, children_pdus, &identifier)
.await;
if let Ok(summary) = summary {
self.roomid_spacehierarchy_cache.lock().await.insert(
current_room.clone(),
@ -322,12 +332,15 @@ impl Service {
);
}
}
if self.is_accessible_child(
current_room,
&response.room.join_rule,
&Identifier::UserId(user_id),
&response.room.allowed_room_ids,
) {
if self
.is_accessible_child(
current_room,
&response.room.join_rule,
&Identifier::UserId(user_id),
&response.room.allowed_room_ids,
)
.await
{
return Ok(Some(SummaryAccessibility::Accessible(Box::new(summary.clone()))));
}
@ -358,7 +371,7 @@ impl Service {
}
}
fn get_room_summary(
async fn get_room_summary(
&self, current_room: &OwnedRoomId, children_state: Vec<Raw<HierarchySpaceChildEvent>>,
identifier: &Identifier<'_>,
) -> Result<SpaceHierarchyParentSummary, Error> {
@ -367,48 +380,43 @@ impl Service {
let join_rule = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?
.map(|s| {
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")
.await
.map_or(JoinRule::Invite, |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| c.join_rule)
.map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}"))))
})
.transpose()?
.unwrap_or(JoinRule::Invite);
.unwrap()
});
let allowed_room_ids = self
.services
.state_accessor
.allowed_room_ids(join_rule.clone());
if !self.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) {
if !self
.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids)
.await
{
debug!("User is not allowed to see room {room_id}");
// This error will be caught later
return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room"));
}
let join_rule = join_rule.into();
Ok(SpaceHierarchyParentSummary {
canonical_alias: self
.services
.state_accessor
.get_canonical_alias(room_id)
.unwrap_or(None),
name: self
.services
.state_accessor
.get_name(room_id)
.unwrap_or(None),
.await
.ok(),
name: self.services.state_accessor.get_name(room_id).await.ok(),
num_joined_members: self
.services
.state_cache
.room_joined_count(room_id)
.unwrap_or_default()
.unwrap_or_else(|| {
warn!("Room {room_id} has no member count");
0
})
.await
.unwrap_or(0)
.try_into()
.expect("user count should not be that big"),
room_id: room_id.to_owned(),
@ -416,18 +424,29 @@ impl Service {
.services
.state_accessor
.get_room_topic(room_id)
.unwrap_or(None),
world_readable: self.services.state_accessor.is_world_readable(room_id)?,
guest_can_join: self.services.state_accessor.guest_can_join(room_id)?,
.await
.ok(),
world_readable: self
.services
.state_accessor
.is_world_readable(room_id)
.await,
guest_can_join: self.services.state_accessor.guest_can_join(room_id).await,
avatar_url: self
.services
.state_accessor
.get_avatar(room_id)?
.get_avatar(room_id)
.await
.into_option()
.unwrap_or_default()
.url,
join_rule,
room_type: self.services.state_accessor.get_room_type(room_id)?,
join_rule: join_rule.into(),
room_type: self
.services
.state_accessor
.get_room_type(room_id)
.await
.ok(),
children_state,
allowed_room_ids,
})
@ -474,21 +493,22 @@ impl Service {
results.push(summary_to_chunk(*summary.clone()));
} else {
children = children
.into_iter()
.rev()
.skip_while(|(room, _)| {
if let Ok(short) = self.services.short.get_shortroomid(room)
{
short.as_ref() != short_room_ids.get(parents.len())
} else {
false
}
})
.collect::<Vec<_>>()
// skip_while doesn't implement DoubleEndedIterator, which is needed for rev
.into_iter()
.rev()
.collect();
.iter()
.rev()
.stream()
.skip_while(|(room, _)| {
self.services
.short
.get_shortroomid(room)
.map_ok(|short| Some(&short) != short_room_ids.get(parents.len()))
.unwrap_or_else(|_| false)
})
.map(Clone::clone)
.collect::<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>()
.await
.into_iter()
.rev()
.collect();
if children.is_empty() {
return Err(Error::BadRequest(
@ -531,7 +551,7 @@ impl Service {
let mut short_room_ids = vec![];
for room in parents {
short_room_ids.push(self.services.short.get_or_create_shortroomid(&room)?);
short_room_ids.push(self.services.short.get_or_create_shortroomid(&room).await);
}
Some(
@ -554,7 +574,7 @@ impl Service {
async fn get_stripped_space_child_events(
&self, room_id: &RoomId,
) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> {
let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? else {
let Ok(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id).await else {
return Ok(None);
};
@ -562,10 +582,13 @@ impl Service {
.services
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
.await
.map_err(|e| err!(Database("State in space not found: {e}")))?;
let mut children_pdus = Vec::new();
for (key, id) in state {
let (event_type, state_key) = self.services.short.get_statekey_from_short(key)?;
let (event_type, state_key) = self.services.short.get_statekey_from_short(key).await?;
if event_type != StateEventType::SpaceChild {
continue;
}
@ -573,8 +596,9 @@ impl Service {
let pdu = self
.services
.timeline
.get_pdu(&id)?
.ok_or_else(|| Error::bad_database("Event in space state not found"))?;
.get_pdu(&id)
.await
.map_err(|e| err!(Database("Event {id:?} in space state not found: {e:?}")))?;
if serde_json::from_str::<SpaceChildEventContent>(pdu.content.get())
.ok()
@ -593,7 +617,7 @@ impl Service {
}
/// With the given identifier, checks if a room is accessable
fn is_accessible_child(
async fn is_accessible_child(
&self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>,
allowed_room_ids: &Vec<OwnedRoomId>,
) -> bool {
@ -607,6 +631,7 @@ impl Service {
.services
.event_handler
.acl_check(server_name, room_id)
.await
.is_err()
{
return false;
@ -617,12 +642,11 @@ impl Service {
.services
.state_cache
.is_joined(user_id, current_room)
.unwrap_or_default()
|| self
.services
.state_cache
.is_invited(user_id, current_room)
.unwrap_or_default()
.await || self
.services
.state_cache
.is_invited(user_id, current_room)
.await
{
return true;
}
@ -633,22 +657,12 @@ impl Service {
for room in allowed_room_ids {
match identifier {
Identifier::UserId(user) => {
if self
.services
.state_cache
.is_joined(user, room)
.unwrap_or_default()
{
if self.services.state_cache.is_joined(user, room).await {
return true;
}
},
Identifier::ServerName(server) => {
if self
.services
.state_cache
.server_in_room(server, room)
.unwrap_or_default()
{
if self.services.state_cache.server_in_room(server, room).await {
return true;
}
},

View file

@ -1,34 +1,31 @@
use std::{collections::HashSet, sync::Arc};
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::{Database, Map};
use ruma::{EventId, OwnedEventId, RoomId};
use conduit::{
utils::{stream::TryIgnore, ReadyExt},
Result,
};
use database::{Database, Deserialized, Interfix, Map};
use ruma::{OwnedEventId, RoomId};
use super::RoomMutexGuard;
pub(super) struct Data {
shorteventid_shortstatehash: Arc<Map>,
roomid_pduleaves: Arc<Map>,
roomid_shortstatehash: Arc<Map>,
pub(super) roomid_pduleaves: Arc<Map>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(),
roomid_pduleaves: db["roomid_pduleaves"].clone(),
roomid_shortstatehash: db["roomid_shortstatehash"].clone(),
roomid_pduleaves: db["roomid_pduleaves"].clone(),
}
}
pub(super) fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
})?))
})
pub(super) async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<u64> {
self.roomid_shortstatehash.qry(room_id).await.deserialized()
}
#[inline]
@ -37,53 +34,35 @@ impl Data {
room_id: &RoomId,
new_shortstatehash: u64,
_mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
) {
self.roomid_shortstatehash
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes());
}
pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) {
self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(())
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes());
}
pub(super) fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
self.roomid_pduleaves
.scan_prefix(prefix)
.map(|(_, bytes)| {
EventId::parse_arc(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
})
.collect()
}
pub(super) fn set_forward_extremities(
pub(super) async fn set_forward_extremities(
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
_mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
) {
let prefix = (room_id, Interfix);
self.roomid_pduleaves
.keys_raw_prefix(&prefix)
.ignore_err()
.ready_for_each(|key| self.roomid_pduleaves.remove(key))
.await;
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?;
}
for event_id in event_ids {
let mut key = prefix.clone();
key.extend_from_slice(event_id.as_bytes());
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
self.roomid_pduleaves.insert(&key, event_id.as_bytes());
}
Ok(())
}
}

View file

@ -7,12 +7,14 @@ use std::{
};
use conduit::{
utils::{calculate_hash, MutexMap, MutexMapGuard},
warn, Error, PduEvent, Result,
err,
utils::{calculate_hash, stream::TryIgnore, IterStream, MutexMap, MutexMapGuard},
warn, PduEvent, Result,
};
use data::Data;
use database::{Ignore, Interfix};
use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use ruma::{
api::client::error::ErrorKind,
events::{
room::{create::RoomCreateEventContent, member::RoomMemberEventContent},
AnyStrippedStateEvent, StateEventType, TimelineEventType,
@ -81,14 +83,16 @@ impl Service {
_statediffremoved: Arc<HashSet<CompressedStateEvent>>,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
for event_id in statediffnew.iter().filter_map(|new| {
let event_ids = statediffnew.iter().stream().filter_map(|new| {
self.services
.state_compressor
.parse_compressed_state_event(new)
.ok()
.map(|(_, id)| id)
}) {
let Some(pdu) = self.services.timeline.get_pdu_json(&event_id)? else {
.map_ok_or_else(|_| None, |(_, event_id)| Some(event_id))
});
pin_mut!(event_ids);
while let Some(event_id) = event_ids.next().await {
let Ok(pdu) = self.services.timeline.get_pdu_json(&event_id).await else {
continue;
};
@ -113,15 +117,10 @@ impl Service {
continue;
};
self.services.state_cache.update_membership(
room_id,
&user_id,
membership_event,
&pdu.sender,
None,
None,
false,
)?;
self.services
.state_cache
.update_membership(room_id, &user_id, membership_event, &pdu.sender, None, None, false)
.await?;
},
TimelineEventType::SpaceChild => {
self.services
@ -135,10 +134,9 @@ impl Service {
}
}
self.services.state_cache.update_joined_count(room_id)?;
self.services.state_cache.update_joined_count(room_id).await;
self.db
.set_room_state(room_id, shortstatehash, state_lock)?;
self.db.set_room_state(room_id, shortstatehash, state_lock);
Ok(())
}
@ -148,12 +146,16 @@ impl Service {
/// This adds all current state events (not including the incoming event)
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, state_ids_compressed), level = "debug")]
pub fn set_event_state(
pub async fn set_event_state(
&self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> Result<u64> {
let shorteventid = self.services.short.get_or_create_shorteventid(event_id)?;
let shorteventid = self
.services
.short
.get_or_create_shorteventid(event_id)
.await;
let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?;
let previous_shortstatehash = self.db.get_room_shortstatehash(room_id).await;
let state_hash = calculate_hash(
&state_ids_compressed
@ -165,13 +167,18 @@ impl Service {
let (shortstatehash, already_existed) = self
.services
.short
.get_or_create_shortstatehash(&state_hash)?;
.get_or_create_shortstatehash(&state_hash)
.await;
if !already_existed {
let states_parents = previous_shortstatehash.map_or_else(
|| Ok(Vec::new()),
|p| self.services.state_compressor.load_shortstatehash_info(p),
)?;
let states_parents = if let Ok(p) = previous_shortstatehash {
self.services
.state_compressor
.load_shortstatehash_info(p)
.await?
} else {
Vec::new()
};
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = state_ids_compressed
@ -198,7 +205,7 @@ impl Service {
)?;
}
self.db.set_event_state(shorteventid, shortstatehash)?;
self.db.set_event_state(shorteventid, shortstatehash);
Ok(shortstatehash)
}
@ -208,34 +215,40 @@ impl Service {
/// This adds all current state events (not including the incoming event)
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, new_pdu), level = "debug")]
pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
let shorteventid = self
.services
.short
.get_or_create_shorteventid(&new_pdu.event_id)?;
.get_or_create_shorteventid(&new_pdu.event_id)
.await;
let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?;
let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id).await;
if let Some(p) = previous_shortstatehash {
self.db.set_event_state(shorteventid, p)?;
if let Ok(p) = previous_shortstatehash {
self.db.set_event_state(shorteventid, p);
}
if let Some(state_key) = &new_pdu.state_key {
let states_parents = previous_shortstatehash.map_or_else(
|| Ok(Vec::new()),
#[inline]
|p| self.services.state_compressor.load_shortstatehash_info(p),
)?;
let states_parents = if let Ok(p) = previous_shortstatehash {
self.services
.state_compressor
.load_shortstatehash_info(p)
.await?
} else {
Vec::new()
};
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?;
.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)
.await;
let new = self
.services
.state_compressor
.compress_state_event(shortstatekey, &new_pdu.event_id)?;
.compress_state_event(shortstatekey, &new_pdu.event_id)
.await;
let replaces = states_parents
.last()
@ -276,49 +289,55 @@ impl Service {
}
#[tracing::instrument(skip(self, invite_event), level = "debug")]
pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
pub async fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
let mut state = Vec::new();
// Add recommended events
if let Some(e) =
self.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")?
if let Ok(e) = self
.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")
.await
{
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")?
if let Ok(e) = self
.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")
.await
{
state.push(e.to_stripped_state_event());
}
if let Some(e) = self.services.state_accessor.room_state_get(
&invite_event.room_id,
&StateEventType::RoomCanonicalAlias,
"",
)? {
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")?
if let Ok(e) = self
.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomCanonicalAlias, "")
.await
{
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")?
if let Ok(e) = self
.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")
.await
{
state.push(e.to_stripped_state_event());
}
if let Some(e) = self.services.state_accessor.room_state_get(
&invite_event.room_id,
&StateEventType::RoomMember,
invite_event.sender.as_str(),
)? {
if let Ok(e) = self
.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")
.await
{
state.push(e.to_stripped_state_event());
}
if let Ok(e) = self
.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomMember, invite_event.sender.as_str())
.await
{
state.push(e.to_stripped_state_event());
}
@ -333,101 +352,108 @@ impl Service {
room_id: &RoomId,
shortstatehash: u64,
mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
self.db.set_room_state(room_id, shortstatehash, mutex_lock)
) {
self.db.set_room_state(room_id, shortstatehash, mutex_lock);
}
/// Returns the room's version.
#[tracing::instrument(skip(self), level = "debug")]
pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
let create_event = self
.services
pub async fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
self.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?;
let create_event_content: RoomCreateEventContent = create_event
.as_ref()
.map(|create_event| {
serde_json::from_str(create_event.content.get()).map_err(|e| {
warn!("Invalid create event: {}", e);
Error::bad_database("Invalid create event in db.")
})
})
.transpose()?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "No create event found"))?;
Ok(create_event_content.room_version)
.room_state_get_content(room_id, &StateEventType::RoomCreate, "")
.await
.map(|content: RoomCreateEventContent| content.room_version)
.map_err(|e| err!(Request(NotFound("No create event found: {e:?}"))))
}
#[inline]
pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.db.get_room_shortstatehash(room_id)
pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<u64> {
self.db.get_room_shortstatehash(room_id).await
}
pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
self.db.get_forward_extremities(room_id)
pub fn get_forward_extremities<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &EventId> + Send + '_ {
let prefix = (room_id, Interfix);
self.db
.roomid_pduleaves
.keys_prefix(&prefix)
.map_ok(|(_, event_id): (Ignore, &EventId)| event_id)
.ignore_err()
}
pub fn set_forward_extremities(
pub async fn set_forward_extremities(
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
) {
self.db
.set_forward_extremities(room_id, event_ids, state_lock)
.await;
}
/// This fetches auth events from the current state.
#[tracing::instrument(skip(self), level = "debug")]
pub fn get_auth_events(
pub async fn get_auth_events(
&self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>,
content: &serde_json::value::RawValue,
) -> Result<StateMap<Arc<PduEvent>>> {
let Some(shortstatehash) = self.get_room_shortstatehash(room_id)? else {
let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else {
return Ok(HashMap::new());
};
let auth_events =
state_res::auth_types_for_event(kind, sender, state_key, content).expect("content is a valid JSON object");
let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content)?;
let mut sauthevents = auth_events
.into_iter()
let mut sauthevents: HashMap<_, _> = auth_events
.iter()
.stream()
.filter_map(|(event_type, state_key)| {
self.services
.short
.get_shortstatekey(&event_type.to_string().into(), &state_key)
.ok()
.flatten()
.map(|s| (s, (event_type, state_key)))
.get_shortstatekey(event_type, state_key)
.map_ok(move |s| (s, (event_type, state_key)))
.map(Result::ok)
})
.collect::<HashMap<_, _>>();
.collect()
.await;
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.load_shortstatehash_info(shortstatehash)
.await
.map_err(|e| {
err!(Database(
"Missing shortstatehash info for {room_id:?} at {shortstatehash:?}: {e:?}"
))
})?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.iter()
.filter_map(|compressed| {
self.services
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
})
.filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id)))
.filter_map(|(k, event_id)| {
self.services
.timeline
.get_pdu(&event_id)
.ok()
.flatten()
.map(|pdu| (k, pdu))
})
.collect())
let mut ret = HashMap::new();
for compressed in full_state.iter() {
let Ok((shortstatekey, event_id)) = self
.services
.state_compressor
.parse_compressed_state_event(compressed)
.await
else {
continue;
};
let Some((ty, state_key)) = sauthevents.remove(&shortstatekey) else {
continue;
};
let Ok(pdu) = self.services.timeline.get_pdu(&event_id).await else {
continue;
};
ret.insert((ty.to_owned(), state_key.to_owned()), pdu);
}
Ok(ret)
}
}

View file

@ -1,7 +1,8 @@
use std::{collections::HashMap, sync::Arc};
use conduit::{utils, Error, PduEvent, Result};
use database::Map;
use conduit::{err, PduEvent, Result};
use database::{Deserialized, Map};
use futures::TryFutureExt;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{rooms, Dep};
@ -39,17 +40,22 @@ impl Data {
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.load_shortstatehash_info(shortstatehash)
.await
.map_err(|e| err!(Database("Missing state IDs: {e}")))?
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new();
let mut i: u8 = 0;
for compressed in full_state.iter() {
let parsed = self
.services
.state_compressor
.parse_compressed_state_event(compressed)?;
.parse_compressed_state_event(compressed)
.await?;
result.insert(parsed.0, parsed.1);
i = i.wrapping_add(1);
@ -57,6 +63,7 @@ impl Data {
tokio::task::yield_now().await;
}
}
Ok(result)
}
@ -67,7 +74,8 @@ impl Data {
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.load_shortstatehash_info(shortstatehash)
.await?
.pop()
.expect("there is always one layer")
.1;
@ -78,18 +86,13 @@ impl Data {
let (_, eventid) = self
.services
.state_compressor
.parse_compressed_state_event(compressed)?;
if let Some(pdu) = self.services.timeline.get_pdu(&eventid)? {
result.insert(
(
pdu.kind.to_string().into(),
pdu.state_key
.as_ref()
.ok_or_else(|| Error::bad_database("State event has no state key."))?
.clone(),
),
pdu,
);
.parse_compressed_state_event(compressed)
.await?;
if let Ok(pdu) = self.services.timeline.get_pdu(&eventid).await {
if let Some(state_key) = pdu.state_key.as_ref() {
result.insert((pdu.kind.to_string().into(), state_key.clone()), pdu);
}
}
i = i.wrapping_add(1);
@ -101,61 +104,63 @@ impl Data {
Ok(result)
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
#[allow(clippy::unused_self)]
pub(super) fn state_get_id(
pub(super) async fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
let Some(shortstatekey) = self
) -> Result<Arc<EventId>> {
let shortstatekey = self
.services
.short
.get_shortstatekey(event_type, state_key)?
else {
return Ok(None);
};
.get_shortstatekey(event_type, state_key)
.await?;
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.load_shortstatehash_info(shortstatehash)
.await
.map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
let compressed = full_state
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| {
self.services
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
.ok_or(err!(Database("No shortstatekey in compressed state")))?;
self.services
.state_compressor
.parse_compressed_state_event(compressed)
.map_ok(|(_, id)| id)
.map_err(|e| {
err!(Database(error!(
?event_type,
?state_key,
?shortstatekey,
"Failed to parse compressed: {e:?}"
)))
})
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
pub(super) fn state_get(
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn state_get(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.state_get_id(shortstatehash, event_type, state_key)?
.map_or(Ok(None), |event_id| self.services.timeline.get_pdu(&event_id))
) -> Result<Arc<PduEvent>> {
self.state_get_id(shortstatehash, event_type, state_key)
.and_then(|event_id| async move { self.services.timeline.get_pdu(&event_id).await })
.await
}
/// Returns the state hash for this pdu.
pub(super) fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
pub(super) async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<u64> {
self.eventid_shorteventid
.get(event_id.as_bytes())?
.map_or(Ok(None), |shorteventid| {
self.shorteventid_shortstatehash
.get(&shorteventid)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash")
})
})
.transpose()
})
.qry(event_id)
.and_then(|shorteventid| self.shorteventid_shortstatehash.qry(&shorteventid))
.await
.deserialized()
}
/// Returns the full room state.
@ -163,34 +168,33 @@ impl Data {
pub(super) async fn room_state_full(
&self, room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? {
self.state_full(current_shortstatehash).await
} else {
Ok(HashMap::new())
}
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_full(shortstatehash))
.map_err(|e| err!(Database("Missing state for {room_id:?}: {e:?}")))
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
pub(super) fn room_state_get_id(
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn room_state_get_id(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? {
self.state_get_id(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
) -> Result<Arc<EventId>> {
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_get_id(shortstatehash, event_type, state_key))
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
pub(super) fn room_state_get(
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? {
self.state_get(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
) -> Result<Arc<PduEvent>> {
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key))
.await
}
}

View file

@ -6,8 +6,13 @@ use std::{
sync::{Arc, Mutex as StdMutex, Mutex},
};
use conduit::{err, error, pdu::PduBuilder, utils::math::usize_from_f64, warn, Error, PduEvent, Result};
use data::Data;
use conduit::{
err, error,
pdu::PduBuilder,
utils::{math::usize_from_f64, ReadyExt},
Error, PduEvent, Result,
};
use futures::StreamExt;
use lru_cache::LruCache;
use ruma::{
events::{
@ -31,8 +36,10 @@ use ruma::{
EventEncryptionAlgorithm, EventId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName,
UserId,
};
use serde::Deserialize;
use serde_json::value::to_raw_value;
use self::data::Data;
use crate::{rooms, rooms::state::RoomMutexGuard, Dep};
pub struct Service {
@ -99,54 +106,58 @@ impl Service {
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub fn state_get_id(
pub async fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
self.db.state_get_id(shortstatehash, event_type, state_key)
) -> Result<Arc<EventId>> {
self.db
.state_get_id(shortstatehash, event_type, state_key)
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
#[inline]
pub fn state_get(
pub async fn state_get(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.db.state_get(shortstatehash, event_type, state_key)
) -> Result<Arc<PduEvent>> {
self.db
.state_get(shortstatehash, event_type, state_key)
.await
}
/// Get membership for given user in state
fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result<MembershipState> {
self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())?
.map_or(Ok(MembershipState::Leave), |s| {
async fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> MembershipState {
self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())
.await
.map_or(MembershipState::Leave, |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomMemberEventContent| c.membership)
.map_err(|_| Error::bad_database("Invalid room membership event in database."))
.unwrap()
})
}
/// The user was a joined member at this state (potentially in the past)
#[inline]
fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id)
.is_ok_and(|s| s == MembershipState::Join)
// Return sensible default, i.e.
// false
async fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id).await == MembershipState::Join
}
/// The user was an invited or joined room member at this state (potentially
/// in the past)
#[inline]
fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id)
.is_ok_and(|s| s == MembershipState::Join || s == MembershipState::Invite)
// Return sensible default, i.e. false
async fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool {
let s = self.user_membership(shortstatehash, user_id).await;
s == MembershipState::Join || s == MembershipState::Invite
}
/// Whether a server is allowed to see an event through federation, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip(self, origin, room_id, event_id))]
pub fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else {
pub async fn server_can_see_event(
&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId,
) -> Result<bool> {
let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else {
return Ok(true);
};
@ -160,8 +171,9 @@ impl Service {
}
let history_visibility = self
.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?
.map_or(Ok(HistoryVisibility::Shared), |s| {
.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")
.await
.map_or(HistoryVisibility::Shared, |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility)
.map_err(|e| {
@ -171,25 +183,28 @@ impl Service {
);
Error::bad_database("Invalid history visibility event in database.")
})
})
.unwrap_or(HistoryVisibility::Shared);
.unwrap()
});
let mut current_server_members = self
let current_server_members = self
.services
.state_cache
.room_members(room_id)
.filter_map(Result::ok)
.filter(|member| member.server_name() == origin);
.ready_filter(|member| member.server_name() == origin);
let visibility = match history_visibility {
HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true,
HistoryVisibility::Invited => {
// Allow if any member on requesting server was AT LEAST invited, else deny
current_server_members.any(|member| self.user_was_invited(shortstatehash, &member))
current_server_members
.any(|member| self.user_was_invited(shortstatehash, member))
.await
},
HistoryVisibility::Joined => {
// Allow if any member on requested server was joined, else deny
current_server_members.any(|member| self.user_was_joined(shortstatehash, &member))
current_server_members
.any(|member| self.user_was_joined(shortstatehash, member))
.await
},
_ => {
error!("Unknown history visibility {history_visibility}");
@ -208,9 +223,9 @@ impl Service {
/// Whether a user is allowed to see an event, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip(self, user_id, room_id, event_id))]
pub fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else {
return Ok(true);
pub async fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> bool {
let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else {
return true;
};
if let Some(visibility) = self
@ -219,14 +234,15 @@ impl Service {
.unwrap()
.get_mut(&(user_id.to_owned(), shortstatehash))
{
return Ok(*visibility);
return *visibility;
}
let currently_member = self.services.state_cache.is_joined(user_id, room_id)?;
let currently_member = self.services.state_cache.is_joined(user_id, room_id).await;
let history_visibility = self
.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?
.map_or(Ok(HistoryVisibility::Shared), |s| {
.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")
.await
.map_or(HistoryVisibility::Shared, |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility)
.map_err(|e| {
@ -236,19 +252,19 @@ impl Service {
);
Error::bad_database("Invalid history visibility event in database.")
})
})
.unwrap_or(HistoryVisibility::Shared);
.unwrap()
});
let visibility = match history_visibility {
HistoryVisibility::WorldReadable => true,
HistoryVisibility::Shared => currently_member,
HistoryVisibility::Invited => {
// Allow if any member on requesting server was AT LEAST invited, else deny
self.user_was_invited(shortstatehash, user_id)
self.user_was_invited(shortstatehash, user_id).await
},
HistoryVisibility::Joined => {
// Allow if any member on requested server was joined, else deny
self.user_was_joined(shortstatehash, user_id)
self.user_was_joined(shortstatehash, user_id).await
},
_ => {
error!("Unknown history visibility {history_visibility}");
@ -261,17 +277,18 @@ impl Service {
.unwrap()
.insert((user_id.to_owned(), shortstatehash), visibility);
Ok(visibility)
visibility
}
/// Whether a user is allowed to see an event, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip(self, user_id, room_id))]
pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let currently_member = self.services.state_cache.is_joined(user_id, room_id)?;
pub async fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let currently_member = self.services.state_cache.is_joined(user_id, room_id).await;
let history_visibility = self
.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?
.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")
.await
.map_or(Ok(HistoryVisibility::Shared), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility)
@ -285,11 +302,13 @@ impl Service {
})
.unwrap_or(HistoryVisibility::Shared);
Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable)
currently_member || history_visibility == HistoryVisibility::WorldReadable
}
/// Returns the state hash for this pdu.
pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { self.db.pdu_shortstatehash(event_id) }
pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<u64> {
self.db.pdu_shortstatehash(event_id).await
}
/// Returns the full room state.
#[tracing::instrument(skip(self), level = "debug")]
@ -300,47 +319,61 @@ impl Service {
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_state_get_id(
pub async fn room_state_get_id(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
self.db.room_state_get_id(room_id, event_type, state_key)
) -> Result<Arc<EventId>> {
self.db
.room_state_get_id(room_id, event_type, state_key)
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_state_get(
pub async fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.db.room_state_get(room_id, event_type, state_key)
) -> Result<Arc<PduEvent>> {
self.db.room_state_get(room_id, event_type, state_key).await
}
pub fn get_name(&self, room_id: &RoomId) -> Result<Option<String>> {
self.room_state_get(room_id, &StateEventType::RoomName, "")?
.map_or(Ok(None), |s| {
Ok(serde_json::from_str(s.content.get()).map_or_else(|_| None, |c: RoomNameEventContent| Some(c.name)))
})
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub async fn room_state_get_content<T>(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<T>
where
T: for<'de> Deserialize<'de> + Send,
{
use serde_json::from_str;
self.room_state_get(room_id, event_type, state_key)
.await
.and_then(|event| from_str::<T>(event.content.get()).map_err(Into::into))
}
pub fn get_avatar(&self, room_id: &RoomId) -> Result<ruma::JsOption<RoomAvatarEventContent>> {
self.room_state_get(room_id, &StateEventType::RoomAvatar, "")?
.map_or(Ok(ruma::JsOption::Undefined), |s| {
pub async fn get_name(&self, room_id: &RoomId) -> Result<String> {
self.room_state_get_content(room_id, &StateEventType::RoomName, "")
.await
.map(|c: RoomNameEventContent| c.name)
}
pub async fn get_avatar(&self, room_id: &RoomId) -> ruma::JsOption<RoomAvatarEventContent> {
self.room_state_get(room_id, &StateEventType::RoomAvatar, "")
.await
.map_or(ruma::JsOption::Undefined, |s| {
serde_json::from_str(s.content.get())
.map_err(|_| Error::bad_database("Invalid room avatar event in database."))
.unwrap()
})
}
pub fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<RoomMemberEventContent>> {
self.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map_err(|_| Error::bad_database("Invalid room member event in database."))
})
pub async fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result<RoomMemberEventContent> {
self.room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str())
.await
}
pub fn user_can_invite(
pub async fn user_can_invite(
&self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard,
) -> Result<bool> {
) -> bool {
let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite))
.expect("Event content always serializes");
@ -353,122 +386,101 @@ impl Service {
timestamp: None,
};
Ok(self
.services
self.services
.timeline
.create_hash_and_sign_event(new_event, sender, room_id, state_lock)
.is_ok())
.await
.is_ok()
}
/// Checks if guests are able to view room content without joining
pub fn is_world_readable(&self, room_id: &RoomId) -> Result<bool, Error> {
self.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?
.map_or(Ok(false), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| {
c.history_visibility == HistoryVisibility::WorldReadable
})
.map_err(|e| {
error!(
"Invalid room history visibility event in database for room {room_id}, assuming not world \
readable: {e} "
);
Error::bad_database("Invalid room history visibility event in database.")
})
})
pub async fn is_world_readable(&self, room_id: &RoomId) -> bool {
self.room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "")
.await
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable)
.unwrap_or(false)
}
/// Checks if guests are able to join a given room
pub fn guest_can_join(&self, room_id: &RoomId) -> Result<bool, Error> {
self.room_state_get(room_id, &StateEventType::RoomGuestAccess, "")?
.map_or(Ok(false), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin)
.map_err(|_| Error::bad_database("Invalid room guest access event in database."))
})
pub async fn guest_can_join(&self, room_id: &RoomId) -> bool {
self.room_state_get_content(room_id, &StateEventType::RoomGuestAccess, "")
.await
.map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin)
.unwrap_or(false)
}
/// Gets the primary alias from canonical alias event
pub fn get_canonical_alias(&self, room_id: &RoomId) -> Result<Option<OwnedRoomAliasId>, Error> {
self.room_state_get(room_id, &StateEventType::RoomCanonicalAlias, "")?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomCanonicalAliasEventContent| c.alias)
.map_err(|_| Error::bad_database("Invalid canonical alias event in database."))
pub async fn get_canonical_alias(&self, room_id: &RoomId) -> Result<OwnedRoomAliasId> {
self.room_state_get_content(room_id, &StateEventType::RoomCanonicalAlias, "")
.await
.and_then(|c: RoomCanonicalAliasEventContent| {
c.alias
.ok_or_else(|| err!(Request(NotFound("No alias found in event content."))))
})
}
/// Gets the room topic
pub fn get_room_topic(&self, room_id: &RoomId) -> Result<Option<String>, Error> {
self.room_state_get(room_id, &StateEventType::RoomTopic, "")?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomTopicEventContent| Some(c.topic))
.map_err(|e| {
error!("Invalid room topic event in database for room {room_id}: {e}");
Error::bad_database("Invalid room topic event in database.")
})
})
pub async fn get_room_topic(&self, room_id: &RoomId) -> Result<String> {
self.room_state_get_content(room_id, &StateEventType::RoomTopic, "")
.await
.map(|c: RoomTopicEventContent| c.topic)
}
/// Checks if a given user can redact a given event
///
/// If federation is true, it allows redaction events from any user of the
/// same server as the original event sender
pub fn user_can_redact(
pub async fn user_can_redact(
&self, redacts: &EventId, sender: &UserId, room_id: &RoomId, federation: bool,
) -> Result<bool> {
self.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?
.map_or_else(
|| {
// Falling back on m.room.create to judge power level
if let Some(pdu) = self.room_state_get(room_id, &StateEventType::RoomCreate, "")? {
Ok(pdu.sender == sender
|| if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) {
pdu.sender == sender
} else {
false
})
if let Ok(event) = self
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")
.await
{
let Ok(event) = serde_json::from_str(event.content.get())
.map(|content: RoomPowerLevelsEventContent| content.into())
.map(|event: RoomPowerLevels| event)
else {
return Ok(false);
};
Ok(event.user_can_redact_event_of_other(sender)
|| event.user_can_redact_own_event(sender)
&& if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await {
if federation {
pdu.sender.server_name() == sender.server_name()
} else {
pdu.sender == sender
}
} else {
Err(Error::bad_database(
"No m.room.power_levels or m.room.create events in database for room",
))
}
},
|event| {
serde_json::from_str(event.content.get())
.map(|content: RoomPowerLevelsEventContent| content.into())
.map(|event: RoomPowerLevels| {
event.user_can_redact_event_of_other(sender)
|| event.user_can_redact_own_event(sender)
&& if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) {
if federation {
pdu.sender.server_name() == sender.server_name()
} else {
pdu.sender == sender
}
} else {
false
}
})
.map_err(|_| Error::bad_database("Invalid m.room.power_levels event in database"))
},
)
false
})
} else {
// Falling back on m.room.create to judge power level
if let Ok(pdu) = self
.room_state_get(room_id, &StateEventType::RoomCreate, "")
.await
{
Ok(pdu.sender == sender
|| if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await {
pdu.sender == sender
} else {
false
})
} else {
Err(Error::bad_database(
"No m.room.power_levels or m.room.create events in database for room",
))
}
}
}
/// Returns the join rule (`SpaceRoomJoinRule`) for a given room
pub fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec<OwnedRoomId>), Error> {
Ok(self
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?
.map(|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| {
(c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule))
})
.map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}"))))
})
.transpose()?
.unwrap_or((SpaceRoomJoinRule::Invite, vec![])))
pub async fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec<OwnedRoomId>)> {
self.room_state_get_content(room_id, &StateEventType::RoomJoinRules, "")
.await
.map(|c: RoomJoinRulesEventContent| (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule)))
.or_else(|_| Ok((SpaceRoomJoinRule::Invite, vec![])))
}
/// Returns an empty vec if not a restricted room
@ -487,25 +499,21 @@ impl Service {
room_ids
}
pub fn get_room_type(&self, room_id: &RoomId) -> Result<Option<RoomType>> {
Ok(self
.room_state_get(room_id, &StateEventType::RoomCreate, "")?
.map(|s| {
serde_json::from_str::<RoomCreateEventContent>(s.content.get())
.map_err(|e| err!(Database(error!("Invalid room create event in database: {e}"))))
pub async fn get_room_type(&self, room_id: &RoomId) -> Result<RoomType> {
self.room_state_get_content(room_id, &StateEventType::RoomCreate, "")
.await
.and_then(|content: RoomCreateEventContent| {
content
.room_type
.ok_or_else(|| err!(Request(NotFound("No type found in event content"))))
})
.transpose()?
.and_then(|e| e.room_type))
}
/// Gets the room's encryption algorithm if `m.room.encryption` state event
/// is found
pub fn get_room_encryption(&self, room_id: &RoomId) -> Result<Option<EventEncryptionAlgorithm>> {
self.room_state_get(room_id, &StateEventType::RoomEncryption, "")?
.map_or(Ok(None), |s| {
serde_json::from_str::<RoomEncryptionEventContent>(s.content.get())
.map(|content| Some(content.algorithm))
.map_err(|e| err!(Database(error!("Invalid room encryption event in database: {e}"))))
})
pub async fn get_room_encryption(&self, room_id: &RoomId) -> Result<EventEncryptionAlgorithm> {
self.room_state_get_content(room_id, &StateEventType::RoomEncryption, "")
.await
.map(|content: RoomEncryptionEventContent| content.algorithm)
}
}

View file

@ -1,43 +1,42 @@
use std::{
collections::{HashMap, HashSet},
collections::HashMap,
sync::{Arc, RwLock},
};
use conduit::{utils, Error, Result};
use database::Map;
use itertools::Itertools;
use conduit::{utils, utils::stream::TryIgnore, Error, Result};
use database::{Deserialized, Interfix, Map};
use futures::{Stream, StreamExt};
use ruma::{
events::{AnyStrippedStateEvent, AnySyncStateEvent},
serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
OwnedRoomId, RoomId, UserId,
};
use crate::{appservice::RegistrationInfo, globals, users, Dep};
use crate::{globals, Dep};
type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
type AppServiceInRoomCache = RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>;
type StrippedStateEventItem = (OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>);
type SyncStateEventItem = (OwnedRoomId, Vec<Raw<AnySyncStateEvent>>);
pub(super) struct Data {
pub(super) appservice_in_room_cache: AppServiceInRoomCache,
roomid_invitedcount: Arc<Map>,
roomid_inviteviaservers: Arc<Map>,
roomid_joinedcount: Arc<Map>,
roomserverids: Arc<Map>,
roomuserid_invitecount: Arc<Map>,
roomuserid_joined: Arc<Map>,
roomuserid_leftcount: Arc<Map>,
roomuseroncejoinedids: Arc<Map>,
serverroomids: Arc<Map>,
userroomid_invitestate: Arc<Map>,
userroomid_joined: Arc<Map>,
userroomid_leftstate: Arc<Map>,
pub(super) roomid_invitedcount: Arc<Map>,
pub(super) roomid_inviteviaservers: Arc<Map>,
pub(super) roomid_joinedcount: Arc<Map>,
pub(super) roomserverids: Arc<Map>,
pub(super) roomuserid_invitecount: Arc<Map>,
pub(super) roomuserid_joined: Arc<Map>,
pub(super) roomuserid_leftcount: Arc<Map>,
pub(super) roomuseroncejoinedids: Arc<Map>,
pub(super) serverroomids: Arc<Map>,
pub(super) userroomid_invitestate: Arc<Map>,
pub(super) userroomid_joined: Arc<Map>,
pub(super) userroomid_leftstate: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
users: Dep<users::Service>,
}
impl Data {
@ -59,19 +58,18 @@ impl Data {
userroomid_leftstate: db["userroomid_leftstate"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
users: args.depend::<users::Service>("users"),
},
}
}
pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.roomuseroncejoinedids.insert(&userroom_id, &[])
self.roomuseroncejoinedids.insert(&userroom_id, &[]);
}
pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) {
let roomid = room_id.as_bytes().to_vec();
let mut roomuser_id = roomid.clone();
@ -82,64 +80,17 @@ impl Data {
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_joined.insert(&userroom_id, &[])?;
self.roomuserid_joined.insert(&roomuser_id, &[])?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
self.userroomid_joined.insert(&userroom_id, &[]);
self.roomuserid_joined.insert(&roomuser_id, &[]);
self.userroomid_invitestate.remove(&userroom_id);
self.roomuserid_invitecount.remove(&roomuser_id);
self.userroomid_leftstate.remove(&userroom_id);
self.roomuserid_leftcount.remove(&roomuser_id);
self.roomid_inviteviaservers.remove(&roomid)?;
Ok(())
self.roomid_inviteviaservers.remove(&roomid);
}
pub(super) fn mark_as_invited(
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate.insert(
&userroom_id,
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
)?;
self.roomuserid_invitecount
.insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
if let Some(servers) = invite_via {
let mut prev_servers = self
.servers_invite_via(room_id)
.filter_map(Result::ok)
.collect_vec();
#[allow(clippy::redundant_clone)] // this is a necessary clone?
prev_servers.append(servers.clone().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers)?;
}
Ok(())
}
pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) {
let roomid = room_id.as_bytes().to_vec();
let mut roomuser_id = roomid.clone();
@ -153,115 +104,20 @@ impl Data {
self.userroomid_leftstate.insert(
&userroom_id,
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
)?; // TODO
); // TODO
self.roomuserid_leftcount
.insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
.insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes());
self.userroomid_joined.remove(&userroom_id);
self.roomuserid_joined.remove(&roomuser_id);
self.userroomid_invitestate.remove(&userroom_id);
self.roomuserid_invitecount.remove(&roomuser_id);
self.roomid_inviteviaservers.remove(&roomid)?;
Ok(())
}
pub(super) fn update_joined_count(&self, room_id: &RoomId) -> Result<()> {
let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64;
let mut joined_servers = HashSet::new();
for joined in self.room_members(room_id).filter_map(Result::ok) {
joined_servers.insert(joined.server_name().to_owned());
joinedcount = joinedcount.saturating_add(1);
}
for _invited in self.room_members_invited(room_id).filter_map(Result::ok) {
invitedcount = invitedcount.saturating_add(1);
}
self.roomid_joinedcount
.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
self.roomid_invitedcount
.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) {
if !joined_servers.remove(&old_joined_server) {
// Server not in room anymore
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.remove(&roomserver_id)?;
self.serverroomids.remove(&serverroom_id)?;
}
}
// Now only new servers are in joined_servers anymore
for server in joined_servers {
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(server.as_bytes());
let mut serverroom_id = server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.insert(&roomserver_id, &[])?;
self.serverroomids.insert(&serverroom_id, &[])?;
}
self.appservice_in_room_cache
.write()
.unwrap()
.remove(room_id);
Ok(())
}
#[tracing::instrument(skip(self, room_id, appservice), level = "debug")]
pub(super) fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> {
let maybe = self
.appservice_in_room_cache
.read()
.unwrap()
.get(room_id)
.and_then(|map| map.get(&appservice.registration.id))
.copied();
if let Some(b) = maybe {
Ok(b)
} else {
let bridge_user_id = UserId::parse_with_server_name(
appservice.registration.sender_localpart.as_str(),
self.services.globals.server_name(),
)
.ok();
let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
|| self
.room_members(room_id)
.any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str())));
self.appservice_in_room_cache
.write()
.unwrap()
.entry(room_id.to_owned())
.or_default()
.insert(appservice.registration.id.clone(), in_room);
Ok(in_room)
}
self.roomid_inviteviaservers.remove(&roomid);
}
/// Makes a user forget a room.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
@ -270,397 +126,69 @@ impl Data {
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
Ok(())
}
/// Returns an iterator of all servers participating in this room.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn room_servers<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
ServerName::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
}))
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
let mut key = server.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.serverroomids.get(&key).map(|o| o.is_some())
}
/// Returns an iterator of all rooms a server participates in (as far as we
/// know).
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn server_rooms<'a>(
&'a self, server: &ServerName,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
let mut prefix = server.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
}))
}
/// Returns an iterator of all joined members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn room_members<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + Send + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
}))
}
/// Returns an iterator of all our local users in the room, even if they're
/// deactivated/guests
pub(super) fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> {
Box::new(
self.room_members(room_id)
.filter_map(Result::ok)
.filter(|user| self.services.globals.user_is_local(user)),
)
}
/// Returns an iterator of all our local joined users in a room who are
/// active (not deactivated, not guest)
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn active_local_users_in_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> {
Box::new(
self.local_users_in_room(room_id)
.filter(|user| !self.services.users.is_deactivated(user).unwrap_or(true)),
)
}
/// Returns the number of users which are currently in a room
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_joinedcount
.get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
.transpose()
}
/// Returns the number of users which are currently invited to a room
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_invitedcount
.get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
.transpose()
}
/// Returns an iterator over all User IDs who ever joined a room.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn room_useroncejoined<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.roomuseroncejoinedids
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
}),
)
}
/// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn room_members_invited<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.roomuserid_invitecount
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
}),
)
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_invitecount
.get(&key)?
.map_or(Ok(None), |bytes| {
Ok(Some(
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?,
))
})
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_leftcount
.get(&key)?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db.")))
.transpose()
}
/// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn rooms_joined(&self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + '_> {
Box::new(
self.userroomid_joined
.scan_prefix(user_id.as_bytes().to_vec())
.map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
}),
)
self.userroomid_leftstate.remove(&userroom_id);
self.roomuserid_leftcount.remove(&roomuser_id);
}
/// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.userroomid_invitestate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok((room_id, state))
}),
)
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn invite_state(
&self, user_id: &UserId, room_id: &RoomId,
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
#[inline]
pub(super) fn rooms_invited<'a>(
&'a self, user_id: &'a UserId,
) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a {
let prefix = (user_id, Interfix);
self.userroomid_invitestate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
.stream_raw_prefix(&prefix)
.ignore_err()
.map(|(key, val)| {
let room_id = key.rsplit(|&b| b == 0xFF).next().unwrap();
let room_id = utils::string_from_bytes(room_id).unwrap();
let room_id = RoomId::parse(room_id).unwrap();
let state = serde_json::from_slice(val)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))
.unwrap();
Ok(state)
(room_id, state)
})
.transpose()
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn left_state(
pub(super) async fn invite_state(
&self, user_id: &UserId, room_id: &RoomId,
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
let key = (user_id, room_id);
self.userroomid_invitestate
.qry(&key)
.await
.deserialized_json()
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) async fn left_state(
&self, user_id: &UserId, room_id: &RoomId,
) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
let key = (user_id, room_id);
self.userroomid_leftstate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok(state)
})
.transpose()
.qry(&key)
.await
.deserialized_json()
}
/// Returns an iterator over all rooms a user left.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
#[inline]
pub(super) fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = SyncStateEventItem> + Send + 'a {
let prefix = (user_id, Interfix);
self.userroomid_leftstate
.stream_raw_prefix(&prefix)
.ignore_err()
.map(|(key, val)| {
let room_id = key.rsplit(|&b| b == 0xFF).next().unwrap();
let room_id = utils::string_from_bytes(room_id).unwrap();
let room_id = RoomId::parse(room_id).unwrap();
let state = serde_json::from_slice(val)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))
.unwrap();
Box::new(
self.userroomid_leftstate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok((room_id, state))
}),
)
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn servers_invite_via<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let key = room_id.as_bytes().to_vec();
Box::new(
self.roomid_inviteviaservers
.scan_prefix(key)
.map(|(_, servers)| {
ServerName::parse(
utils::string_from_bytes(
servers
.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Server name in roomid_inviteviaservers is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Server name in roomid_inviteviaservers is invalid."))
}),
)
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> {
let mut prev_servers = self
.servers_invite_via(room_id)
.filter_map(Result::ok)
.collect_vec();
prev_servers.extend(servers.to_owned());
prev_servers.sort_unstable();
prev_servers.dedup();
let servers = prev_servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers)?;
Ok(())
(room_id, state)
})
}
}

View file

@ -1,9 +1,15 @@
mod data;
use std::sync::Arc;
use std::{collections::HashSet, sync::Arc};
use conduit::{err, error, warn, Error, Result};
use conduit::{
err,
utils::{stream::TryIgnore, ReadyExt},
warn, Result,
};
use data::Data;
use database::{Deserialized, Ignore, Interfix};
use futures::{Stream, StreamExt};
use itertools::Itertools;
use ruma::{
events::{
@ -18,7 +24,7 @@ use ruma::{
},
int,
serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId,
};
use crate::{account_data, appservice::RegistrationInfo, globals, rooms, users, Dep};
@ -55,7 +61,7 @@ impl Service {
/// Update current membership data.
#[tracing::instrument(skip(self, last_state))]
#[allow(clippy::too_many_arguments)]
pub fn update_membership(
pub async fn update_membership(
&self, room_id: &RoomId, user_id: &UserId, membership_event: RoomMemberEventContent, sender: &UserId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, invite_via: Option<Vec<OwnedServerName>>,
update_joined_count: bool,
@ -68,7 +74,7 @@ impl Service {
// update
#[allow(clippy::collapsible_if)]
if !self.services.globals.user_is_local(user_id) {
if !self.services.users.exists(user_id)? {
if !self.services.users.exists(user_id).await {
self.services.users.create(user_id, None)?;
}
@ -100,17 +106,17 @@ impl Service {
match &membership {
MembershipState::Join => {
// Check if the user never joined this room
if !self.once_joined(user_id, room_id)? {
if !self.once_joined(user_id, room_id).await {
// Add the user ID to the join list then
self.db.mark_as_once_joined(user_id, room_id)?;
self.db.mark_as_once_joined(user_id, room_id);
// Check if the room has a predecessor
if let Some(predecessor) = self
if let Ok(Some(predecessor)) = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?
.and_then(|create| serde_json::from_str(create.content.get()).ok())
.and_then(|content: RoomCreateEventContent| content.predecessor)
.room_state_get_content(room_id, &StateEventType::RoomCreate, "")
.await
.map(|content: RoomCreateEventContent| content.predecessor)
{
// Copy user settings from predecessor to the current room:
// - Push rules
@ -138,32 +144,33 @@ impl Service {
// .ok();
// Copy old tags to new room
if let Some(tag_event) = self
if let Ok(tag_event) = self
.services
.account_data
.get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)?
.map(|event| {
.get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)
.await
.and_then(|event| {
serde_json::from_str(event.get())
.map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}"))))
}) {
self.services
.account_data
.update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event?)
.update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event)
.await
.ok();
};
// Copy direct chat flag
if let Some(direct_event) = self
if let Ok(mut direct_event) = self
.services
.account_data
.get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())?
.map(|event| {
.get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())
.await
.and_then(|event| {
serde_json::from_str::<DirectEvent>(event.get())
.map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}"))))
}) {
let mut direct_event = direct_event?;
let mut room_ids_updated = false;
for room_ids in direct_event.content.0.values_mut() {
if room_ids.iter().any(|r| r == &predecessor.room_id) {
room_ids.push(room_id.to_owned());
@ -172,18 +179,21 @@ impl Service {
}
if room_ids_updated {
self.services.account_data.update(
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
&serde_json::to_value(&direct_event).expect("to json always works"),
)?;
self.services
.account_data
.update(
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
&serde_json::to_value(&direct_event).expect("to json always works"),
)
.await?;
}
};
}
}
self.db.mark_as_joined(user_id, room_id)?;
self.db.mark_as_joined(user_id, room_id);
},
MembershipState::Invite => {
// We want to know if the sender is ignored by the receiver
@ -196,12 +206,12 @@ impl Service {
GlobalAccountDataEventType::IgnoredUserList
.to_string()
.into(),
)?
.map(|event| {
)
.await
.and_then(|event| {
serde_json::from_str::<IgnoredUserListEvent>(event.get())
.map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}"))))
})
.transpose()?
.map_or(false, |ignored| {
ignored
.content
@ -214,194 +224,282 @@ impl Service {
return Ok(());
}
self.db
.mark_as_invited(user_id, room_id, last_state, invite_via)?;
self.mark_as_invited(user_id, room_id, last_state, invite_via)
.await;
},
MembershipState::Leave | MembershipState::Ban => {
self.db.mark_as_left(user_id, room_id)?;
self.db.mark_as_left(user_id, room_id);
},
_ => {},
}
if update_joined_count {
self.update_joined_count(room_id)?;
self.update_joined_count(room_id).await;
}
Ok(())
}
#[tracing::instrument(skip(self, room_id), level = "debug")]
pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { self.db.update_joined_count(room_id) }
#[tracing::instrument(skip(self, room_id, appservice), level = "debug")]
pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> {
self.db.appservice_in_room(room_id, appservice)
pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> bool {
let maybe = self
.db
.appservice_in_room_cache
.read()
.unwrap()
.get(room_id)
.and_then(|map| map.get(&appservice.registration.id))
.copied();
if let Some(b) = maybe {
b
} else {
let bridge_user_id = UserId::parse_with_server_name(
appservice.registration.sender_localpart.as_str(),
self.services.globals.server_name(),
)
.ok();
let in_room = if let Some(id) = &bridge_user_id {
self.is_joined(id, room_id).await
} else {
false
};
let in_room = in_room
|| self
.room_members(room_id)
.ready_any(|userid| appservice.users.is_match(userid.as_str()))
.await;
self.db
.appservice_in_room_cache
.write()
.unwrap()
.entry(room_id.to_owned())
.or_default()
.insert(appservice.registration.id.clone(), in_room);
in_room
}
}
/// Direct DB function to directly mark a user as left. It is not
/// recommended to use this directly. You most likely should use
/// `update_membership` instead
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
self.db.mark_as_left(user_id, room_id)
}
pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { self.db.mark_as_left(user_id, room_id); }
/// Direct DB function to directly mark a user as joined. It is not
/// recommended to use this directly. You most likely should use
/// `update_membership` instead
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
self.db.mark_as_joined(user_id, room_id)
}
pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { self.db.mark_as_joined(user_id, room_id); }
/// Makes a user forget a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { self.db.forget(room_id, user_id) }
pub fn forget(&self, room_id: &RoomId, user_id: &UserId) { self.db.forget(room_id, user_id); }
/// Returns an iterator of all servers participating in this room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_servers(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ {
self.db.room_servers(room_id)
pub fn room_servers<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &ServerName> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomserverids
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, server): (Ignore, &ServerName)| server)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
self.db.server_in_room(server, room_id)
pub async fn server_in_room<'a>(&'a self, server: &'a ServerName, room_id: &'a RoomId) -> bool {
let key = (server, room_id);
self.db.serverroomids.qry(&key).await.is_ok()
}
/// Returns an iterator of all rooms a server participates in (as far as we
/// know).
#[tracing::instrument(skip(self), level = "debug")]
pub fn server_rooms(&self, server: &ServerName) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ {
self.db.server_rooms(server)
pub fn server_rooms<'a>(&'a self, server: &'a ServerName) -> impl Stream<Item = &RoomId> + Send + 'a {
let prefix = (server, Interfix);
self.db
.serverroomids
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, room_id): (Ignore, &RoomId)| room_id)
}
/// Returns true if server can see user by sharing at least one room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> Result<bool> {
Ok(self
.server_rooms(server)
.filter_map(Result::ok)
.any(|room_id: OwnedRoomId| self.is_joined(user_id, &room_id).unwrap_or(false)))
pub async fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> bool {
self.server_rooms(server)
.any(|room_id| self.is_joined(user_id, room_id))
.await
}
/// Returns true if user_a and user_b share at least one room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> Result<bool> {
pub async fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> bool {
// Minimize number of point-queries by iterating user with least nr rooms
let (a, b) = if self.rooms_joined(user_a).count() < self.rooms_joined(user_b).count() {
let (a, b) = if self.rooms_joined(user_a).count().await < self.rooms_joined(user_b).count().await {
(user_a, user_b)
} else {
(user_b, user_a)
};
Ok(self
.rooms_joined(a)
.filter_map(Result::ok)
.any(|room_id| self.is_joined(b, &room_id).unwrap_or(false)))
self.rooms_joined(a)
.any(|room_id| self.is_joined(b, room_id))
.await
}
/// Returns an iterator over all joined members of a room.
/// Returns an iterator of all joined members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_members(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + Send + '_ {
self.db.room_members(room_id)
pub fn room_members<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_joined
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, user_id): (Ignore, &UserId)| user_id)
}
/// Returns the number of users which are currently in a room
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> { self.db.room_joined_count(room_id) }
pub async fn room_joined_count(&self, room_id: &RoomId) -> Result<u64> {
self.db.roomid_joinedcount.qry(room_id).await.deserialized()
}
#[tracing::instrument(skip(self), level = "debug")]
/// Returns an iterator of all our local users in the room, even if they're
/// deactivated/guests
pub fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator<Item = OwnedUserId> + 'a {
self.db.local_users_in_room(room_id)
pub fn local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
self.room_members(room_id)
.ready_filter(|user| self.services.globals.user_is_local(user))
}
#[tracing::instrument(skip(self), level = "debug")]
/// Returns an iterator of all our local joined users in a room who are
/// active (not deactivated, not guest)
pub fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator<Item = OwnedUserId> + 'a {
self.db.active_local_users_in_room(room_id)
pub fn active_local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
self.local_users_in_room(room_id)
.filter(|user| self.services.users.is_active(user))
}
/// Returns the number of users which are currently invited to a room
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { self.db.room_invited_count(room_id) }
pub async fn room_invited_count(&self, room_id: &RoomId) -> Result<u64> {
self.db
.roomid_invitedcount
.qry(room_id)
.await
.deserialized()
}
/// Returns an iterator over all User IDs who ever joined a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_useroncejoined(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + '_ {
self.db.room_useroncejoined(room_id)
pub fn room_useroncejoined<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuseroncejoinedids
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, user_id): (Ignore, &UserId)| user_id)
}
/// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + '_ {
self.db.room_members_invited(room_id)
pub fn room_members_invited<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_invitecount
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, user_id): (Ignore, &UserId)| user_id)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
self.db.get_invite_count(room_id, user_id)
pub async fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
self.db
.roomuserid_invitecount
.qry(&key)
.await
.deserialized()
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
self.db.get_left_count(room_id, user_id)
pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
self.db.roomuserid_leftcount.qry(&key).await.deserialized()
}
/// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self), level = "debug")]
pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ {
self.db.rooms_joined(user_id)
pub fn rooms_joined(&self, user_id: &UserId) -> impl Stream<Item = &RoomId> + Send {
self.db
.userroomid_joined
.keys_prefix(user_id)
.ignore_err()
.map(|(_, room_id): (Ignore, &RoomId)| room_id)
}
/// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self), level = "debug")]
pub fn rooms_invited(
&self, user_id: &UserId,
) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + '_ {
pub fn rooms_invited<'a>(
&'a self, user_id: &'a UserId,
) -> impl Stream<Item = (OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)> + Send + 'a {
self.db.rooms_invited(user_id)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
self.db.invite_state(user_id, room_id)
pub async fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
self.db.invite_state(user_id, room_id).await
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
self.db.left_state(user_id, room_id)
pub async fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
self.db.left_state(user_id, room_id).await
}
/// Returns an iterator over all rooms a user left.
#[tracing::instrument(skip(self), level = "debug")]
pub fn rooms_left(
&self, user_id: &UserId,
) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + '_ {
pub fn rooms_left<'a>(
&'a self, user_id: &'a UserId,
) -> impl Stream<Item = (OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)> + Send + 'a {
self.db.rooms_left(user_id)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
self.db.once_joined(user_id, room_id)
pub async fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
self.db.roomuseroncejoinedids.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_joined(user_id, room_id) }
#[tracing::instrument(skip(self), level = "debug")]
pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
self.db.is_invited(user_id, room_id)
pub async fn is_joined<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_joined.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_left(user_id, room_id) }
pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_invitestate.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ {
self.db.servers_invite_via(room_id)
pub async fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_leftstate.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> impl Stream<Item = &ServerName> + Send + 'a {
self.db
.roomid_inviteviaservers
.stream_prefix(room_id)
.ignore_err()
.map(|(_, servers): (Ignore, Vec<&ServerName>)| &**(servers.last().expect("at least one servername")))
}
/// Gets up to three servers that are likely to be in the room in the
@ -409,37 +507,27 @@ impl Service {
///
/// See <https://spec.matrix.org/v1.10/appendices/#routing>
#[tracing::instrument(skip(self))]
pub fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> {
pub async fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> {
let most_powerful_user_server = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?
.map(|pdu| {
serde_json::from_str(pdu.content.get()).map(|conent: RoomPowerLevelsEventContent| {
conent
.users
.iter()
.max_by_key(|(_, power)| *power)
.and_then(|x| {
if x.1 >= &int!(50) {
Some(x)
} else {
None
}
})
.map(|(user, _power)| user.server_name().to_owned())
})
.room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "")
.await
.map(|content: RoomPowerLevelsEventContent| {
content
.users
.iter()
.max_by_key(|(_, power)| *power)
.and_then(|x| (x.1 >= &int!(50)).then_some(x))
.map(|(user, _power)| user.server_name().to_owned())
})
.transpose()
.map_err(|e| {
error!("Invalid power levels event content in database: {e}");
Error::bad_database("Invalid power levels event content in database")
})?
.flatten();
.map_err(|e| err!(Database(error!(?e, "Invalid power levels event content in database."))))?;
let mut servers: Vec<OwnedServerName> = self
.room_members(room_id)
.filter_map(Result::ok)
.collect::<Vec<_>>()
.await
.iter()
.counts_by(|user| user.server_name().to_owned())
.iter()
.sorted_by_key(|(_, users)| *users)
@ -468,4 +556,139 @@ impl Service {
.expect("locked")
.clear();
}
pub async fn update_joined_count(&self, room_id: &RoomId) {
let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64;
let mut joined_servers = HashSet::new();
self.room_members(room_id)
.ready_for_each(|joined| {
joined_servers.insert(joined.server_name().to_owned());
joinedcount = joinedcount.saturating_add(1);
})
.await;
invitedcount = invitedcount.saturating_add(
self.room_members_invited(room_id)
.count()
.await
.try_into()
.unwrap_or(0),
);
self.db
.roomid_joinedcount
.insert(room_id.as_bytes(), &joinedcount.to_be_bytes());
self.db
.roomid_invitedcount
.insert(room_id.as_bytes(), &invitedcount.to_be_bytes());
self.room_servers(room_id)
.ready_for_each(|old_joined_server| {
if !joined_servers.remove(old_joined_server) {
// Server not in room anymore
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.db.roomserverids.remove(&roomserver_id);
self.db.serverroomids.remove(&serverroom_id);
}
})
.await;
// Now only new servers are in joined_servers anymore
for server in joined_servers {
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(server.as_bytes());
let mut serverroom_id = server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.db.roomserverids.insert(&roomserver_id, &[]);
self.db.serverroomids.insert(&serverroom_id, &[]);
}
self.db
.appservice_in_room_cache
.write()
.unwrap()
.remove(room_id);
}
pub async fn mark_as_invited(
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) {
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.db.userroomid_invitestate.insert(
&userroom_id,
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
);
self.db
.roomuserid_invitecount
.insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes());
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
if let Some(servers) = invite_via {
let mut prev_servers = self
.servers_invite_via(room_id)
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await;
#[allow(clippy::redundant_clone)] // this is a necessary clone?
prev_servers.append(servers.clone().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.db
.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers);
}
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) {
let mut prev_servers = self
.servers_invite_via(room_id)
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await;
prev_servers.extend(servers.to_owned());
prev_servers.sort_unstable();
prev_servers.dedup();
let servers = prev_servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.db
.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers);
}
}

View file

@ -1,6 +1,6 @@
use std::{collections::HashSet, mem::size_of, sync::Arc};
use conduit::{checked, utils, Error, Result};
use conduit::{err, expected, utils, Result};
use database::{Database, Map};
use super::CompressedStateEvent;
@ -22,11 +22,13 @@ impl Data {
}
}
pub(super) fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
pub(super) async fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self
.shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
.qry(&shortstatehash)
.await
.map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?;
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
let parent = if parent != 0 {
Some(parent)
@ -40,10 +42,10 @@ impl Data {
let stride = size_of::<u64>();
let mut i = stride;
while let Some(v) = value.get(i..checked!(i + 2 * stride)?) {
while let Some(v) = value.get(i..expected!(i + 2 * stride)) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false;
i = checked!(i + stride)?;
i = expected!(i + stride);
continue;
}
if add_mode {
@ -51,7 +53,7 @@ impl Data {
} else {
removed.insert(v.try_into().expect("we checked the size above"));
}
i = checked!(i + 2 * stride)?;
i = expected!(i + 2 * stride);
}
Ok(StateDiff {
@ -61,7 +63,7 @@ impl Data {
})
}
pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) -> Result<()> {
pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) {
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
for new in diff.added.iter() {
value.extend_from_slice(&new[..]);
@ -75,6 +77,6 @@ impl Data {
}
self.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value)
.insert(&shortstatehash.to_be_bytes(), &value);
}
}

View file

@ -27,14 +27,12 @@ type StateInfoLruCache = Mutex<
>,
>;
type ShortStateInfoResult = Result<
Vec<(
u64, // sstatehash
Arc<HashSet<CompressedStateEvent>>, // full state
Arc<HashSet<CompressedStateEvent>>, // added
Arc<HashSet<CompressedStateEvent>>, // removed
)>,
>;
type ShortStateInfoResult = Vec<(
u64, // sstatehash
Arc<HashSet<CompressedStateEvent>>, // full state
Arc<HashSet<CompressedStateEvent>>, // added
Arc<HashSet<CompressedStateEvent>>, // removed
)>;
type ParentStatesVec = Vec<(
u64, // sstatehash
@ -43,7 +41,7 @@ type ParentStatesVec = Vec<(
Arc<HashSet<CompressedStateEvent>>, // removed
)>;
type HashSetCompressStateEvent = Result<(u64, Arc<HashSet<CompressedStateEvent>>, Arc<HashSet<CompressedStateEvent>>)>;
type HashSetCompressStateEvent = (u64, Arc<HashSet<CompressedStateEvent>>, Arc<HashSet<CompressedStateEvent>>);
pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
pub struct Service {
@ -86,12 +84,11 @@ impl crate::Service for Service {
impl Service {
/// Returns a stack with info on shortstatehash, full state, added diff and
/// removed diff for the selected shortstatehash and each parent layer.
#[tracing::instrument(skip(self), level = "debug")]
pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult {
pub async fn load_shortstatehash_info(&self, shortstatehash: u64) -> Result<ShortStateInfoResult> {
if let Some(r) = self
.stateinfo_cache
.lock()
.unwrap()
.expect("locked")
.get_mut(&shortstatehash)
{
return Ok(r.clone());
@ -101,11 +98,11 @@ impl Service {
parent,
added,
removed,
} = self.db.get_statediff(shortstatehash)?;
} = self.db.get_statediff(shortstatehash).await?;
if let Some(parent) = parent {
let mut response = self.load_shortstatehash_info(parent)?;
let mut state = (*response.last().unwrap().1).clone();
let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?;
let mut state = (*response.last().expect("at least one response").1).clone();
state.extend(added.iter().copied());
let removed = (*removed).clone();
for r in &removed {
@ -116,7 +113,7 @@ impl Service {
self.stateinfo_cache
.lock()
.unwrap()
.expect("locked")
.insert(shortstatehash, response.clone());
Ok(response)
@ -124,33 +121,42 @@ impl Service {
let response = vec![(shortstatehash, added.clone(), added, removed)];
self.stateinfo_cache
.lock()
.unwrap()
.expect("locked")
.insert(shortstatehash, response.clone());
Ok(response)
}
}
pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result<CompressedStateEvent> {
pub async fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> CompressedStateEvent {
let mut v = shortstatekey.to_be_bytes().to_vec();
v.extend_from_slice(
&self
.services
.short
.get_or_create_shorteventid(event_id)?
.get_or_create_shorteventid(event_id)
.await
.to_be_bytes(),
);
Ok(v.try_into().expect("we checked the size above"))
v.try_into().expect("we checked the size above")
}
/// Returns shortstatekey, event id
#[inline]
pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc<EventId>)> {
Ok((
utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()]).expect("bytes have right length"),
self.services.short.get_eventid_from_short(
utils::u64_from_bytes(&compressed_event[size_of::<u64>()..]).expect("bytes have right length"),
)?,
))
pub async fn parse_compressed_state_event(
&self, compressed_event: &CompressedStateEvent,
) -> Result<(u64, Arc<EventId>)> {
use utils::u64_from_u8;
let shortstatekey = u64_from_u8(&compressed_event[0..size_of::<u64>()]);
let event_id = self
.services
.short
.get_eventid_from_short(u64_from_u8(&compressed_event[size_of::<u64>()..]))
.await?;
Ok((shortstatekey, event_id))
}
/// Creates a new shortstatehash that often is just a diff to an already
@ -227,7 +233,7 @@ impl Service {
added: statediffnew,
removed: statediffremoved,
},
)?;
);
return Ok(());
};
@ -280,7 +286,7 @@ impl Service {
added: statediffnew,
removed: statediffremoved,
},
)?;
);
}
Ok(())
@ -288,10 +294,15 @@ impl Service {
/// Returns the new shortstatehash, and the state diff from the previous
/// room state
pub fn save_state(
pub async fn save_state(
&self, room_id: &RoomId, new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> HashSetCompressStateEvent {
let previous_shortstatehash = self.services.state.get_room_shortstatehash(room_id)?;
) -> Result<HashSetCompressStateEvent> {
let previous_shortstatehash = self
.services
.state
.get_room_shortstatehash(room_id)
.await
.ok();
let state_hash = utils::calculate_hash(
&new_state_ids_compressed
@ -303,14 +314,18 @@ impl Service {
let (new_shortstatehash, already_existed) = self
.services
.short
.get_or_create_shortstatehash(&state_hash)?;
.get_or_create_shortstatehash(&state_hash)
.await;
if Some(new_shortstatehash) == previous_shortstatehash {
return Ok((new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new())));
}
let states_parents =
previous_shortstatehash.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?;
let states_parents = if let Some(p) = previous_shortstatehash {
self.load_shortstatehash_info(p).await.unwrap_or_default()
} else {
ShortStateInfoResult::new()
};
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = new_state_ids_compressed

View file

@ -1,13 +1,18 @@
use std::{mem::size_of, sync::Arc};
use conduit::{checked, utils, Error, PduEvent, Result};
use database::Map;
use conduit::{
checked,
result::LogErr,
utils,
utils::{stream::TryIgnore, ReadyExt},
PduEvent, Result,
};
use database::{Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId};
use crate::{rooms, Dep};
type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>;
pub(super) struct Data {
threadid_userids: Arc<Map>,
services: Services,
@ -30,38 +35,37 @@ impl Data {
}
}
pub(super) fn threads_until<'a>(
pub(super) async fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
) -> PduEventIterResult<'a> {
) -> Result<impl Stream<Item = (u64, PduEvent)> + Send + 'a> {
let prefix = self
.services
.short
.get_shortroomid(room_id)?
.expect("room exists")
.get_shortroomid(room_id)
.await?
.to_be_bytes()
.to_vec();
let mut current = prefix.clone();
current.extend_from_slice(&(checked!(until - 1)?).to_be_bytes());
Ok(Box::new(
self.threadid_userids
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pduid, _users)| {
let count = utils::u64_from_bytes(&pduid[(size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
let mut pdu = self
.services
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((count, pdu))
}),
))
let stream = self
.threadid_userids
.rev_raw_keys_from(&current)
.ignore_err()
.ready_take_while(move |key| key.starts_with(&prefix))
.map(|pduid| (utils::u64_from_u8(&pduid[(size_of::<u64>())..]), pduid))
.filter_map(move |(count, pduid)| async move {
let mut pdu = self.services.timeline.get_pdu_from_id(pduid).await.ok()?;
if pdu.sender != user_id {
pdu.remove_transaction_id().log_err().ok();
}
Some((count, pdu))
});
Ok(stream)
}
pub(super) fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
@ -71,28 +75,12 @@ impl Data {
.collect::<Vec<_>>()
.join(&[0xFF][..]);
self.threadid_userids.insert(root_id, &users)?;
self.threadid_userids.insert(root_id, &users);
Ok(())
}
pub(super) fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
if let Some(users) = self.threadid_userids.get(root_id)? {
Ok(Some(
users
.split(|b| *b == 0xFF)
.map(|bytes| {
UserId::parse(
utils::string_from_bytes(bytes)
.map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?,
)
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
})
.filter_map(Result::ok)
.collect(),
))
} else {
Ok(None)
}
pub(super) async fn get_participants(&self, root_id: &[u8]) -> Result<Vec<OwnedUserId>> {
self.threadid_userids.qry(root_id).await.deserialized()
}
}

View file

@ -2,12 +2,12 @@ mod data;
use std::{collections::BTreeMap, sync::Arc};
use conduit::{Error, PduEvent, Result};
use conduit::{err, PduEvent, Result};
use data::Data;
use futures::Stream;
use ruma::{
api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads},
events::relation::BundledThread,
uint, CanonicalJsonValue, EventId, RoomId, UserId,
api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint, CanonicalJsonValue,
EventId, RoomId, UserId,
};
use serde_json::json;
@ -36,30 +36,35 @@ impl crate::Service for Service {
}
impl Service {
pub fn threads_until<'a>(
pub async fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads,
) -> Result<impl Iterator<Item = Result<(u64, PduEvent)>> + 'a> {
self.db.threads_until(user_id, room_id, until, include)
) -> Result<impl Stream<Item = (u64, PduEvent)> + Send + 'a> {
self.db
.threads_until(user_id, room_id, until, include)
.await
}
pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> {
pub async fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> {
let root_id = self
.services
.timeline
.get_pdu_id(root_event_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Invalid event id in thread message"))?;
.get_pdu_id(root_event_id)
.await
.map_err(|e| err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}"))))?;
let root_pdu = self
.services
.timeline
.get_pdu_from_id(&root_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?;
.get_pdu_from_id(&root_id)
.await
.map_err(|e| err!(Request(InvalidParam("Thread root not found: {e:?}"))))?;
let mut root_pdu_json = self
.services
.timeline
.get_pdu_json_from_id(&root_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?;
.get_pdu_json_from_id(&root_id)
.await
.map_err(|e| err!(Request(InvalidParam("Thread root pdu not found: {e:?}"))))?;
if let CanonicalJsonValue::Object(unsigned) = root_pdu_json
.entry("unsigned".to_owned())
@ -103,11 +108,12 @@ impl Service {
self.services
.timeline
.replace_pdu(&root_id, &root_pdu_json, &root_pdu)?;
.replace_pdu(&root_id, &root_pdu_json, &root_pdu)
.await?;
}
let mut users = Vec::new();
if let Some(userids) = self.db.get_participants(&root_id)? {
if let Ok(userids) = self.db.get_participants(&root_id).await {
users.extend_from_slice(&userids);
} else {
users.push(root_pdu.sender);

View file

@ -1,12 +1,20 @@
use std::{
collections::{hash_map, HashMap},
mem::size_of,
sync::{Arc, Mutex},
sync::Arc,
};
use conduit::{checked, error, utils, Error, PduCount, PduEvent, Result};
use database::{Database, Map};
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use conduit::{
err, expected,
result::{LogErr, NotFound},
utils,
utils::{stream::TryIgnore, u64_from_u8, ReadyExt},
Err, PduCount, PduEvent, Result,
};
use database::{Database, Deserialized, KeyVal, Map};
use futures::{FutureExt, Stream, StreamExt};
use ruma::{CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use tokio::sync::Mutex;
use crate::{rooms, Dep};
@ -25,8 +33,7 @@ struct Services {
short: Dep<rooms::short::Service>,
}
type PdusIterItem = Result<(PduCount, PduEvent)>;
type PdusIterator<'a> = Box<dyn Iterator<Item = PdusIterItem> + 'a>;
pub type PdusIterItem = (PduCount, PduEvent);
type LastTimelineCountCache = Mutex<HashMap<OwnedRoomId, PduCount>>;
impl Data {
@ -46,23 +53,20 @@ impl Data {
}
}
pub(super) fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
pub(super) async fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
match self
.lasttimelinecount_cache
.lock()
.expect("locked")
.await
.entry(room_id.to_owned())
{
hash_map::Entry::Vacant(v) => {
if let Some(last_count) = self
.pdus_until(sender_user, room_id, PduCount::max())?
.find_map(|r| {
// Filter out buggy events
if r.is_err() {
error!("Bad pdu in pdus_since: {:?}", r);
}
r.ok()
}) {
.pdus_until(sender_user, room_id, PduCount::max())
.await?
.next()
.await
{
Ok(*v.insert(last_count.0))
} else {
Ok(PduCount::Normal(0))
@ -73,232 +77,215 @@ impl Data {
}
/// Returns the `count` of this pdu's id.
pub(super) fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
pub(super) async fn get_pdu_count(&self, event_id: &EventId) -> Result<PduCount> {
self.eventid_pduid
.get(event_id.as_bytes())?
.qry(event_id)
.await
.map(|pdu_id| pdu_count(&pdu_id))
.transpose()
}
/// Returns the json of a pdu.
pub(super) fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
},
|x| Ok(Some(x)),
)
pub(super) async fn get_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
if let Ok(pdu) = self.get_non_outlier_pdu_json(event_id).await {
return Ok(pdu);
}
self.eventid_outlierpdu
.qry(event_id)
.await
.deserialized_json()
}
/// Returns the json of a pdu.
pub(super) fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
})
.transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
pub(super) async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
let pduid = self.get_pdu_id(event_id).await?;
self.pduid_pdu.qry(&pduid).await.deserialized_json()
}
/// Returns the pdu's id.
#[inline]
pub(super) fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<database::Handle<'_>>> {
self.eventid_pduid.get(event_id.as_bytes())
pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result<database::Handle<'_>> {
self.eventid_pduid.qry(event_id).await
}
/// Returns the pdu directly from `eventid_pduid` only.
pub(super) fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
})
.transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
pub(super) async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<PduEvent> {
let pduid = self.get_pdu_id(event_id).await?;
self.pduid_pdu.qry(&pduid).await.deserialized_json()
}
/// Like get_non_outlier_pdu(), but without the expense of fetching and
/// parsing the PduEvent
pub(super) async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> {
let pduid = self.get_pdu_id(event_id).await?;
self.pduid_pdu.qry(&pduid).await?;
Ok(())
}
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub(super) fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(pdu) = self
.get_non_outlier_pdu(event_id)?
.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
},
|x| Ok(Some(x)),
)?
.map(Arc::new)
{
Ok(Some(pdu))
} else {
Ok(None)
pub(super) async fn get_pdu(&self, event_id: &EventId) -> Result<Arc<PduEvent>> {
if let Ok(pdu) = self.get_non_outlier_pdu(event_id).await {
return Ok(Arc::new(pdu));
}
self.eventid_outlierpdu
.qry(event_id)
.await
.deserialized_json()
.map(Arc::new)
}
/// Like get_non_outlier_pdu(), but without the expense of fetching and
/// parsing the PduEvent
pub(super) async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> {
self.eventid_outlierpdu.qry(event_id).await?;
Ok(())
}
/// Like get_pdu(), but without the expense of fetching and parsing the data
pub(super) async fn pdu_exists(&self, event_id: &EventId) -> bool {
let non_outlier = self.non_outlier_pdu_exists(event_id).map(|res| res.is_ok());
let outlier = self.outlier_pdu_exists(event_id).map(|res| res.is_ok());
//TODO: parallelize
non_outlier.await || outlier.await
}
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
pub(super) fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
pub(super) async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<PduEvent> {
self.pduid_pdu.qry(pdu_id).await.deserialized_json()
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
pub(super) fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<CanonicalJsonObject> {
self.pduid_pdu.qry(pdu_id).await.deserialized_json()
}
pub(super) fn append_pdu(
&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64,
) -> Result<()> {
pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
);
self.lasttimelinecount_cache
.lock()
.expect("locked")
.await
.insert(pdu.room_id.clone(), PduCount::Normal(count));
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
Ok(())
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id);
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes());
}
pub(super) fn prepend_backfill_pdu(
&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject,
) -> Result<()> {
pub(super) fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
);
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(event_id.as_bytes())?;
Ok(())
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id);
self.eventid_outlierpdu.remove(event_id.as_bytes());
}
/// Removes a pdu and creates a new one with the same id.
pub(super) fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> {
if self.pduid_pdu.get(pdu_id)?.is_some() {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
)?;
} else {
return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist."));
pub(super) async fn replace_pdu(
&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent,
) -> Result<()> {
if self.pduid_pdu.qry(pdu_id).await.is_not_found() {
return Err!(Request(NotFound("PDU does not exist.")));
}
let pdu = serde_json::to_vec(pdu_json)?;
self.pduid_pdu.insert(pdu_id, &pdu);
Ok(())
}
/// Returns an iterator over all events and their tokens in a room that
/// happened before the event with id `until` in reverse-chronological
/// order.
pub(super) fn pdus_until(&self, user_id: &UserId, room_id: &RoomId, until: PduCount) -> Result<PdusIterator<'_>> {
let (prefix, current) = self.count_to_id(room_id, until, 1, true)?;
pub(super) async fn pdus_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
let (prefix, current) = self.count_to_id(room_id, until, 1, true).await?;
let stream = self
.pduid_pdu
.rev_raw_stream_from(&current)
.ignore_err()
.ready_take_while(move |(key, _)| key.starts_with(&prefix))
.map(move |item| Self::each_pdu(item, user_id));
let user_id = user_id.to_owned();
Ok(Box::new(
self.pduid_pdu
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
pdu.add_age()?;
let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}),
))
Ok(stream)
}
pub(super) fn pdus_after(&self, user_id: &UserId, room_id: &RoomId, from: PduCount) -> Result<PdusIterator<'_>> {
let (prefix, current) = self.count_to_id(room_id, from, 1, false)?;
pub(super) async fn pdus_after<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
let (prefix, current) = self.count_to_id(room_id, from, 1, false).await?;
let stream = self
.pduid_pdu
.raw_stream_from(&current)
.ignore_err()
.ready_take_while(move |(key, _)| key.starts_with(&prefix))
.map(move |item| Self::each_pdu(item, user_id));
let user_id = user_id.to_owned();
Ok(stream)
}
Ok(Box::new(
self.pduid_pdu
.iter_from(&current, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
pdu.add_age()?;
let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}),
))
fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: &UserId) -> PdusIterItem {
let mut pdu =
serde_json::from_slice::<PduEvent>(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON");
if pdu.sender != user_id {
pdu.remove_transaction_id().log_err().ok();
}
pdu.add_age().log_err().ok();
let count = pdu_count(pdu_id);
(count, pdu)
}
pub(super) fn increment_notification_counts(
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
) -> Result<()> {
let mut notifies_batch = Vec::new();
let mut highlights_batch = Vec::new();
) {
let _cork = self.db.cork();
for user in notifies {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
notifies_batch.push(userroom_id);
increment(&self.userroomid_notificationcount, &userroom_id);
}
for user in highlights {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
highlights_batch.push(userroom_id);
increment(&self.userroomid_highlightcount, &userroom_id);
}
self.userroomid_notificationcount
.increment_batch(notifies_batch.iter().map(Vec::as_slice))?;
self.userroomid_highlightcount
.increment_batch(highlights_batch.iter().map(Vec::as_slice))?;
Ok(())
}
pub(super) fn count_to_id(
pub(super) async fn count_to_id(
&self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool,
) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = self
.services
.short
.get_shortroomid(room_id)?
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
.get_shortroomid(room_id)
.await
.map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))?
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();
// +1 so we don't send the base event
let count_raw = match count {
@ -326,17 +313,23 @@ impl Data {
}
/// Returns the `count` of this pdu's id.
pub(super) fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
let stride = size_of::<u64>();
let pdu_id_len = pdu_id.len();
let last_u64 = utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - stride)?..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
let second_last_u64 =
utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - 2 * stride)?..checked!(pdu_id_len - stride)?]);
pub(super) fn pdu_count(pdu_id: &[u8]) -> PduCount {
const STRIDE: usize = size_of::<u64>();
if matches!(second_last_u64, Ok(0)) {
Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64)))
let pdu_id_len = pdu_id.len();
let last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - STRIDE)..]);
let second_last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - 2 * STRIDE)..expected!(pdu_id_len - STRIDE)]);
if second_last_u64 == 0 {
PduCount::Backfilled(u64::MAX.saturating_sub(last_u64))
} else {
Ok(PduCount::Normal(last_u64))
PduCount::Normal(last_u64)
}
}
//TODO: this is an ABA
fn increment(db: &Arc<Map>, key: &[u8]) {
let old = db.get(key);
let new = utils::increment(old.ok().as_deref());
db.insert(key, &new);
}

File diff suppressed because it is too large Load diff

View file

@ -46,7 +46,7 @@ impl Service {
/// Sets a user as typing until the timeout timestamp is reached or
/// roomtyping_remove is called.
pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
debug_info!("typing started {:?} in {:?} timeout:{:?}", user_id, room_id, timeout);
debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}");
// update clients
self.typing
.write()
@ -54,17 +54,19 @@ impl Service {
.entry(room_id.to_owned())
.or_default()
.insert(user_id.to_owned(), timeout);
self.last_typing_update
.write()
.await
.insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested");
}
// update federation
if self.services.globals.user_is_local(user_id) {
self.federation_send(room_id, user_id, true)?;
self.federation_send(room_id, user_id, true).await?;
}
Ok(())
@ -72,7 +74,7 @@ impl Service {
/// Removes a user from typing before the timeout is reached.
pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
debug_info!("typing stopped {:?} in {:?}", user_id, room_id);
debug_info!("typing stopped {user_id:?} in {room_id:?}");
// update clients
self.typing
.write()
@ -80,31 +82,31 @@ impl Service {
.entry(room_id.to_owned())
.or_default()
.remove(user_id);
self.last_typing_update
.write()
.await
.insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested");
}
// update federation
if self.services.globals.user_is_local(user_id) {
self.federation_send(room_id, user_id, false)?;
self.federation_send(room_id, user_id, false).await?;
}
Ok(())
}
pub async fn wait_for_update(&self, room_id: &RoomId) -> Result<()> {
pub async fn wait_for_update(&self, room_id: &RoomId) {
let mut receiver = self.typing_update_sender.subscribe();
while let Ok(next) = receiver.recv().await {
if next == room_id {
break;
}
}
Ok(())
}
/// Makes sure that typing events with old timestamps get removed.
@ -123,30 +125,30 @@ impl Service {
removable.push(user.clone());
}
}
drop(typing);
};
if !removable.is_empty() {
let typing = &mut self.typing.write().await;
let room = typing.entry(room_id.to_owned()).or_default();
for user in &removable {
debug_info!("typing timeout {:?} in {:?}", &user, room_id);
debug_info!("typing timeout {user:?} in {room_id:?}");
room.remove(user);
}
// update clients
self.last_typing_update
.write()
.await
.insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested");
}
// update federation
for user in removable {
if self.services.globals.user_is_local(&user) {
self.federation_send(room_id, &user, false)?;
for user in &removable {
if self.services.globals.user_is_local(user) {
self.federation_send(room_id, user, false).await?;
}
}
}
@ -183,7 +185,7 @@ impl Service {
})
}
fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> {
async fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> {
debug_assert!(
self.services.globals.user_is_local(user_id),
"tried to broadcast typing status of remote user",
@ -197,7 +199,8 @@ impl Service {
self.services
.sending
.send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))?;
.send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))
.await?;
Ok(())
}

View file

@ -1,8 +1,9 @@
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::Map;
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use conduit::Result;
use database::{Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{RoomId, UserId};
use crate::{globals, rooms, Dep};
@ -11,13 +12,13 @@ pub(super) struct Data {
userroomid_highlightcount: Arc<Map>,
roomuserid_lastnotificationread: Arc<Map>,
roomsynctoken_shortstatehash: Arc<Map>,
userroomid_joined: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
short: Dep<rooms::short::Service>,
state_cache: Dep<rooms::state_cache::Service>,
}
impl Data {
@ -28,15 +29,15 @@ impl Data {
userroomid_highlightcount: db["userroomid_highlightcount"].clone(),
roomuserid_lastnotificationread: db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit
roomsynctoken_shortstatehash: db["roomsynctoken_shortstatehash"].clone(),
userroomid_joined: db["userroomid_joined"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
short: args.depend::<rooms::short::Service>("rooms::short"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
},
}
}
pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
@ -45,128 +46,73 @@ impl Data {
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_notificationcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
.insert(&userroom_id, &0_u64.to_be_bytes());
self.userroomid_highlightcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
.insert(&userroom_id, &0_u64.to_be_bytes());
self.roomuserid_lastnotificationread
.insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?;
Ok(())
.insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes());
}
pub(super) fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
pub(super) async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
let key = (user_id, room_id);
self.userroomid_notificationcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db."))
})
.qry(&key)
.await
.deserialized()
.unwrap_or(0)
}
pub(super) fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
pub(super) async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
let key = (user_id, room_id);
self.userroomid_highlightcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db."))
})
.qry(&key)
.await
.deserialized()
.unwrap_or(0)
}
pub(super) fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
.roomuserid_lastnotificationread
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
})
.transpose()?
.unwrap_or(0))
pub(super) async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
let key = (room_id, user_id);
self.roomuserid_lastnotificationread
.qry(&key)
.await
.deserialized()
.unwrap_or(0)
}
pub(super) fn associate_token_shortstatehash(
&self, room_id: &RoomId, token: u64, shortstatehash: u64,
) -> Result<()> {
pub(super) async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) {
let shortroomid = self
.services
.short
.get_shortroomid(room_id)?
.get_shortroomid(room_id)
.await
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.insert(&key, &shortstatehash.to_be_bytes())
.insert(&key, &shortstatehash.to_be_bytes());
}
pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = self
.services
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
pub(super) async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<u64> {
let shortroomid = self.services.short.get_shortroomid(room_id).await?;
let key: &[u64] = &[shortroomid, token];
self.roomsynctoken_shortstatehash
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash"))
})
.transpose()
.qry(key)
.await
.deserialized()
}
//TODO: optimize; replace point-queries with dual iteration
pub(super) fn get_shared_rooms<'a>(
&'a self, users: Vec<OwnedUserId>,
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
let iterators = users.into_iter().map(move |user_id| {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
self.userroomid_joined
.scan_prefix(prefix)
.map(|(key, _)| {
let roomid_index = key
.iter()
.enumerate()
.find(|(_, &b)| b == 0xFF)
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
.0
.saturating_add(1); // +1 because the room id starts AFTER the separator
let room_id = key[roomid_index..].to_vec();
Ok::<_, Error>(room_id)
})
.filter_map(Result::ok)
});
// We use the default compare function because keys are sorted correctly (not
// reversed)
Ok(Box::new(
utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?,
)
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
}),
))
&'a self, user_a: &'a UserId, user_b: &'a UserId,
) -> impl Stream<Item = &RoomId> + Send + 'a {
self.services
.state_cache
.rooms_joined(user_a)
.filter(|room_id| self.services.state_cache.is_joined(user_b, room_id))
}
}

View file

@ -3,7 +3,8 @@ mod data;
use std::sync::Arc;
use conduit::Result;
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use futures::{pin_mut, Stream, StreamExt};
use ruma::{RoomId, UserId};
use self::data::Data;
@ -22,32 +23,49 @@ impl crate::Service for Service {
}
impl Service {
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
self.db.reset_notification_counts(user_id, room_id)
#[inline]
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) {
self.db.reset_notification_counts(user_id, room_id);
}
pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
self.db.notification_count(user_id, room_id)
#[inline]
pub async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
self.db.notification_count(user_id, room_id).await
}
pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
self.db.highlight_count(user_id, room_id)
#[inline]
pub async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
self.db.highlight_count(user_id, room_id).await
}
pub fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
self.db.last_notification_read(user_id, room_id)
#[inline]
pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
self.db.last_notification_read(user_id, room_id).await
}
pub fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
#[inline]
pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) {
self.db
.associate_token_shortstatehash(room_id, token, shortstatehash)
.await;
}
pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
self.db.get_token_shortstatehash(room_id, token)
#[inline]
pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<u64> {
self.db.get_token_shortstatehash(room_id, token).await
}
pub fn get_shared_rooms(&self, users: Vec<OwnedUserId>) -> Result<impl Iterator<Item = Result<OwnedRoomId>> + '_> {
self.db.get_shared_rooms(users)
#[inline]
pub fn get_shared_rooms<'a>(
&'a self, user_a: &'a UserId, user_b: &'a UserId,
) -> impl Stream<Item = &RoomId> + Send + 'a {
self.db.get_shared_rooms(user_a, user_b)
}
pub async fn has_shared_rooms<'a>(&'a self, user_a: &'a UserId, user_b: &'a UserId) -> bool {
let get_shared_rooms = self.get_shared_rooms(user_a, user_b);
pin_mut!(get_shared_rooms);
get_shared_rooms.next().await.is_some()
}
}

View file

@ -1,14 +1,21 @@
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::{Database, Map};
use conduit::{
utils,
utils::{stream::TryIgnore, ReadyExt},
Error, Result,
};
use database::{Database, Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{ServerName, UserId};
use super::{Destination, SendingEvent};
use crate::{globals, Dep};
type OutgoingSendingIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a>;
type SendingEventIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a>;
pub(super) type OutgoingItem = (Key, SendingEvent, Destination);
pub(super) type SendingItem = (Key, SendingEvent);
pub(super) type QueueItem = (Key, SendingEvent);
pub(super) type Key = Vec<u8>;
pub struct Data {
servercurrentevent_data: Arc<Map>,
@ -36,85 +43,34 @@ impl Data {
}
}
#[inline]
pub fn active_requests(&self) -> OutgoingSendingIter<'_> {
Box::new(
self.servercurrentevent_data
.iter()
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))),
)
}
pub(super) fn delete_active_request(&self, key: &[u8]) { self.servercurrentevent_data.remove(key); }
#[inline]
pub fn active_requests_for<'a>(&'a self, destination: &Destination) -> SendingEventIter<'a> {
pub(super) async fn delete_all_active_requests_for(&self, destination: &Destination) {
let prefix = destination.get_prefix();
Box::new(
self.servercurrentevent_data
.scan_prefix(prefix)
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
)
self.servercurrentevent_data
.raw_keys_prefix(&prefix)
.ignore_err()
.ready_for_each(|key| self.servercurrentevent_data.remove(key))
.await;
}
pub(super) fn delete_active_request(&self, key: &[u8]) -> Result<()> { self.servercurrentevent_data.remove(key) }
pub(super) fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> {
pub(super) async fn delete_all_requests_for(&self, destination: &Destination) {
let prefix = destination.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
self.servercurrentevent_data.remove(&key)?;
}
self.servercurrentevent_data
.raw_keys_prefix(&prefix)
.ignore_err()
.ready_for_each(|key| self.servercurrentevent_data.remove(key))
.await;
Ok(())
}
pub(super) fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> {
let prefix = destination.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
self.servercurrentevent_data.remove(&key).unwrap();
}
for (key, _) in self.servernameevent_data.scan_prefix(prefix) {
self.servernameevent_data.remove(&key).unwrap();
}
Ok(())
}
pub(super) fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result<Vec<Vec<u8>>> {
let mut batch = Vec::new();
let mut keys = Vec::new();
for (destination, event) in requests {
let mut key = destination.get_prefix();
if let SendingEvent::Pdu(value) = &event {
key.extend_from_slice(value);
} else {
key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
}
let value = if let SendingEvent::Edu(value) = &event {
&**value
} else {
&[]
};
batch.push((key.clone(), value.to_owned()));
keys.push(key);
}
self.servernameevent_data
.insert_batch(batch.iter().map(database::KeyVal::from))?;
Ok(keys)
.raw_keys_prefix(&prefix)
.ignore_err()
.ready_for_each(|key| self.servernameevent_data.remove(key))
.await;
}
pub fn queued_requests<'a>(
&'a self, destination: &Destination,
) -> Box<dyn Iterator<Item = Result<(SendingEvent, Vec<u8>)>> + 'a> {
let prefix = destination.get_prefix();
return Box::new(
self.servernameevent_data
.scan_prefix(prefix)
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
);
}
pub(super) fn mark_as_active(&self, events: &[(SendingEvent, Vec<u8>)]) -> Result<()> {
for (e, key) in events {
pub(super) fn mark_as_active(&self, events: &[QueueItem]) {
for (key, e) in events {
if key.is_empty() {
continue;
}
@ -124,29 +80,87 @@ impl Data {
} else {
&[]
};
self.servercurrentevent_data.insert(key, value)?;
self.servernameevent_data.remove(key)?;
self.servercurrentevent_data.insert(key, value);
self.servernameevent_data.remove(key);
}
}
#[inline]
pub fn active_requests(&self) -> impl Stream<Item = OutgoingItem> + Send + '_ {
self.servercurrentevent_data
.raw_stream()
.ignore_err()
.map(|(key, val)| {
let (dest, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
(key.to_vec(), event, dest)
})
}
#[inline]
pub fn active_requests_for<'a>(&'a self, destination: &Destination) -> impl Stream<Item = SendingItem> + Send + 'a {
let prefix = destination.get_prefix();
self.servercurrentevent_data
.stream_raw_prefix(&prefix)
.ignore_err()
.map(|(key, val)| {
let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
(key.to_vec(), event)
})
}
pub(super) fn queue_requests(&self, requests: &[(&SendingEvent, &Destination)]) -> Vec<Vec<u8>> {
let mut batch = Vec::new();
let mut keys = Vec::new();
for (event, destination) in requests {
let mut key = destination.get_prefix();
if let SendingEvent::Pdu(value) = &event {
key.extend_from_slice(value);
} else {
key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes());
}
let value = if let SendingEvent::Edu(value) = &event {
&**value
} else {
&[]
};
batch.push((key.clone(), value.to_owned()));
keys.push(key);
}
Ok(())
self.servernameevent_data.insert_batch(batch.iter());
keys
}
pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
self.servername_educount
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
}
pub fn queued_requests<'a>(&'a self, destination: &Destination) -> impl Stream<Item = QueueItem> + Send + 'a {
let prefix = destination.get_prefix();
self.servernameevent_data
.stream_raw_prefix(&prefix)
.ignore_err()
.map(|(key, val)| {
let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent");
pub fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
self.servername_educount
.get(server_name.as_bytes())?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
(key.to_vec(), event)
})
}
pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) {
self.servername_educount
.insert(server_name.as_bytes(), &last_count.to_be_bytes());
}
pub async fn get_latest_educount(&self, server_name: &ServerName) -> u64 {
self.servername_educount
.qry(server_name)
.await
.deserialized()
.unwrap_or(0)
}
}
#[tracing::instrument(skip(key), level = "debug")]
fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination, SendingEvent)> {
fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, SendingEvent)> {
// Appservices start with a plus
Ok::<_, Error>(if key.starts_with(b"+") {
let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
@ -164,7 +178,7 @@ fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination,
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
SendingEvent::Edu(value)
SendingEvent::Edu(value.to_vec())
},
)
} else if key.starts_with(b"$") {
@ -192,7 +206,7 @@ fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination,
SendingEvent::Pdu(event.to_vec())
} else {
// I'm pretty sure this should never be called
SendingEvent::Edu(value)
SendingEvent::Edu(value.to_vec())
},
)
} else {
@ -214,7 +228,7 @@ fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination,
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
SendingEvent::Edu(value)
SendingEvent::Edu(value.to_vec())
},
)
})

View file

@ -7,10 +7,11 @@ mod sender;
use std::{fmt::Debug, sync::Arc};
use async_trait::async_trait;
use conduit::{err, warn, Result, Server};
use conduit::{err, utils::ReadyExt, warn, Result, Server};
use futures::{future::ready, Stream, StreamExt, TryStreamExt};
use ruma::{
api::{appservice::Registration, OutgoingRequest},
OwnedServerName, RoomId, ServerName, UserId,
RoomId, ServerName, UserId,
};
use tokio::sync::Mutex;
@ -104,7 +105,7 @@ impl Service {
let dest = Destination::Push(user.to_owned(), pushkey);
let event = SendingEvent::Pdu(pdu_id.to_owned());
let _cork = self.db.db.cork();
let keys = self.db.queue_requests(&[(&dest, event.clone())])?;
let keys = self.db.queue_requests(&[(&event, &dest)]);
self.dispatch(Msg {
dest,
event,
@ -117,7 +118,7 @@ impl Service {
let dest = Destination::Appservice(appservice_id);
let event = SendingEvent::Pdu(pdu_id);
let _cork = self.db.db.cork();
let keys = self.db.queue_requests(&[(&dest, event.clone())])?;
let keys = self.db.queue_requests(&[(&event, &dest)]);
self.dispatch(Msg {
dest,
event,
@ -126,30 +127,31 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")]
pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> {
pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> {
let servers = self
.services
.state_cache
.room_servers(room_id)
.filter_map(Result::ok)
.filter(|server_name| !self.services.globals.server_is_ours(server_name));
.ready_filter(|server_name| !self.services.globals.server_is_ours(server_name));
self.send_pdu_servers(servers, pdu_id)
self.send_pdu_servers(servers, pdu_id).await
}
#[tracing::instrument(skip(self, servers, pdu_id), level = "debug")]
pub fn send_pdu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, pdu_id: &[u8]) -> Result<()> {
let requests = servers
.into_iter()
.map(|server| (Destination::Normal(server), SendingEvent::Pdu(pdu_id.to_owned())))
.collect::<Vec<_>>();
pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &[u8]) -> Result<()>
where
S: Stream<Item = &'a ServerName> + Send + 'a,
{
let _cork = self.db.db.cork();
let keys = self.db.queue_requests(
&requests
.iter()
.map(|(o, e)| (o, e.clone()))
.collect::<Vec<_>>(),
)?;
let requests = servers
.map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.into())))
.collect::<Vec<_>>()
.await;
let keys = self
.db
.queue_requests(&requests.iter().map(|(o, e)| (e, o)).collect::<Vec<_>>());
for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
self.dispatch(Msg {
dest,
@ -166,7 +168,7 @@ impl Service {
let dest = Destination::Normal(server.to_owned());
let event = SendingEvent::Edu(serialized);
let _cork = self.db.db.cork();
let keys = self.db.queue_requests(&[(&dest, event.clone())])?;
let keys = self.db.queue_requests(&[(&event, &dest)]);
self.dispatch(Msg {
dest,
event,
@ -175,30 +177,30 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id, serialized), level = "debug")]
pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec<u8>) -> Result<()> {
pub async fn send_edu_room(&self, room_id: &RoomId, serialized: Vec<u8>) -> Result<()> {
let servers = self
.services
.state_cache
.room_servers(room_id)
.filter_map(Result::ok)
.filter(|server_name| !self.services.globals.server_is_ours(server_name));
.ready_filter(|server_name| !self.services.globals.server_is_ours(server_name));
self.send_edu_servers(servers, serialized)
self.send_edu_servers(servers, serialized).await
}
#[tracing::instrument(skip(self, servers, serialized), level = "debug")]
pub fn send_edu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, serialized: Vec<u8>) -> Result<()> {
let requests = servers
.into_iter()
.map(|server| (Destination::Normal(server), SendingEvent::Edu(serialized.clone())))
.collect::<Vec<_>>();
pub async fn send_edu_servers<'a, S>(&self, servers: S, serialized: Vec<u8>) -> Result<()>
where
S: Stream<Item = &'a ServerName> + Send + 'a,
{
let _cork = self.db.db.cork();
let keys = self.db.queue_requests(
&requests
.iter()
.map(|(o, e)| (o, e.clone()))
.collect::<Vec<_>>(),
)?;
let requests = servers
.map(|server| (Destination::Normal(server.to_owned()), SendingEvent::Edu(serialized.clone())))
.collect::<Vec<_>>()
.await;
let keys = self
.db
.queue_requests(&requests.iter().map(|(o, e)| (e, o)).collect::<Vec<_>>());
for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
self.dispatch(Msg {
@ -212,29 +214,33 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id), level = "debug")]
pub fn flush_room(&self, room_id: &RoomId) -> Result<()> {
pub async fn flush_room(&self, room_id: &RoomId) -> Result<()> {
let servers = self
.services
.state_cache
.room_servers(room_id)
.filter_map(Result::ok)
.filter(|server_name| !self.services.globals.server_is_ours(server_name));
.ready_filter(|server_name| !self.services.globals.server_is_ours(server_name));
self.flush_servers(servers)
self.flush_servers(servers).await
}
#[tracing::instrument(skip(self, servers), level = "debug")]
pub fn flush_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I) -> Result<()> {
let requests = servers.into_iter().map(Destination::Normal);
for dest in requests {
self.dispatch(Msg {
dest,
event: SendingEvent::Flush,
queue_id: Vec::<u8>::new(),
})?;
}
Ok(())
pub async fn flush_servers<'a, S>(&self, servers: S) -> Result<()>
where
S: Stream<Item = &'a ServerName> + Send + 'a,
{
servers
.map(ToOwned::to_owned)
.map(Destination::Normal)
.map(Ok)
.try_for_each(|dest| {
ready(self.dispatch(Msg {
dest,
event: SendingEvent::Flush,
queue_id: Vec::<u8>::new(),
}))
})
.await
}
#[tracing::instrument(skip_all, name = "request")]
@ -263,11 +269,10 @@ impl Service {
/// Cleanup event data
/// Used for instance after we remove an appservice registration
#[tracing::instrument(skip(self), level = "debug")]
pub fn cleanup_events(&self, appservice_id: String) -> Result<()> {
pub async fn cleanup_events(&self, appservice_id: String) {
self.db
.delete_all_requests_for(&Destination::Appservice(appservice_id))?;
Ok(())
.delete_all_requests_for(&Destination::Appservice(appservice_id))
.await;
}
fn dispatch(&self, msg: Msg) -> Result<()> {

View file

@ -7,18 +7,15 @@ use std::{
use base64::{engine::general_purpose, Engine as _};
use conduit::{
debug, debug_warn, error, trace,
utils::{calculate_hash, math::continue_exponential_backoff_secs},
debug, debug_warn, err, trace,
utils::{calculate_hash, math::continue_exponential_backoff_secs, ReadyExt},
warn, Error, Result,
};
use federation::transactions::send_transaction_message;
use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
use futures::{future::BoxFuture, pin_mut, stream::FuturesUnordered, FutureExt, StreamExt};
use ruma::{
api::federation::{
self,
transactions::edu::{
DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap,
},
api::federation::transactions::{
edu::{DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap},
send_transaction_message,
},
device_id,
events::{push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType},
@ -28,7 +25,7 @@ use ruma::{
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::time::sleep_until;
use super::{appservice, Destination, Msg, SendingEvent, Service};
use super::{appservice, data::QueueItem, Destination, Msg, SendingEvent, Service};
#[derive(Debug)]
enum TransactionStatus {
@ -50,20 +47,20 @@ const CLEANUP_TIMEOUT_MS: u64 = 3500;
impl Service {
#[tracing::instrument(skip_all, name = "sender")]
pub(super) async fn sender(&self) -> Result<()> {
let receiver = self.receiver.lock().await;
let mut futures: SendingFutures<'_> = FuturesUnordered::new();
let mut statuses: CurTransactionStatus = CurTransactionStatus::new();
let mut futures: SendingFutures<'_> = FuturesUnordered::new();
let receiver = self.receiver.lock().await;
self.initial_requests(&futures, &mut statuses);
self.initial_requests(&mut futures, &mut statuses).await;
loop {
debug_assert!(!receiver.is_closed(), "channel error");
tokio::select! {
request = receiver.recv_async() => match request {
Ok(request) => self.handle_request(request, &futures, &mut statuses),
Ok(request) => self.handle_request(request, &mut futures, &mut statuses).await,
Err(_) => break,
},
Some(response) = futures.next() => {
self.handle_response(response, &futures, &mut statuses);
self.handle_response(response, &mut futures, &mut statuses).await;
},
}
}
@ -72,18 +69,16 @@ impl Service {
Ok(())
}
fn handle_response<'a>(
&'a self, response: SendingResult, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus,
async fn handle_response<'a>(
&'a self, response: SendingResult, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus,
) {
match response {
Ok(dest) => self.handle_response_ok(&dest, futures, statuses),
Err((dest, e)) => Self::handle_response_err(dest, futures, statuses, &e),
Ok(dest) => self.handle_response_ok(&dest, futures, statuses).await,
Err((dest, e)) => Self::handle_response_err(dest, statuses, &e),
};
}
fn handle_response_err(
dest: Destination, _futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, e: &Error,
) {
fn handle_response_err(dest: Destination, statuses: &mut CurTransactionStatus, e: &Error) {
debug!(dest = ?dest, "{e:?}");
statuses.entry(dest).and_modify(|e| {
*e = match e {
@ -94,39 +89,40 @@ impl Service {
});
}
fn handle_response_ok<'a>(
&'a self, dest: &Destination, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus,
#[allow(clippy::needless_pass_by_ref_mut)]
async fn handle_response_ok<'a>(
&'a self, dest: &Destination, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus,
) {
let _cork = self.db.db.cork();
self.db
.delete_all_active_requests_for(dest)
.expect("all active requests deleted");
self.db.delete_all_active_requests_for(dest).await;
// Find events that have been added since starting the last request
let new_events = self
.db
.queued_requests(dest)
.filter_map(Result::ok)
.take(DEQUEUE_LIMIT)
.collect::<Vec<_>>();
.collect::<Vec<_>>()
.await;
// Insert any pdus we found
if !new_events.is_empty() {
self.db
.mark_as_active(&new_events)
.expect("marked as active");
let new_events_vec = new_events.into_iter().map(|(event, _)| event).collect();
futures.push(Box::pin(self.send_events(dest.clone(), new_events_vec)));
self.db.mark_as_active(&new_events);
let new_events_vec = new_events.into_iter().map(|(_, event)| event).collect();
futures.push(self.send_events(dest.clone(), new_events_vec).boxed());
} else {
statuses.remove(dest);
}
}
fn handle_request<'a>(&'a self, msg: Msg, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) {
let iv = vec![(msg.event, msg.queue_id)];
if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses) {
#[allow(clippy::needless_pass_by_ref_mut)]
async fn handle_request<'a>(
&'a self, msg: Msg, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus,
) {
let iv = vec![(msg.queue_id, msg.event)];
if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await {
if !events.is_empty() {
futures.push(Box::pin(self.send_events(msg.dest, events)));
futures.push(self.send_events(msg.dest, events).boxed());
} else {
statuses.remove(&msg.dest);
}
@ -142,7 +138,7 @@ impl Service {
tokio::select! {
() = sleep_until(deadline.into()) => break,
response = futures.next() => match response {
Some(response) => self.handle_response(response, futures, statuses),
Some(response) => self.handle_response(response, futures, statuses).await,
None => return,
}
}
@ -151,16 +147,17 @@ impl Service {
debug_warn!("Leaving with {} unfinished requests...", futures.len());
}
fn initial_requests<'a>(&'a self, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) {
#[allow(clippy::needless_pass_by_ref_mut)]
async fn initial_requests<'a>(&'a self, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus) {
let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX);
let mut txns = HashMap::<Destination, Vec<SendingEvent>>::new();
for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) {
let mut active = self.db.active_requests().boxed();
while let Some((key, event, dest)) = active.next().await {
let entry = txns.entry(dest.clone()).or_default();
if self.server.config.startup_netburst_keep >= 0 && entry.len() >= keep {
warn!("Dropping unsent event {:?} {:?}", dest, String::from_utf8_lossy(&key));
self.db
.delete_active_request(&key)
.expect("active request deleted");
warn!("Dropping unsent event {dest:?} {:?}", String::from_utf8_lossy(&key));
self.db.delete_active_request(&key);
} else {
entry.push(event);
}
@ -169,16 +166,16 @@ impl Service {
for (dest, events) in txns {
if self.server.config.startup_netburst && !events.is_empty() {
statuses.insert(dest.clone(), TransactionStatus::Running);
futures.push(Box::pin(self.send_events(dest.clone(), events)));
futures.push(self.send_events(dest.clone(), events).boxed());
}
}
}
#[tracing::instrument(skip_all, level = "debug")]
fn select_events(
async fn select_events(
&self,
dest: &Destination,
new_events: Vec<(SendingEvent, Vec<u8>)>, // Events we want to send: event and full key
new_events: Vec<QueueItem>, // Events we want to send: event and full key
statuses: &mut CurTransactionStatus,
) -> Result<Option<Vec<SendingEvent>>> {
let (allow, retry) = self.select_events_current(dest.clone(), statuses)?;
@ -195,8 +192,8 @@ impl Service {
if retry {
self.db
.active_requests_for(dest)
.filter_map(Result::ok)
.for_each(|(_, e)| events.push(e));
.ready_for_each(|(_, e)| events.push(e))
.await;
return Ok(Some(events));
}
@ -204,17 +201,17 @@ impl Service {
// Compose the next transaction
let _cork = self.db.db.cork();
if !new_events.is_empty() {
self.db.mark_as_active(&new_events)?;
for (e, _) in new_events {
self.db.mark_as_active(&new_events);
for (_, e) in new_events {
events.push(e);
}
}
// Add EDU's into the transaction
if let Destination::Normal(server_name) = dest {
if let Ok((select_edus, last_count)) = self.select_edus(server_name) {
if let Ok((select_edus, last_count)) = self.select_edus(server_name).await {
events.extend(select_edus.into_iter().map(SendingEvent::Edu));
self.db.set_latest_educount(server_name, last_count)?;
self.db.set_latest_educount(server_name, last_count);
}
}
@ -248,26 +245,32 @@ impl Service {
}
#[tracing::instrument(skip_all, level = "debug")]
fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> {
async fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> {
// u64: count of last edu
let since = self.db.get_latest_educount(server_name)?;
let since = self.db.get_latest_educount(server_name).await;
let mut events = Vec::new();
let mut max_edu_count = since;
let mut device_list_changes = HashSet::new();
for room_id in self.services.state_cache.server_rooms(server_name) {
let room_id = room_id?;
let server_rooms = self.services.state_cache.server_rooms(server_name);
pin_mut!(server_rooms);
while let Some(room_id) = server_rooms.next().await {
// Look for device list updates in this room
device_list_changes.extend(
self.services
.users
.keys_changed(room_id.as_ref(), since, None)
.filter_map(Result::ok)
.filter(|user_id| self.services.globals.user_is_local(user_id)),
.keys_changed(room_id.as_str(), since, None)
.ready_filter(|user_id| self.services.globals.user_is_local(user_id))
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await,
);
if self.server.config.allow_outgoing_read_receipts
&& !self.select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)?
&& !self
.select_edus_receipts(room_id, since, &mut max_edu_count, &mut events)
.await?
{
break;
}
@ -290,19 +293,22 @@ impl Service {
}
if self.server.config.allow_outgoing_presence {
self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?;
self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events)
.await?;
}
Ok((events, max_edu_count))
}
/// Look for presence
fn select_edus_presence(
async fn select_edus_presence(
&self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>,
) -> Result<bool> {
// Look for presence updates for this server
let presence_since = self.services.presence.presence_since(since);
pin_mut!(presence_since);
let mut presence_updates = Vec::new();
for (user_id, count, presence_bytes) in self.services.presence.presence_since(since) {
while let Some((user_id, count, presence_bytes)) = presence_since.next().await {
*max_edu_count = cmp::max(count, *max_edu_count);
if !self.services.globals.user_is_local(&user_id) {
@ -312,7 +318,8 @@ impl Service {
if !self
.services
.state_cache
.server_sees_user(server_name, &user_id)?
.server_sees_user(server_name, &user_id)
.await
{
continue;
}
@ -320,7 +327,9 @@ impl Service {
let presence_event = self
.services
.presence
.from_json_bytes_to_event(&presence_bytes, &user_id)?;
.from_json_bytes_to_event(&presence_bytes, &user_id)
.await?;
presence_updates.push(PresenceUpdate {
user_id,
presence: presence_event.content.presence,
@ -346,32 +355,33 @@ impl Service {
}
/// Look for read receipts in this room
fn select_edus_receipts(
async fn select_edus_receipts(
&self, room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>,
) -> Result<bool> {
for r in self
let receipts = self
.services
.read_receipt
.readreceipts_since(room_id, since)
{
let (user_id, count, read_receipt) = r?;
*max_edu_count = cmp::max(count, *max_edu_count);
.readreceipts_since(room_id, since);
pin_mut!(receipts);
while let Some((user_id, count, read_receipt)) = receipts.next().await {
*max_edu_count = cmp::max(count, *max_edu_count);
if !self.services.globals.user_is_local(&user_id) {
continue;
}
let event = serde_json::from_str(read_receipt.json().get())
.map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?;
let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event {
let mut read = BTreeMap::new();
let (event_id, mut receipt) = r
.content
.0
.into_iter()
.next()
.expect("we only use one event per read receipt");
let receipt = receipt
.remove(&ReceiptType::Read)
.expect("our read receipts always set this")
@ -427,24 +437,17 @@ impl Service {
async fn send_events_dest_appservice(
&self, dest: &Destination, id: &str, events: Vec<SendingEvent>,
) -> SendingResult {
let mut pdu_jsons = Vec::new();
let Some(appservice) = self.services.appservice.get_registration(id).await else {
return Err((dest.clone(), err!(Database(warn!(?id, "Missing appservice registration")))));
};
let mut pdu_jsons = Vec::new();
for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => {
pdu_jsons.push(
self.services
.timeline
.get_pdu_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Appservice] Event in servernameevent_data not found in db."),
)
})?
.to_room_event(),
);
if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await {
pdu_jsons.push(pdu.to_room_event());
}
},
SendingEvent::Edu(_) | SendingEvent::Flush => {
// Appservices don't need EDUs (?) and flush only;
@ -453,32 +456,24 @@ impl Service {
}
}
let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash(
&events
.iter()
.map(|e| match e {
SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b,
SendingEvent::Flush => &[],
})
.collect::<Vec<_>>(),
));
//debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction");
let client = &self.services.client.appservice;
match appservice::send_request(
client,
self.services
.appservice
.get_registration(id)
.await
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Appservice] Could not load registration from db."),
)
})?,
appservice,
ruma::api::appservice::event::push_events::v1::Request {
events: pdu_jsons,
txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash(
&events
.iter()
.map(|e| match e {
SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b,
SendingEvent::Flush => &[],
})
.collect::<Vec<_>>(),
)))
.into(),
txn_id: txn_id.into(),
ephemeral: Vec::new(),
to_device: Vec::new(),
},
@ -494,23 +489,17 @@ impl Service {
async fn send_events_dest_push(
&self, dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec<SendingEvent>,
) -> SendingResult {
let mut pdus = Vec::new();
let Ok(pusher) = self.services.pusher.get_pusher(userid, pushkey).await else {
return Err((dest.clone(), err!(Database(error!(?userid, ?pushkey, "Missing pusher")))));
};
let mut pdus = Vec::new();
for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => {
pdus.push(
self.services
.timeline
.get_pdu_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Push] Event in servernameevent_data not found in db."),
)
})?,
);
if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await {
pdus.push(pdu);
}
},
SendingEvent::Edu(_) | SendingEvent::Flush => {
// Push gateways don't need EDUs (?) and flush only;
@ -529,28 +518,22 @@ impl Service {
}
}
let Some(pusher) = self
.services
.pusher
.get_pusher(userid, pushkey)
.map_err(|e| (dest.clone(), e))?
else {
continue;
};
let rules_for_user = self
.services
.account_data
.get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap_or_default()
.and_then(|event| serde_json::from_str::<PushRulesEvent>(event.get()).ok())
.map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global);
.await
.and_then(|event| serde_json::from_str::<PushRulesEvent>(event.get()).map_err(Into::into))
.map_or_else(
|_| push::Ruleset::server_default(userid),
|ev: PushRulesEvent| ev.content.global,
);
let unread: UInt = self
.services
.user
.notification_count(userid, &pdu.room_id)
.map_err(|e| (dest.clone(), e))?
.await
.try_into()
.expect("notification count can't go that high");
@ -559,7 +542,6 @@ impl Service {
.pusher
.send_push_notice(userid, unread, &pusher, rules_for_user, &pdu)
.await
.map(|_response| dest.clone())
.map_err(|e| (dest.clone(), e));
}
@ -586,21 +568,11 @@ impl Service {
for event in &events {
match event {
// TODO: check room version and remove event_id if needed
SendingEvent::Pdu(pdu_id) => pdu_jsons.push(
self.convert_to_outgoing_federation_event(
self.services
.timeline
.get_pdu_json_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
error!(?dest, ?server, ?pdu_id, "event not found");
(
dest.clone(),
Error::bad_database("[Normal] Event in servernameevent_data not found in db."),
)
})?,
),
),
SendingEvent::Pdu(pdu_id) => {
if let Ok(pdu) = self.services.timeline.get_pdu_json_from_id(pdu_id).await {
pdu_jsons.push(self.convert_to_outgoing_federation_event(pdu).await);
}
},
SendingEvent::Edu(edu) => {
if let Ok(raw) = serde_json::from_slice(edu) {
edu_jsons.push(raw);
@ -647,7 +619,7 @@ impl Service {
}
/// This does not return a full `Pdu` it is only to satisfy ruma's types.
pub fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
pub async fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
if let Some(unsigned) = pdu_json
.get_mut("unsigned")
.and_then(|val| val.as_object_mut())
@ -660,7 +632,7 @@ impl Service {
.get("room_id")
.and_then(|val| RoomId::parse(val.as_str()?).ok())
{
match self.services.state.get_room_version(&room_id) {
match self.services.state.get_room_version(&room_id).await {
Ok(room_version_id) => match room_version_id {
RoomVersionId::V1 | RoomVersionId::V2 => {},
_ => _ = pdu_json.remove("event_id"),

View file

@ -5,7 +5,7 @@ use std::{
};
use conduit::{debug, debug_error, debug_warn, err, error, info, trace, warn, Err, Result};
use futures_util::{stream::FuturesUnordered, StreamExt};
use futures::{stream::FuturesUnordered, StreamExt};
use ruma::{
api::federation::{
discovery::{
@ -179,7 +179,8 @@ impl Service {
let result: BTreeMap<_, _> = self
.services
.globals
.verify_keys_for(origin)?
.verify_keys_for(origin)
.await?
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
@ -236,7 +237,8 @@ impl Service {
.services
.globals
.db
.add_signing_key(&k.server_name, k.clone())?
.add_signing_key(&k.server_name, k.clone())
.await
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect::<BTreeMap<_, _>>();
@ -283,7 +285,8 @@ impl Service {
.services
.globals
.db
.add_signing_key(&origin, key)?
.add_signing_key(&origin, key)
.await
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
@ -384,7 +387,8 @@ impl Service {
let mut result: BTreeMap<_, _> = self
.services
.globals
.verify_keys_for(origin)?
.verify_keys_for(origin)
.await?
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
@ -431,7 +435,8 @@ impl Service {
self.services
.globals
.db
.add_signing_key(origin, k.clone())?;
.add_signing_key(origin, k.clone())
.await;
result.extend(
k.verify_keys
.into_iter()
@ -462,7 +467,8 @@ impl Service {
self.services
.globals
.db
.add_signing_key(origin, server_key.clone())?;
.add_signing_key(origin, server_key.clone())
.await;
result.extend(
server_key
@ -495,7 +501,8 @@ impl Service {
self.services
.globals
.db
.add_signing_key(origin, server_key.clone())?;
.add_signing_key(origin, server_key.clone())
.await;
result.extend(
server_key
@ -545,7 +552,8 @@ impl Service {
self.services
.globals
.db
.add_signing_key(origin, k.clone())?;
.add_signing_key(origin, k.clone())
.await;
result.extend(
k.verify_keys
.into_iter()

View file

@ -14,7 +14,7 @@ use crate::{
manager::Manager,
media, presence, pusher, resolver, rooms, sending, server_keys, service,
service::{Args, Map, Service},
transaction_ids, uiaa, updates, users,
sync, transaction_ids, uiaa, updates, users,
};
pub struct Services {
@ -32,6 +32,7 @@ pub struct Services {
pub rooms: rooms::Service,
pub sending: Arc<sending::Service>,
pub server_keys: Arc<server_keys::Service>,
pub sync: Arc<sync::Service>,
pub transaction_ids: Arc<transaction_ids::Service>,
pub uiaa: Arc<uiaa::Service>,
pub updates: Arc<updates::Service>,
@ -96,6 +97,7 @@ impl Services {
},
sending: build!(sending::Service),
server_keys: build!(server_keys::Service),
sync: build!(sync::Service),
transaction_ids: build!(transaction_ids::Service),
uiaa: build!(uiaa::Service),
updates: build!(updates::Service),

233
src/service/sync/mod.rs Normal file
View file

@ -0,0 +1,233 @@
use std::{
collections::{BTreeMap, BTreeSet},
sync::{Arc, Mutex, Mutex as StdMutex},
};
use conduit::Result;
use ruma::{
api::client::sync::sync_events::{
self,
v4::{ExtensionsConfig, SyncRequestList},
},
OwnedDeviceId, OwnedRoomId, OwnedUserId,
};
pub struct Service {
connections: DbConnections,
}
struct SlidingSyncCache {
lists: BTreeMap<String, SyncRequestList>,
subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>,
known_rooms: BTreeMap<String, BTreeMap<OwnedRoomId, u64>>, // For every room, the roomsince number
extensions: ExtensionsConfig,
}
type DbConnections = Mutex<BTreeMap<DbConnectionsKey, DbConnectionsVal>>;
type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String);
type DbConnectionsVal = Arc<Mutex<SlidingSyncCache>>;
impl crate::Service for Service {
fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
connections: StdMutex::new(BTreeMap::new()),
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
pub fn remembered(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) -> bool {
self.connections
.lock()
.unwrap()
.contains_key(&(user_id, device_id, conn_id))
}
pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) {
self.connections
.lock()
.expect("locked")
.remove(&(user_id, device_id, conn_id));
}
pub fn update_sync_request_with_cache(
&self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request,
) -> BTreeMap<String, BTreeMap<OwnedRoomId, u64>> {
let Some(conn_id) = request.conn_id.clone() else {
return BTreeMap::new();
};
let mut cache = self.connections.lock().expect("locked");
let cached = Arc::clone(
cache
.entry((user_id, device_id, conn_id))
.or_insert_with(|| {
Arc::new(Mutex::new(SlidingSyncCache {
lists: BTreeMap::new(),
subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(),
}))
}),
);
let cached = &mut cached.lock().expect("locked");
drop(cache);
for (list_id, list) in &mut request.lists {
if let Some(cached_list) = cached.lists.get(list_id) {
if list.sort.is_empty() {
list.sort.clone_from(&cached_list.sort);
};
if list.room_details.required_state.is_empty() {
list.room_details
.required_state
.clone_from(&cached_list.room_details.required_state);
};
list.room_details.timeline_limit = list
.room_details
.timeline_limit
.or(cached_list.room_details.timeline_limit);
list.include_old_rooms = list
.include_old_rooms
.clone()
.or_else(|| cached_list.include_old_rooms.clone());
match (&mut list.filters, cached_list.filters.clone()) {
(Some(list_filters), Some(cached_filters)) => {
list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm);
if list_filters.spaces.is_empty() {
list_filters.spaces = cached_filters.spaces;
}
list_filters.is_encrypted = list_filters.is_encrypted.or(cached_filters.is_encrypted);
list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite);
if list_filters.room_types.is_empty() {
list_filters.room_types = cached_filters.room_types;
}
if list_filters.not_room_types.is_empty() {
list_filters.not_room_types = cached_filters.not_room_types;
}
list_filters.room_name_like = list_filters
.room_name_like
.clone()
.or(cached_filters.room_name_like);
if list_filters.tags.is_empty() {
list_filters.tags = cached_filters.tags;
}
if list_filters.not_tags.is_empty() {
list_filters.not_tags = cached_filters.not_tags;
}
},
(_, Some(cached_filters)) => list.filters = Some(cached_filters),
(Some(list_filters), _) => list.filters = Some(list_filters.clone()),
(..) => {},
}
if list.bump_event_types.is_empty() {
list.bump_event_types
.clone_from(&cached_list.bump_event_types);
};
}
cached.lists.insert(list_id.clone(), list.clone());
}
cached
.subscriptions
.extend(request.room_subscriptions.clone());
request
.room_subscriptions
.extend(cached.subscriptions.clone());
request.extensions.e2ee.enabled = request
.extensions
.e2ee
.enabled
.or(cached.extensions.e2ee.enabled);
request.extensions.to_device.enabled = request
.extensions
.to_device
.enabled
.or(cached.extensions.to_device.enabled);
request.extensions.account_data.enabled = request
.extensions
.account_data
.enabled
.or(cached.extensions.account_data.enabled);
request.extensions.account_data.lists = request
.extensions
.account_data
.lists
.clone()
.or_else(|| cached.extensions.account_data.lists.clone());
request.extensions.account_data.rooms = request
.extensions
.account_data
.rooms
.clone()
.or_else(|| cached.extensions.account_data.rooms.clone());
cached.extensions = request.extensions.clone();
cached.known_rooms.clone()
}
pub fn update_sync_subscriptions(
&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String,
subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>,
) {
let mut cache = self.connections.lock().expect("locked");
let cached = Arc::clone(
cache
.entry((user_id, device_id, conn_id))
.or_insert_with(|| {
Arc::new(Mutex::new(SlidingSyncCache {
lists: BTreeMap::new(),
subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(),
}))
}),
);
let cached = &mut cached.lock().expect("locked");
drop(cache);
cached.subscriptions = subscriptions;
}
pub fn update_sync_known_rooms(
&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String,
new_cached_rooms: BTreeSet<OwnedRoomId>, globalsince: u64,
) {
let mut cache = self.connections.lock().expect("locked");
let cached = Arc::clone(
cache
.entry((user_id, device_id, conn_id))
.or_insert_with(|| {
Arc::new(Mutex::new(SlidingSyncCache {
lists: BTreeMap::new(),
subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(),
}))
}),
);
let cached = &mut cached.lock().expect("locked");
drop(cache);
for (roomid, lastsince) in cached
.known_rooms
.entry(list_id.clone())
.or_default()
.iter_mut()
{
if !new_cached_rooms.contains(roomid) {
*lastsince = 0;
}
}
let list = cached.known_rooms.entry(list_id).or_default();
for roomid in new_cached_rooms {
list.insert(roomid, globalsince);
}
}
}

View file

@ -1,44 +0,0 @@
use std::sync::Arc;
use conduit::Result;
use database::{Database, Map};
use ruma::{DeviceId, TransactionId, UserId};
pub struct Data {
userdevicetxnid_response: Arc<Map>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
userdevicetxnid_response: db["userdevicetxnid_response"].clone(),
}
}
pub(super) fn add_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
self.userdevicetxnid_response.insert(&key, data)?;
Ok(())
}
pub(super) fn existing_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
) -> Result<Option<database::Handle<'_>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
// If there's no entry, this is a new transaction
self.userdevicetxnid_response.get(&key)
}
}

View file

@ -1,35 +1,45 @@
mod data;
use std::sync::Arc;
use conduit::Result;
use data::Data;
use conduit::{implement, Result};
use database::{Handle, Map};
use ruma::{DeviceId, TransactionId, UserId};
pub struct Service {
pub db: Data,
db: Data,
}
struct Data {
userdevicetxnid_response: Arc<Map>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data {
userdevicetxnid_response: args.db["userdevicetxnid_response"].clone(),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
pub fn add_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
) -> Result<()> {
self.db.add_txnid(user_id, device_id, txn_id, data)
}
#[implement(Service)]
pub fn add_txnid(&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8]) {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
pub fn existing_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
) -> Result<Option<database::Handle<'_>>> {
self.db.existing_txnid(user_id, device_id, txn_id)
}
self.db.userdevicetxnid_response.insert(&key, data);
}
// If there's no entry, this is a new transaction
#[implement(Service)]
pub async fn existing_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
) -> Result<Handle<'_>> {
let key = (user_id, device_id, txn_id);
self.db.userdevicetxnid_response.qry(&key).await
}

View file

@ -1,87 +0,0 @@
use std::{
collections::BTreeMap,
sync::{Arc, RwLock},
};
use conduit::{Error, Result};
use database::{Database, Map};
use ruma::{
api::client::{error::ErrorKind, uiaa::UiaaInfo},
CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId,
};
pub struct Data {
userdevicesessionid_uiaarequest: RwLock<BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>>,
userdevicesessionid_uiaainfo: Arc<Map>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
userdevicesessionid_uiaainfo: db["userdevicesessionid_uiaainfo"].clone(),
}
}
pub(super) fn set_uiaa_request(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
) -> Result<()> {
self.userdevicesessionid_uiaarequest
.write()
.unwrap()
.insert(
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request.to_owned(),
);
Ok(())
}
pub(super) fn get_uiaa_request(
&self, user_id: &UserId, device_id: &DeviceId, session: &str,
) -> Option<CanonicalJsonValue> {
self.userdevicesessionid_uiaarequest
.read()
.unwrap()
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
.map(ToOwned::to_owned)
}
pub(super) fn update_uiaa_session(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>,
) -> Result<()> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
if let Some(uiaainfo) = uiaainfo {
self.userdevicesessionid_uiaainfo.insert(
&userdevicesessionid,
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
)?;
} else {
self.userdevicesessionid_uiaainfo
.remove(&userdevicesessionid)?;
}
Ok(())
}
pub(super) fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
serde_json::from_slice(
&self
.userdevicesessionid_uiaainfo
.get(&userdevicesessionid)?
.ok_or(Error::BadRequest(ErrorKind::forbidden(), "UIAA session does not exist."))?,
)
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
}
}

View file

@ -1,174 +1,243 @@
mod data;
use std::{
collections::BTreeMap,
sync::{Arc, RwLock},
};
use std::sync::Arc;
use conduit::{error, utils, utils::hash, Error, Result, Server};
use data::Data;
use conduit::{
err, error, implement, utils,
utils::{hash, string::EMPTY},
Error, Result, Server,
};
use database::{Deserialized, Map};
use ruma::{
api::client::{
error::ErrorKind,
uiaa::{AuthData, AuthType, Password, UiaaInfo, UserIdentifier},
},
CanonicalJsonValue, DeviceId, UserId,
CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId,
};
use crate::{globals, users, Dep};
pub const SESSION_ID_LENGTH: usize = 32;
pub struct Service {
server: Arc<Server>,
userdevicesessionid_uiaarequest: RwLock<RequestMap>,
db: Data,
services: Services,
pub db: Data,
}
struct Services {
server: Arc<Server>,
globals: Dep<globals::Service>,
users: Dep<users::Service>,
}
struct Data {
userdevicesessionid_uiaainfo: Arc<Map>,
}
type RequestMap = BTreeMap<RequestKey, CanonicalJsonValue>;
type RequestKey = (OwnedUserId, OwnedDeviceId, String);
pub const SESSION_ID_LENGTH: usize = 32;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
server: args.server.clone(),
userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()),
db: Data {
userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(),
},
services: Services {
server: args.server.clone(),
globals: args.depend::<globals::Service>("globals"),
users: args.depend::<users::Service>("users"),
},
db: Data::new(args.db),
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Creates a new Uiaa session. Make sure the session token is unique.
pub fn create(
&self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue,
) -> Result<()> {
self.db.set_uiaa_request(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session should be set"), /* TODO: better session error handling (why
* is it optional in ruma?) */
json_body,
)?;
self.db.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session should be set"),
Some(uiaainfo),
)
/// Creates a new Uiaa session. Make sure the session token is unique.
#[implement(Service)]
pub fn create(&self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue) {
// TODO: better session error handling (why is uiaainfo.session optional in
// ruma?)
self.set_uiaa_request(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session should be set"),
json_body,
);
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session should be set"),
Some(uiaainfo),
);
}
#[implement(Service)]
pub async fn try_auth(
&self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo,
) -> Result<(bool, UiaaInfo)> {
let mut uiaainfo = if let Some(session) = auth.session() {
self.get_uiaa_session(user_id, device_id, session).await?
} else {
uiaainfo.clone()
};
if uiaainfo.session.is_none() {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
}
pub fn try_auth(
&self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo,
) -> Result<(bool, UiaaInfo)> {
let mut uiaainfo = auth.session().map_or_else(
|| Ok(uiaainfo.clone()),
|session| self.db.get_uiaa_session(user_id, device_id, session),
)?;
match auth {
// Find out what the user completed
AuthData::Password(Password {
identifier,
password,
#[cfg(feature = "element_hacks")]
user,
..
}) => {
#[cfg(feature = "element_hacks")]
let username = if let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier {
username
} else if let Some(username) = user {
username
} else {
return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized."));
};
if uiaainfo.session.is_none() {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
}
#[cfg(not(feature = "element_hacks"))]
let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier
else {
return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized."));
};
match auth {
// Find out what the user completed
AuthData::Password(Password {
identifier,
password,
#[cfg(feature = "element_hacks")]
user,
..
}) => {
#[cfg(feature = "element_hacks")]
let username = if let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier {
username
} else if let Some(username) = user {
username
} else {
return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized."));
};
let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name())
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?;
#[cfg(not(feature = "element_hacks"))]
let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier
else {
return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized."));
};
let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name())
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?;
// Check if password is correct
if let Some(hash) = self.services.users.password_hash(&user_id)? {
let hash_matches = hash::verify_password(password, &hash).is_ok();
if !hash_matches {
uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody {
kind: ErrorKind::forbidden(),
message: "Invalid username or password.".to_owned(),
});
return Ok((false, uiaainfo));
}
}
// Password was correct! Let's add it to `completed`
uiaainfo.completed.push(AuthType::Password);
},
AuthData::RegistrationToken(t) => {
if Some(t.token.trim()) == self.server.config.registration_token.as_deref() {
uiaainfo.completed.push(AuthType::RegistrationToken);
} else {
// Check if password is correct
if let Ok(hash) = self.services.users.password_hash(&user_id).await {
let hash_matches = hash::verify_password(password, &hash).is_ok();
if !hash_matches {
uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody {
kind: ErrorKind::forbidden(),
message: "Invalid registration token.".to_owned(),
message: "Invalid username or password.".to_owned(),
});
return Ok((false, uiaainfo));
}
},
AuthData::Dummy(_) => {
uiaainfo.completed.push(AuthType::Dummy);
},
k => error!("type not supported: {:?}", k),
}
// Check if a flow now succeeds
let mut completed = false;
'flows: for flow in &mut uiaainfo.flows {
for stage in &flow.stages {
if !uiaainfo.completed.contains(stage) {
continue 'flows;
}
}
// We didn't break, so this flow succeeded!
completed = true;
}
if !completed {
self.db.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
Some(&uiaainfo),
)?;
return Ok((false, uiaainfo));
}
// Password was correct! Let's add it to `completed`
uiaainfo.completed.push(AuthType::Password);
},
AuthData::RegistrationToken(t) => {
if Some(t.token.trim()) == self.services.server.config.registration_token.as_deref() {
uiaainfo.completed.push(AuthType::RegistrationToken);
} else {
uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody {
kind: ErrorKind::forbidden(),
message: "Invalid registration token.".to_owned(),
});
return Ok((false, uiaainfo));
}
},
AuthData::Dummy(_) => {
uiaainfo.completed.push(AuthType::Dummy);
},
k => error!("type not supported: {:?}", k),
}
// UIAA was successful! Remove this session and return true
self.db.update_uiaa_session(
// Check if a flow now succeeds
let mut completed = false;
'flows: for flow in &mut uiaainfo.flows {
for stage in &flow.stages {
if !uiaainfo.completed.contains(stage) {
continue 'flows;
}
}
// We didn't break, so this flow succeeded!
completed = true;
}
if !completed {
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
None,
)?;
Ok((true, uiaainfo))
Some(&uiaainfo),
);
return Ok((false, uiaainfo));
}
#[must_use]
pub fn get_uiaa_request(
&self, user_id: &UserId, device_id: &DeviceId, session: &str,
) -> Option<CanonicalJsonValue> {
self.db.get_uiaa_request(user_id, device_id, session)
// UIAA was successful! Remove this session and return true
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
None,
);
Ok((true, uiaainfo))
}
#[implement(Service)]
fn set_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue) {
let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned());
self.userdevicesessionid_uiaarequest
.write()
.expect("locked for writing")
.insert(key, request.to_owned());
}
#[implement(Service)]
pub fn get_uiaa_request(
&self, user_id: &UserId, device_id: Option<&DeviceId>, session: &str,
) -> Option<CanonicalJsonValue> {
let key = (
user_id.to_owned(),
device_id.unwrap_or_else(|| EMPTY.into()).to_owned(),
session.to_owned(),
);
self.userdevicesessionid_uiaarequest
.read()
.expect("locked for reading")
.get(&key)
.cloned()
}
#[implement(Service)]
fn update_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>) {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
if let Some(uiaainfo) = uiaainfo {
self.db.userdevicesessionid_uiaainfo.insert(
&userdevicesessionid,
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
);
} else {
self.db
.userdevicesessionid_uiaainfo
.remove(&userdevicesessionid);
}
}
#[implement(Service)]
async fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
let key = (user_id, device_id, session);
self.db
.userdevicesessionid_uiaainfo
.qry(&key)
.await
.deserialized_json()
.map_err(|_| err!(Request(Forbidden("UIAA session does not exist."))))
}

View file

@ -1,19 +1,22 @@
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use conduit::{debug, err, info, utils, warn, Error, Result};
use database::Map;
use conduit::{debug, info, warn, Result};
use database::{Deserialized, Map};
use ruma::events::room::message::RoomMessageEventContent;
use serde::Deserialize;
use tokio::{sync::Notify, time::interval};
use tokio::{
sync::Notify,
time::{interval, MissedTickBehavior},
};
use crate::{admin, client, globals, Dep};
pub struct Service {
services: Services,
db: Arc<Map>,
interrupt: Notify,
interval: Duration,
interrupt: Notify,
db: Arc<Map>,
services: Services,
}
struct Services {
@ -22,12 +25,12 @@ struct Services {
globals: Dep<globals::Service>,
}
#[derive(Deserialize)]
#[derive(Debug, Deserialize)]
struct CheckForUpdatesResponse {
updates: Vec<CheckForUpdatesResponseEntry>,
}
#[derive(Deserialize)]
#[derive(Debug, Deserialize)]
struct CheckForUpdatesResponseEntry {
id: u64,
date: String,
@ -42,33 +45,38 @@ const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL),
interrupt: Notify::new(),
db: args.db["global"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
admin: args.depend::<admin::Service>("admin"),
client: args.depend::<client::Service>("client"),
},
db: args.db["global"].clone(),
interrupt: Notify::new(),
interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL),
}))
}
#[tracing::instrument(skip_all, name = "updates", level = "trace")]
async fn worker(self: Arc<Self>) -> Result<()> {
if !self.services.globals.allow_check_for_updates() {
debug!("Disabling update check");
return Ok(());
}
let mut i = interval(self.interval);
i.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
tokio::select! {
() = self.interrupt.notified() => return Ok(()),
() = self.interrupt.notified() => break,
_ = i.tick() => (),
}
if let Err(e) = self.handle_updates().await {
if let Err(e) = self.check().await {
warn!(%e, "Failed to check for updates");
}
}
Ok(())
}
fn interrupt(&self) { self.interrupt.notify_waiters(); }
@ -77,52 +85,52 @@ impl crate::Service for Service {
}
impl Service {
#[tracing::instrument(skip_all)]
async fn handle_updates(&self) -> Result<()> {
#[tracing::instrument(skip_all, level = "trace")]
async fn check(&self) -> Result<()> {
let response = self
.services
.client
.default
.get(CHECK_FOR_UPDATES_URL)
.send()
.await?
.text()
.await?;
let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?)
.map_err(|e| err!("Bad check for updates response: {e}"))?;
let mut last_update_id = self.last_check_for_updates_id()?;
for update in response.updates {
last_update_id = last_update_id.max(update.id);
if update.id > self.last_check_for_updates_id()? {
info!("{:#}", update.message);
self.services
.admin
.send_message(RoomMessageEventContent::text_markdown(format!(
"### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}",
update.date, update.message
)))
.await;
let response = serde_json::from_str::<CheckForUpdatesResponse>(&response)?;
for update in &response.updates {
if update.id > self.last_check_for_updates_id().await {
self.handle(update).await;
self.update_check_for_updates_id(update.id);
}
}
self.update_check_for_updates_id(last_update_id)?;
Ok(())
}
async fn handle(&self, update: &CheckForUpdatesResponseEntry) {
info!("{} {:#}", update.date, update.message);
self.services
.admin
.send_message(RoomMessageEventContent::text_markdown(format!(
"### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}",
update.date, update.message
)))
.await
.ok();
}
#[inline]
pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
pub fn update_check_for_updates_id(&self, id: u64) {
self.db
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
Ok(())
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes());
}
pub fn last_check_for_updates_id(&self) -> Result<u64> {
pub async fn last_check_for_updates_id(&self) -> u64 {
self.db
.get(LAST_CHECK_FOR_UPDATES_COUNT)?
.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("last check for updates count has invalid bytes."))
})
.qry(LAST_CHECK_FOR_UPDATES_COUNT)
.await
.deserialized()
.unwrap_or(0_u64)
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff