From b94045a4685d274bc273cb8f5ec224c1d24c48c1 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 26 May 2024 21:29:19 +0000 Subject: [PATCH] dissolve key_value/* Signed-off-by: Jason Volk --- src/api/client_server/sync.rs | 2 +- src/service/account_data/data.rs | 130 ++- src/service/appservice/data.rs | 54 +- src/service/globals/data.rs | 297 +++++- src/service/globals/mod.rs | 10 +- src/service/key_backups/data.rs | 314 +++++- src/service/key_backups/mod.rs | 4 +- src/service/key_value/account_data.rs | 137 --- src/service/key_value/appservice.rs | 55 -- src/service/key_value/globals.rs | 299 ------ src/service/key_value/key_backups.rs | 317 ------- src/service/key_value/media.rs | 235 ----- src/service/key_value/mod.rs | 14 - src/service/key_value/presence.rs | 127 --- src/service/key_value/pusher.rs | 65 -- src/service/key_value/rooms/alias.rs | 75 -- src/service/key_value/rooms/auth_chain.rs | 59 -- src/service/key_value/rooms/directory.rs | 23 - src/service/key_value/rooms/lazy_load.rs | 53 -- src/service/key_value/rooms/metadata.rs | 76 -- src/service/key_value/rooms/mod.rs | 21 - src/service/key_value/rooms/outlier.rs | 28 - src/service/key_value/rooms/pdu_metadata.rs | 79 -- src/service/key_value/rooms/read_receipt.rs | 120 --- src/service/key_value/rooms/search.rs | 64 -- src/service/key_value/rooms/short.rs | 164 ---- src/service/key_value/rooms/state.rs | 73 -- src/service/key_value/rooms/state_accessor.rs | 165 ---- src/service/key_value/rooms/state_cache.rs | 626 ------------ .../key_value/rooms/state_compressor.rs | 60 -- src/service/key_value/rooms/threads.rs | 75 -- src/service/key_value/rooms/timeline.rs | 295 ------ src/service/key_value/rooms/user.rs | 137 --- src/service/key_value/sending.rs | 191 ---- src/service/key_value/transaction_ids.rs | 32 - src/service/key_value/uiaa.rs | 68 -- src/service/key_value/users.rs | 898 ------------------ src/service/media/data.rs | 242 ++++- src/service/media/mod.rs | 4 +- src/service/mod.rs | 1 - src/service/presence/data.rs | 126 ++- src/service/presence/mod.rs | 2 +- src/service/pusher/data.rs | 63 +- src/service/pusher/mod.rs | 4 +- src/service/rooms/alias/data.rs | 76 +- src/service/rooms/alias/mod.rs | 2 +- src/service/rooms/auth_chain/data.rs | 60 +- src/service/rooms/auth_chain/mod.rs | 2 +- src/service/rooms/directory/data.rs | 22 +- src/service/rooms/directory/mod.rs | 2 +- src/service/rooms/lazy_loading/data.rs | 52 +- src/service/rooms/lazy_loading/mod.rs | 2 +- src/service/rooms/metadata/data.rs | 75 +- src/service/rooms/metadata/mod.rs | 2 +- src/service/rooms/mod.rs | 21 - src/service/rooms/outlier/data.rs | 27 +- src/service/rooms/outlier/mod.rs | 2 +- src/service/rooms/pdu_metadata/data.rs | 78 +- src/service/rooms/pdu_metadata/mod.rs | 2 +- src/service/rooms/read_receipt/data.rs | 121 ++- src/service/rooms/read_receipt/mod.rs | 2 +- src/service/rooms/search/data.rs | 61 +- src/service/rooms/search/mod.rs | 2 +- src/service/rooms/short/data.rs | 161 +++- src/service/rooms/short/mod.rs | 2 +- src/service/rooms/state/data.rs | 69 +- src/service/rooms/state/mod.rs | 2 +- src/service/rooms/state_accessor/data.rs | 161 +++- src/service/rooms/state_accessor/mod.rs | 2 +- src/service/rooms/state_cache/data.rs | 616 +++++++++++- src/service/rooms/state_cache/mod.rs | 2 +- src/service/rooms/state_compressor/data.rs | 61 +- src/service/rooms/state_compressor/mod.rs | 4 +- src/service/rooms/threads/data.rs | 72 +- src/service/rooms/threads/mod.rs | 2 +- src/service/rooms/timeline/data.rs | 296 +++++- src/service/rooms/timeline/mod.rs | 6 +- src/service/rooms/user/data.rs | 136 ++- src/service/rooms/user/mod.rs | 2 +- src/service/sending/data.rs | 189 +++- src/service/sending/mod.rs | 12 +- src/service/sending/sender.rs | 6 +- src/service/transaction_ids/data.rs | 33 +- src/service/transaction_ids/mod.rs | 4 +- src/service/uiaa/data.rs | 71 +- src/service/uiaa/mod.rs | 4 +- src/service/users/data.rs | 897 ++++++++++++++++- src/service/users/mod.rs | 2 +- 88 files changed, 4556 insertions(+), 4751 deletions(-) delete mode 100644 src/service/key_value/account_data.rs delete mode 100644 src/service/key_value/appservice.rs delete mode 100644 src/service/key_value/globals.rs delete mode 100644 src/service/key_value/key_backups.rs delete mode 100644 src/service/key_value/media.rs delete mode 100644 src/service/key_value/mod.rs delete mode 100644 src/service/key_value/presence.rs delete mode 100644 src/service/key_value/pusher.rs delete mode 100644 src/service/key_value/rooms/alias.rs delete mode 100644 src/service/key_value/rooms/auth_chain.rs delete mode 100644 src/service/key_value/rooms/directory.rs delete mode 100644 src/service/key_value/rooms/lazy_load.rs delete mode 100644 src/service/key_value/rooms/metadata.rs delete mode 100644 src/service/key_value/rooms/mod.rs delete mode 100644 src/service/key_value/rooms/outlier.rs delete mode 100644 src/service/key_value/rooms/pdu_metadata.rs delete mode 100644 src/service/key_value/rooms/read_receipt.rs delete mode 100644 src/service/key_value/rooms/search.rs delete mode 100644 src/service/key_value/rooms/short.rs delete mode 100644 src/service/key_value/rooms/state.rs delete mode 100644 src/service/key_value/rooms/state_accessor.rs delete mode 100644 src/service/key_value/rooms/state_cache.rs delete mode 100644 src/service/key_value/rooms/state_compressor.rs delete mode 100644 src/service/key_value/rooms/threads.rs delete mode 100644 src/service/key_value/rooms/timeline.rs delete mode 100644 src/service/key_value/rooms/user.rs delete mode 100644 src/service/key_value/sending.rs delete mode 100644 src/service/key_value/transaction_ids.rs delete mode 100644 src/service/key_value/uiaa.rs delete mode 100644 src/service/key_value/users.rs diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index 09e24f26..5925fc53 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -142,7 +142,7 @@ pub(crate) async fn sync_events_route( .collect::>(); // Coalesce database writes for the remainder of this scope. - let _cork = services().globals.db.cork_and_flush()?; + let _cork = services().globals.cork_and_flush()?; for room_id in all_joined_rooms { let room_id = room_id?; diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs index 492c500c..565990d2 100644 --- a/src/service/account_data/data.rs +++ b/src/service/account_data/data.rs @@ -1,12 +1,14 @@ use std::collections::HashMap; use ruma::{ + api::client::error::ErrorKind, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, serde::Raw, RoomId, UserId, }; +use tracing::warn; -use crate::Result; +use crate::{services, utils, Error, KeyValueDatabase, Result}; pub trait Data: Send + Sync { /// Places one event in the account data of the user and removes the @@ -26,3 +28,129 @@ pub trait Data: Send + Sync { &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, ) -> Result>>; } + +impl Data for KeyValueDatabase { + /// Places one event in the account data of the user and removes the + /// previous entry. + #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] + fn update( + &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, + data: &serde_json::Value, + ) -> Result<()> { + let mut prefix = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xFF); + + let mut roomuserdataid = prefix.clone(); + roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + roomuserdataid.push(0xFF); + roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); + + let mut key = prefix; + key.extend_from_slice(event_type.to_string().as_bytes()); + + if data.get("type").is_none() || data.get("content").is_none() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Account data doesn't have all required fields.", + )); + } + + self.roomuserdataid_accountdata.insert( + &roomuserdataid, + &serde_json::to_vec(&data).expect("to_vec always works on json values"), + )?; + + let prev = self.roomusertype_roomuserdataid.get(&key)?; + + self.roomusertype_roomuserdataid + .insert(&key, &roomuserdataid)?; + + // Remove old entry + if let Some(prev) = prev { + self.roomuserdataid_accountdata.remove(&prev)?; + } + + Ok(()) + } + + /// Searches the account data for a specific kind. + #[tracing::instrument(skip(self, room_id, user_id, kind))] + fn get( + &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, + ) -> Result>> { + let mut key = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(kind.to_string().as_bytes()); + + self.roomusertype_roomuserdataid + .get(&key)? + .and_then(|roomuserdataid| { + self.roomuserdataid_accountdata + .get(&roomuserdataid) + .transpose() + }) + .transpose()? + .map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize"))) + .transpose() + } + + /// Returns all changes to the account data that happened after `since`. + #[tracing::instrument(skip(self, room_id, user_id, since))] + fn changes_since( + &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, + ) -> Result>> { + let mut userdata = HashMap::new(); + + let mut prefix = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xFF); + + // Skip the data that's exactly at since, because we sent that last time + let mut first_possible = prefix.clone(); + first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); + + for r in self + .roomuserdataid_accountdata + .iter_from(&first_possible, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(k, v)| { + Ok::<_, Error>(( + RoomAccountDataEventType::from( + utils::string_from_bytes( + k.rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| Error::bad_database("RoomUserData ID in db is invalid."))?, + ) + .map_err(|e| { + warn!("RoomUserData ID in database is invalid: {}", e); + Error::bad_database("RoomUserData ID in db is invalid.") + })?, + ), + serde_json::from_slice::>(&v) + .map_err(|_| Error::bad_database("Database contains invalid account data."))?, + )) + }) { + let (kind, data) = r?; + userdata.insert(kind, data); + } + + Ok(userdata) + } +} diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index 52c8b34d..e81cb7ac 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -1,6 +1,6 @@ use ruma::api::appservice::Registration; -use crate::Result; +use crate::{utils, Error, KeyValueDatabase, Result}; pub trait Data: Send + Sync { /// Registers an appservice and returns the ID to the caller @@ -19,3 +19,55 @@ pub trait Data: Send + Sync { fn all(&self) -> Result>; } + +impl Data for KeyValueDatabase { + /// Registers an appservice and returns the ID to the caller + fn register_appservice(&self, yaml: Registration) -> Result { + let id = yaml.id.as_str(); + self.id_appserviceregistrations + .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; + + Ok(id.to_owned()) + } + + /// Remove an appservice registration + /// + /// # Arguments + /// + /// * `service_name` - the name you send to register the service previously + fn unregister_appservice(&self, service_name: &str) -> Result<()> { + self.id_appserviceregistrations + .remove(service_name.as_bytes())?; + Ok(()) + } + + fn get_registration(&self, id: &str) -> Result> { + self.id_appserviceregistrations + .get(id.as_bytes())? + .map(|bytes| { + serde_yaml::from_slice(&bytes) + .map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")) + }) + .transpose() + } + + fn iter_ids<'a>(&'a self) -> Result> + 'a>> { + Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| { + utils::string_from_bytes(&id) + .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) + }))) + } + + fn all(&self) -> Result> { + self.iter_ids()? + .filter_map(Result::ok) + .map(move |id| { + Ok(( + id.clone(), + self.get_registration(&id)? + .expect("iter_ids only returns appservices that exist"), + )) + }) + .collect() + } +} diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 59ed4534..f1fa621d 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,13 +1,19 @@ -use std::{collections::BTreeMap, error::Error}; +use std::collections::{BTreeMap, HashMap}; use async_trait::async_trait; +use futures_util::{stream::FuturesUnordered, StreamExt}; +use lru_cache::LruCache; use ruma::{ api::federation::discovery::{ServerSigningKeys, VerifyKey}, signatures::Ed25519KeyPair, - DeviceId, OwnedServerSigningKeyId, ServerName, UserId, + DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, }; +use tracing::trace; -use crate::{database::Cork, Result}; +use crate::{database::Cork, services, utils, Error, KeyValueDatabase, Result}; + +const COUNTER: &[u8] = b"c"; +const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; #[async_trait] pub trait Data: Send + Sync { @@ -41,7 +47,290 @@ pub trait Data: Send + Sync { fn signing_keys_for(&self, origin: &ServerName) -> Result>; fn database_version(&self) -> Result; fn bump_database_version(&self, new_version: u64) -> Result<()>; - fn backup(&self) -> Result<(), Box> { unimplemented!() } + fn backup(&self) -> Result<(), Box> { unimplemented!() } fn backup_list(&self) -> Result { Ok(String::new()) } fn file_list(&self) -> Result { Ok(String::new()) } } + +#[async_trait] +impl Data for KeyValueDatabase { + fn next_count(&self) -> Result { + utils::u64_from_bytes(&self.global.increment(COUNTER)?) + .map_err(|_| Error::bad_database("Count has invalid bytes.")) + } + + fn current_count(&self) -> Result { + self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes.")) + }) + } + + fn last_check_for_updates_id(&self) -> Result { + self.global + .get(LAST_CHECK_FOR_UPDATES_COUNT)? + .map_or(Ok(0_u64), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) + }) + } + + fn update_check_for_updates_id(&self, id: u64) -> Result<()> { + self.global + .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; + + Ok(()) + } + + #[allow(unused_qualifications)] // async traits + #[tracing::instrument(skip(self))] + async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + let userid_bytes = user_id.as_bytes().to_vec(); + let mut userid_prefix = userid_bytes.clone(); + userid_prefix.push(0xFF); + + let mut userdeviceid_prefix = userid_prefix.clone(); + userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); + userdeviceid_prefix.push(0xFF); + + let mut futures = FuturesUnordered::new(); + + // Return when *any* user changed their key + // TODO: only send for user they share a room with + futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix)); + + futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); + futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); + futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); + futures.push( + self.userroomid_notificationcount + .watch_prefix(&userid_prefix), + ); + futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); + + // Events for rooms we are in + for room_id in services() + .rooms + .state_cache + .rooms_joined(user_id) + .filter_map(Result::ok) + { + let short_roomid = services() + .rooms + .short + .get_shortroomid(&room_id) + .ok() + .flatten() + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let roomid_bytes = room_id.as_bytes().to_vec(); + let mut roomid_prefix = roomid_bytes.clone(); + roomid_prefix.push(0xFF); + + // PDUs + futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); + + // EDUs + futures.push(Box::pin(async move { + let _result = services().rooms.typing.wait_for_update(&room_id).await; + })); + + futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); + + // Key changes + futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); + + // Room account data + let mut roomuser_prefix = roomid_prefix.clone(); + roomuser_prefix.extend_from_slice(&userid_prefix); + + futures.push( + self.roomusertype_roomuserdataid + .watch_prefix(&roomuser_prefix), + ); + } + + let mut globaluserdata_prefix = vec![0xFF]; + globaluserdata_prefix.extend_from_slice(&userid_prefix); + + futures.push( + self.roomusertype_roomuserdataid + .watch_prefix(&globaluserdata_prefix), + ); + + // More key changes (used when user is not joined to any rooms) + futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); + + // One time keys + futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); + + futures.push(Box::pin(services().globals.rotate.watch())); + + // Wait until one of them finds something + trace!(futures = futures.len(), "watch started"); + futures.next().await; + trace!(futures = futures.len(), "watch finished"); + + Ok(()) + } + + fn cleanup(&self) -> Result<()> { self.db.cleanup() } + + fn flush(&self) -> Result<()> { self.db.flush() } + + fn cork(&self) -> Result { Ok(Cork::new(&self.db, false, false)) } + + fn cork_and_flush(&self) -> Result { Ok(Cork::new(&self.db, true, false)) } + + fn cork_and_sync(&self) -> Result { Ok(Cork::new(&self.db, true, true)) } + + fn memory_usage(&self) -> String { + let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len(); + let our_real_users_cache = self.our_real_users_cache.read().unwrap().len(); + let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len(); + let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len(); + + let max_auth_chain_cache = self.auth_chain_cache.lock().unwrap().capacity(); + let max_our_real_users_cache = self.our_real_users_cache.read().unwrap().capacity(); + let max_appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().capacity(); + let max_lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().capacity(); + + format!( + "\ +auth_chain_cache: {auth_chain_cache} / {max_auth_chain_cache} +our_real_users_cache: {our_real_users_cache} / {max_our_real_users_cache} +appservice_in_room_cache: {appservice_in_room_cache} / {max_appservice_in_room_cache} +lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cache}\n\n +{}", + self.db.memory_usage().unwrap_or_default() + ) + } + + fn clear_caches(&self, amount: u32) { + if amount > 1 { + let c = &mut *self.auth_chain_cache.lock().unwrap(); + *c = LruCache::new(c.capacity()); + } + if amount > 2 { + let c = &mut *self.our_real_users_cache.write().unwrap(); + *c = HashMap::new(); + } + if amount > 3 { + let c = &mut *self.appservice_in_room_cache.write().unwrap(); + *c = HashMap::new(); + } + if amount > 4 { + let c = &mut *self.lasttimelinecount_cache.lock().unwrap(); + *c = HashMap::new(); + } + } + + fn load_keypair(&self) -> Result { + let keypair_bytes = self.global.get(b"keypair")?.map_or_else( + || { + let keypair = utils::generate_keypair(); + self.global.insert(b"keypair", &keypair)?; + Ok::<_, Error>(keypair) + }, + Ok, + )?; + + let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF); + + utils::string_from_bytes( + // 1. version + parts + .next() + .expect("splitn always returns at least one element"), + ) + .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) + .and_then(|version| { + // 2. key + parts + .next() + .ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) + .map(|key| (version, key)) + }) + .and_then(|(version, key)| { + Ed25519KeyPair::from_der(key, version) + .map_err(|_| Error::bad_database("Private or public keys are invalid.")) + }) + } + + fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } + + fn add_signing_key( + &self, origin: &ServerName, new_keys: ServerSigningKeys, + ) -> Result> { + // Not atomic, but this is not critical + let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; + + let mut keys = signingkeys + .and_then(|keys| serde_json::from_slice(&keys).ok()) + .unwrap_or_else(|| { + // Just insert "now", it doesn't matter + ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) + }); + + let ServerSigningKeys { + verify_keys, + old_verify_keys, + .. + } = new_keys; + + keys.verify_keys.extend(verify_keys); + keys.old_verify_keys.extend(old_verify_keys); + + self.server_signingkeys.insert( + origin.as_bytes(), + &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), + )?; + + let mut tree = keys.verify_keys; + tree.extend( + keys.old_verify_keys + .into_iter() + .map(|old| (old.0, VerifyKey::new(old.1.key))), + ); + + Ok(tree) + } + + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found + /// for the server. + fn signing_keys_for(&self, origin: &ServerName) -> Result> { + let signingkeys = self + .server_signingkeys + .get(origin.as_bytes())? + .and_then(|bytes| serde_json::from_slice(&bytes).ok()) + .map_or_else(BTreeMap::new, |keys: ServerSigningKeys| { + let mut tree = keys.verify_keys; + tree.extend( + keys.old_verify_keys + .into_iter() + .map(|old| (old.0, VerifyKey::new(old.1.key))), + ); + tree + }); + + Ok(signingkeys) + } + + fn database_version(&self) -> Result { + self.global.get(b"version")?.map_or(Ok(0), |version| { + utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid.")) + }) + } + + fn bump_database_version(&self, new_version: u64) -> Result<()> { + self.global.insert(b"version", &new_version.to_be_bytes())?; + Ok(()) + } + + fn backup(&self) -> Result<(), Box> { self.db.backup() } + + fn backup_list(&self) -> Result { self.db.backup_list() } + + fn file_list(&self) -> Result { self.db.file_list() } +} diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 6995ab71..8f035fcd 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -9,7 +9,7 @@ use std::{ use argon2::Argon2; use base64::{engine::general_purpose, Engine as _}; -pub use data::Data; +use data::Data; use hickory_resolver::TokioAsyncResolver; use ipaddress::IPAddress; use regex::RegexSet; @@ -29,10 +29,10 @@ use tokio::{ use tracing::{error, trace}; use url::Url; -use crate::{services, Config, Result}; +use crate::{database::Cork, services, Config, Result}; mod client; -pub mod data; +mod data; pub(crate) mod emerg_access; pub(crate) mod migrations; mod resolver; @@ -200,6 +200,10 @@ impl Service { #[allow(dead_code)] pub fn flush(&self) -> Result<()> { self.db.flush() } + pub fn cork(&self) -> Result { self.db.cork() } + + pub fn cork_and_flush(&self) -> Result { self.db.cork_and_flush() } + pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() } pub fn max_request_size(&self) -> u32 { self.config.max_request_size } diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs index ac595a6b..491fd0e4 100644 --- a/src/service/key_backups/data.rs +++ b/src/service/key_backups/data.rs @@ -1,14 +1,17 @@ use std::collections::BTreeMap; use ruma::{ - api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, + api::client::{ + backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, + error::ErrorKind, + }, serde::Raw, OwnedRoomId, RoomId, UserId, }; -use crate::Result; +use crate::{services, utils, Error, KeyValueDatabase, Result}; -pub trait Data: Send + Sync { +pub(crate) trait Data: Send + Sync { fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result; fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>; @@ -45,3 +48,308 @@ pub trait Data: Send + Sync { fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()>; } + +impl Data for KeyValueDatabase { + fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { + let version = services().globals.next_count()?.to_string(); + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + + self.backupid_algorithm.insert( + &key, + &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), + )?; + self.backupid_etag + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + Ok(version) + } + + fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + + self.backupid_algorithm.remove(&key)?; + self.backupid_etag.remove(&key)?; + + key.push(0xFF); + + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; + } + + Ok(()) + } + + fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw) -> Result { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + + if self.backupid_algorithm.get(&key)?.is_none() { + return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); + } + + self.backupid_algorithm + .insert(&key, backup_metadata.json().get().as_bytes())?; + self.backupid_etag + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + Ok(version.to_owned()) + } + + fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.backupid_algorithm + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .next() + .map(|(key, _)| { + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) + }) + .transpose() + } + + fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.backupid_algorithm + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .next() + .map(|(key, value)| { + let version = utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; + + Ok(( + version, + serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?, + )) + }) + .transpose() + } + + fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + + self.backupid_algorithm + .get(&key)? + .map_or(Ok(None), |bytes| { + serde_json::from_slice(&bytes) + .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) + }) + } + + fn add_key( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + + if self.backupid_algorithm.get(&key)?.is_none() { + return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); + } + + self.backupid_etag + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(session_id.as_bytes()); + + self.backupkeyid_backup + .insert(&key, key_data.json().get().as_bytes())?; + + Ok(()) + } + + fn count_keys(&self, user_id: &UserId, version: &str) -> Result { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(version.as_bytes()); + + Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) + } + + fn get_etag(&self, user_id: &UserId, version: &str) -> Result { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + + Ok(utils::u64_from_bytes( + &self + .backupid_etag + .get(&key)? + .ok_or_else(|| Error::bad_database("Backup has no etag."))?, + ) + .map_err(|_| Error::bad_database("etag in backupid_etag invalid."))? + .to_string()) + } + + fn get_all(&self, user_id: &UserId, version: &str) -> Result> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(version.as_bytes()); + prefix.push(0xFF); + + let mut rooms = BTreeMap::::new(); + + for result in self + .backupkeyid_backup + .scan_prefix(prefix) + .map(|(key, value)| { + let mut parts = key.rsplit(|&b| b == 0xFF); + + let session_id = utils::string_from_bytes( + parts + .next() + .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, + ) + .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; + + let room_id = RoomId::parse( + utils::string_from_bytes( + parts + .next() + .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, + ) + .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, + ) + .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?; + + let key_data = serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; + + Ok::<_, Error>((room_id, session_id, key_data)) + }) { + let (room_id, session_id, key_data) = result?; + rooms + .entry(room_id) + .or_insert_with(|| RoomKeyBackup { + sessions: BTreeMap::new(), + }) + .sessions + .insert(session_id, key_data); + } + + Ok(rooms) + } + + fn get_room( + &self, user_id: &UserId, version: &str, room_id: &RoomId, + ) -> Result>> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(version.as_bytes()); + prefix.push(0xFF); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xFF); + + Ok(self + .backupkeyid_backup + .scan_prefix(prefix) + .map(|(key, value)| { + let mut parts = key.rsplit(|&b| b == 0xFF); + + let session_id = utils::string_from_bytes( + parts + .next() + .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, + ) + .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; + + let key_data = serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; + + Ok::<_, Error>((session_id, key_data)) + }) + .filter_map(Result::ok) + .collect()) + } + + fn get_session( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, + ) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(session_id.as_bytes()); + + self.backupkeyid_backup + .get(&key)? + .map(|value| { + serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")) + }) + .transpose() + } + + fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + key.push(0xFF); + + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; + } + + Ok(()) + } + + fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); + + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; + } + + Ok(()) + } + + fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(session_id.as_bytes()); + + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; + } + + Ok(()) + } +} diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index abab604c..9b88e293 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -1,7 +1,7 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; -pub use data::Data; +use data::Data; use ruma::{ api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, serde::Raw, @@ -11,7 +11,7 @@ use ruma::{ use crate::Result; pub struct Service { - pub db: Arc, + pub(super) db: Arc, } impl Service { diff --git a/src/service/key_value/account_data.rs b/src/service/key_value/account_data.rs deleted file mode 100644 index 981f1b8c..00000000 --- a/src/service/key_value/account_data.rs +++ /dev/null @@ -1,137 +0,0 @@ -use std::collections::HashMap; - -use ruma::{ - api::client::error::ErrorKind, - events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, - serde::Raw, - RoomId, UserId, -}; -use tracing::warn; - -use crate::{services, utils, Error, KeyValueDatabase, Result}; - -impl crate::account_data::Data for KeyValueDatabase { - /// Places one event in the account data of the user and removes the - /// previous entry. - #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - fn update( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()> { - let mut prefix = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xFF); - - let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); - roomuserdataid.push(0xFF); - roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); - - let mut key = prefix; - key.extend_from_slice(event_type.to_string().as_bytes()); - - if data.get("type").is_none() || data.get("content").is_none() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Account data doesn't have all required fields.", - )); - } - - self.roomuserdataid_accountdata.insert( - &roomuserdataid, - &serde_json::to_vec(&data).expect("to_vec always works on json values"), - )?; - - let prev = self.roomusertype_roomuserdataid.get(&key)?; - - self.roomusertype_roomuserdataid - .insert(&key, &roomuserdataid)?; - - // Remove old entry - if let Some(prev) = prev { - self.roomuserdataid_accountdata.remove(&prev)?; - } - - Ok(()) - } - - /// Searches the account data for a specific kind. - #[tracing::instrument(skip(self, room_id, user_id, kind))] - fn get( - &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, - ) -> Result>> { - let mut key = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(kind.to_string().as_bytes()); - - self.roomusertype_roomuserdataid - .get(&key)? - .and_then(|roomuserdataid| { - self.roomuserdataid_accountdata - .get(&roomuserdataid) - .transpose() - }) - .transpose()? - .map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize"))) - .transpose() - } - - /// Returns all changes to the account data that happened after `since`. - #[tracing::instrument(skip(self, room_id, user_id, since))] - fn changes_since( - &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result>> { - let mut userdata = HashMap::new(); - - let mut prefix = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xFF); - - // Skip the data that's exactly at since, because we sent that last time - let mut first_possible = prefix.clone(); - first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); - - for r in self - .roomuserdataid_accountdata - .iter_from(&first_possible, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(k, v)| { - Ok::<_, Error>(( - RoomAccountDataEventType::from( - utils::string_from_bytes( - k.rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| Error::bad_database("RoomUserData ID in db is invalid."))?, - ) - .map_err(|e| { - warn!("RoomUserData ID in database is invalid: {}", e); - Error::bad_database("RoomUserData ID in db is invalid.") - })?, - ), - serde_json::from_slice::>(&v) - .map_err(|_| Error::bad_database("Database contains invalid account data."))?, - )) - }) { - let (kind, data) = r?; - userdata.insert(kind, data); - } - - Ok(userdata) - } -} diff --git a/src/service/key_value/appservice.rs b/src/service/key_value/appservice.rs deleted file mode 100644 index f030d5e7..00000000 --- a/src/service/key_value/appservice.rs +++ /dev/null @@ -1,55 +0,0 @@ -use ruma::api::appservice::Registration; - -use crate::{utils, Error, KeyValueDatabase, Result}; - -impl crate::appservice::Data for KeyValueDatabase { - /// Registers an appservice and returns the ID to the caller - fn register_appservice(&self, yaml: Registration) -> Result { - let id = yaml.id.as_str(); - self.id_appserviceregistrations - .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; - - Ok(id.to_owned()) - } - - /// Remove an appservice registration - /// - /// # Arguments - /// - /// * `service_name` - the name you send to register the service previously - fn unregister_appservice(&self, service_name: &str) -> Result<()> { - self.id_appserviceregistrations - .remove(service_name.as_bytes())?; - Ok(()) - } - - fn get_registration(&self, id: &str) -> Result> { - self.id_appserviceregistrations - .get(id.as_bytes())? - .map(|bytes| { - serde_yaml::from_slice(&bytes) - .map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")) - }) - .transpose() - } - - fn iter_ids<'a>(&'a self) -> Result> + 'a>> { - Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| { - utils::string_from_bytes(&id) - .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) - }))) - } - - fn all(&self) -> Result> { - self.iter_ids()? - .filter_map(Result::ok) - .map(move |id| { - Ok(( - id.clone(), - self.get_registration(&id)? - .expect("iter_ids only returns appservices that exist"), - )) - }) - .collect() - } -} diff --git a/src/service/key_value/globals.rs b/src/service/key_value/globals.rs deleted file mode 100644 index e56f9feb..00000000 --- a/src/service/key_value/globals.rs +++ /dev/null @@ -1,299 +0,0 @@ -use std::collections::{BTreeMap, HashMap}; - -use async_trait::async_trait; -use futures_util::{stream::FuturesUnordered, StreamExt}; -use lru_cache::LruCache; -use ruma::{ - api::federation::discovery::{ServerSigningKeys, VerifyKey}, - signatures::Ed25519KeyPair, - DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, -}; -use tracing::trace; - -use crate::{database::Cork, services, utils, Error, KeyValueDatabase, Result}; - -const COUNTER: &[u8] = b"c"; -const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; - -#[async_trait] -impl crate::globals::Data for KeyValueDatabase { - fn next_count(&self) -> Result { - utils::u64_from_bytes(&self.global.increment(COUNTER)?) - .map_err(|_| Error::bad_database("Count has invalid bytes.")) - } - - fn current_count(&self) -> Result { - self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes.")) - }) - } - - fn last_check_for_updates_id(&self) -> Result { - self.global - .get(LAST_CHECK_FOR_UPDATES_COUNT)? - .map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) - }) - } - - fn update_check_for_updates_id(&self, id: u64) -> Result<()> { - self.global - .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; - - Ok(()) - } - - #[allow(unused_qualifications)] // async traits - #[tracing::instrument(skip(self))] - async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - let userid_bytes = user_id.as_bytes().to_vec(); - let mut userid_prefix = userid_bytes.clone(); - userid_prefix.push(0xFF); - - let mut userdeviceid_prefix = userid_prefix.clone(); - userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); - userdeviceid_prefix.push(0xFF); - - let mut futures = FuturesUnordered::new(); - - // Return when *any* user changed their key - // TODO: only send for user they share a room with - futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix)); - - futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); - futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); - futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); - futures.push( - self.userroomid_notificationcount - .watch_prefix(&userid_prefix), - ); - futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); - - // Events for rooms we are in - for room_id in services() - .rooms - .state_cache - .rooms_joined(user_id) - .filter_map(Result::ok) - { - let short_roomid = services() - .rooms - .short - .get_shortroomid(&room_id) - .ok() - .flatten() - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let roomid_bytes = room_id.as_bytes().to_vec(); - let mut roomid_prefix = roomid_bytes.clone(); - roomid_prefix.push(0xFF); - - // PDUs - futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); - - // EDUs - futures.push(Box::pin(async move { - let _result = services().rooms.typing.wait_for_update(&room_id).await; - })); - - futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); - - // Key changes - futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); - - // Room account data - let mut roomuser_prefix = roomid_prefix.clone(); - roomuser_prefix.extend_from_slice(&userid_prefix); - - futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&roomuser_prefix), - ); - } - - let mut globaluserdata_prefix = vec![0xFF]; - globaluserdata_prefix.extend_from_slice(&userid_prefix); - - futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&globaluserdata_prefix), - ); - - // More key changes (used when user is not joined to any rooms) - futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); - - // One time keys - futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); - - futures.push(Box::pin(services().globals.rotate.watch())); - - // Wait until one of them finds something - trace!(futures = futures.len(), "watch started"); - futures.next().await; - trace!(futures = futures.len(), "watch finished"); - - Ok(()) - } - - fn cleanup(&self) -> Result<()> { self.db.cleanup() } - - fn flush(&self) -> Result<()> { self.db.flush() } - - fn cork(&self) -> Result { Ok(Cork::new(&self.db, false, false)) } - - fn cork_and_flush(&self) -> Result { Ok(Cork::new(&self.db, true, false)) } - - fn cork_and_sync(&self) -> Result { Ok(Cork::new(&self.db, true, true)) } - - fn memory_usage(&self) -> String { - let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len(); - let our_real_users_cache = self.our_real_users_cache.read().unwrap().len(); - let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len(); - let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len(); - - let max_auth_chain_cache = self.auth_chain_cache.lock().unwrap().capacity(); - let max_our_real_users_cache = self.our_real_users_cache.read().unwrap().capacity(); - let max_appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().capacity(); - let max_lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().capacity(); - - format!( - "\ -auth_chain_cache: {auth_chain_cache} / {max_auth_chain_cache} -our_real_users_cache: {our_real_users_cache} / {max_our_real_users_cache} -appservice_in_room_cache: {appservice_in_room_cache} / {max_appservice_in_room_cache} -lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cache}\n\n -{}", - self.db.memory_usage().unwrap_or_default() - ) - } - - fn clear_caches(&self, amount: u32) { - if amount > 1 { - let c = &mut *self.auth_chain_cache.lock().unwrap(); - *c = LruCache::new(c.capacity()); - } - if amount > 2 { - let c = &mut *self.our_real_users_cache.write().unwrap(); - *c = HashMap::new(); - } - if amount > 3 { - let c = &mut *self.appservice_in_room_cache.write().unwrap(); - *c = HashMap::new(); - } - if amount > 4 { - let c = &mut *self.lasttimelinecount_cache.lock().unwrap(); - *c = HashMap::new(); - } - } - - fn load_keypair(&self) -> Result { - let keypair_bytes = self.global.get(b"keypair")?.map_or_else( - || { - let keypair = utils::generate_keypair(); - self.global.insert(b"keypair", &keypair)?; - Ok::<_, Error>(keypair) - }, - Ok, - )?; - - let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF); - - utils::string_from_bytes( - // 1. version - parts - .next() - .expect("splitn always returns at least one element"), - ) - .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) - .and_then(|version| { - // 2. key - parts - .next() - .ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) - .map(|key| (version, key)) - }) - .and_then(|(version, key)| { - Ed25519KeyPair::from_der(key, version) - .map_err(|_| Error::bad_database("Private or public keys are invalid.")) - }) - } - - fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } - - fn add_signing_key( - &self, origin: &ServerName, new_keys: ServerSigningKeys, - ) -> Result> { - // Not atomic, but this is not critical - let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; - - let mut keys = signingkeys - .and_then(|keys| serde_json::from_slice(&keys).ok()) - .unwrap_or_else(|| { - // Just insert "now", it doesn't matter - ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) - }); - - let ServerSigningKeys { - verify_keys, - old_verify_keys, - .. - } = new_keys; - - keys.verify_keys.extend(verify_keys); - keys.old_verify_keys.extend(old_verify_keys); - - self.server_signingkeys.insert( - origin.as_bytes(), - &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), - )?; - - let mut tree = keys.verify_keys; - tree.extend( - keys.old_verify_keys - .into_iter() - .map(|old| (old.0, VerifyKey::new(old.1.key))), - ); - - Ok(tree) - } - - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found - /// for the server. - fn signing_keys_for(&self, origin: &ServerName) -> Result> { - let signingkeys = self - .server_signingkeys - .get(origin.as_bytes())? - .and_then(|bytes| serde_json::from_slice(&bytes).ok()) - .map_or_else(BTreeMap::new, |keys: ServerSigningKeys| { - let mut tree = keys.verify_keys; - tree.extend( - keys.old_verify_keys - .into_iter() - .map(|old| (old.0, VerifyKey::new(old.1.key))), - ); - tree - }); - - Ok(signingkeys) - } - - fn database_version(&self) -> Result { - self.global.get(b"version")?.map_or(Ok(0), |version| { - utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid.")) - }) - } - - fn bump_database_version(&self, new_version: u64) -> Result<()> { - self.global.insert(b"version", &new_version.to_be_bytes())?; - Ok(()) - } - - fn backup(&self) -> Result<(), Box> { self.db.backup() } - - fn backup_list(&self) -> Result { self.db.backup_list() } - - fn file_list(&self) -> Result { self.db.file_list() } -} diff --git a/src/service/key_value/key_backups.rs b/src/service/key_value/key_backups.rs deleted file mode 100644 index 82bbdd48..00000000 --- a/src/service/key_value/key_backups.rs +++ /dev/null @@ -1,317 +0,0 @@ -use std::collections::BTreeMap; - -use ruma::{ - api::client::{ - backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, - error::ErrorKind, - }, - serde::Raw, - OwnedRoomId, RoomId, UserId, -}; - -use crate::{services, utils, Error, KeyValueDatabase, Result}; - -impl crate::key_backups::Data for KeyValueDatabase { - fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { - let version = services().globals.next_count()?.to_string(); - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm.insert( - &key, - &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), - )?; - self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; - Ok(version) - } - - fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm.remove(&key)?; - self.backupid_etag.remove(&key)?; - - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - if self.backupid_algorithm.get(&key)?.is_none() { - return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); - } - - self.backupid_algorithm - .insert(&key, backup_metadata.json().get().as_bytes())?; - self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; - Ok(version.to_owned()) - } - - fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.backupid_algorithm - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|(key, _)| { - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) - }) - .transpose() - } - - fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.backupid_algorithm - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|(key, value)| { - let version = utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; - - Ok(( - version, - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?, - )) - }) - .transpose() - } - - fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm - .get(&key)? - .map_or(Ok(None), |bytes| { - serde_json::from_slice(&bytes) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) - }) - } - - fn add_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - if self.backupid_algorithm.get(&key)?.is_none() { - return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); - } - - self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; - - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - self.backupkeyid_backup - .insert(&key, key_data.json().get().as_bytes())?; - - Ok(()) - } - - fn count_keys(&self, user_id: &UserId, version: &str) -> Result { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - - Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) - } - - fn get_etag(&self, user_id: &UserId, version: &str) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - Ok(utils::u64_from_bytes( - &self - .backupid_etag - .get(&key)? - .ok_or_else(|| Error::bad_database("Backup has no etag."))?, - ) - .map_err(|_| Error::bad_database("etag in backupid_etag invalid."))? - .to_string()) - } - - fn get_all(&self, user_id: &UserId, version: &str) -> Result> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xFF); - - let mut rooms = BTreeMap::::new(); - - for result in self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xFF); - - let session_id = utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - - let room_id = RoomId::parse( - utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?; - - let key_data = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; - - Ok::<_, Error>((room_id, session_id, key_data)) - }) { - let (room_id, session_id, key_data) = result?; - rooms - .entry(room_id) - .or_insert_with(|| RoomKeyBackup { - sessions: BTreeMap::new(), - }) - .sessions - .insert(session_id, key_data); - } - - Ok(rooms) - } - - fn get_room( - &self, user_id: &UserId, version: &str, room_id: &RoomId, - ) -> Result>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - Ok(self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xFF); - - let session_id = utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - - let key_data = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; - - Ok::<_, Error>((session_id, key_data)) - }) - .filter_map(Result::ok) - .collect()) - } - - fn get_session( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - self.backupkeyid_backup - .get(&key)? - .map(|value| { - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")) - }) - .transpose() - } - - fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } -} diff --git a/src/service/key_value/media.rs b/src/service/key_value/media.rs deleted file mode 100644 index 1fccb6a5..00000000 --- a/src/service/key_value/media.rs +++ /dev/null @@ -1,235 +0,0 @@ -use conduit::debug_info; -use ruma::api::client::error::ErrorKind; -use tracing::debug; - -use crate::{media::UrlPreviewData, utils::string_from_bytes, Error, KeyValueDatabase, Result}; - -impl crate::media::Data for KeyValueDatabase { - fn create_file_metadata( - &self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, - content_type: Option<&str>, - ) -> Result> { - let mut key = mxc.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&width.to_be_bytes()); - key.extend_from_slice(&height.to_be_bytes()); - key.push(0xFF); - key.extend_from_slice( - content_disposition - .as_ref() - .map(|f| f.as_bytes()) - .unwrap_or_default(), - ); - key.push(0xFF); - key.extend_from_slice( - content_type - .as_ref() - .map(|c| c.as_bytes()) - .unwrap_or_default(), - ); - - self.mediaid_file.insert(&key, &[])?; - - if let Some(user) = sender_user { - let key = mxc.as_bytes().to_vec(); - let user = user.as_bytes().to_vec(); - self.mediaid_user.insert(&key, &user)?; - } - - Ok(key) - } - - fn delete_file_mxc(&self, mxc: String) -> Result<()> { - debug!("MXC URI: {:?}", mxc); - - let mut prefix = mxc.as_bytes().to_vec(); - prefix.push(0xFF); - - debug!("MXC db prefix: {prefix:?}"); - - for (key, _) in self.mediaid_file.scan_prefix(prefix) { - debug!("Deleting key: {:?}", key); - self.mediaid_file.remove(&key)?; - } - - for (key, value) in self.mediaid_user.scan_prefix(mxc.as_bytes().to_vec()) { - if key == mxc.as_bytes().to_vec() { - let user = string_from_bytes(&value).unwrap_or_default(); - - debug_info!("Deleting key \"{key:?}\" which was uploaded by user {user}"); - self.mediaid_user.remove(&key)?; - } - } - - Ok(()) - } - - /// Searches for all files with the given MXC - fn search_mxc_metadata_prefix(&self, mxc: String) -> Result>> { - debug!("MXC URI: {:?}", mxc); - - let mut prefix = mxc.as_bytes().to_vec(); - prefix.push(0xFF); - - let keys: Vec> = self - .mediaid_file - .scan_prefix(prefix) - .map(|(key, _)| key) - .collect(); - - if keys.is_empty() { - return Err(Error::bad_database( - "Failed to find any keys in database with the provided MXC.", - )); - } - - debug!("Got the following keys: {:?}", keys); - - Ok(keys) - } - - fn search_file_metadata( - &self, mxc: String, width: u32, height: u32, - ) -> Result<(Option, Option, Vec)> { - let mut prefix = mxc.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(&width.to_be_bytes()); - prefix.extend_from_slice(&height.to_be_bytes()); - prefix.push(0xFF); - - let (key, _) = self - .mediaid_file - .scan_prefix(prefix) - .next() - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; - - let mut parts = key.rsplit(|&b| b == 0xFF); - - let content_type = parts - .next() - .map(|bytes| { - string_from_bytes(bytes) - .map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode.")) - }) - .transpose()?; - - let content_disposition_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; - - let content_disposition = if content_disposition_bytes.is_empty() { - None - } else { - Some( - string_from_bytes(content_disposition_bytes) - .map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?, - ) - }; - Ok((content_disposition, content_type, key)) - } - - /// Gets all the media keys in our database (this includes all the metadata - /// associated with it such as width, height, content-type, etc) - fn get_all_media_keys(&self) -> Vec> { self.mediaid_file.iter().map(|(key, _)| key).collect() } - - fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) } - - fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> { - let mut value = Vec::::new(); - value.extend_from_slice(×tamp.as_secs().to_be_bytes()); - value.push(0xFF); - value.extend_from_slice( - data.title - .as_ref() - .map(String::as_bytes) - .unwrap_or_default(), - ); - value.push(0xFF); - value.extend_from_slice( - data.description - .as_ref() - .map(String::as_bytes) - .unwrap_or_default(), - ); - value.push(0xFF); - value.extend_from_slice( - data.image - .as_ref() - .map(String::as_bytes) - .unwrap_or_default(), - ); - value.push(0xFF); - value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes()); - value.push(0xFF); - value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes()); - value.push(0xFF); - value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes()); - - self.url_previews.insert(url.as_bytes(), &value) - } - - fn get_url_preview(&self, url: &str) -> Option { - let values = self.url_previews.get(url.as_bytes()).ok()??; - - let mut values = values.split(|&b| b == 0xFF); - - let _ts = values.next(); - /* if we ever decide to use timestamp, this is here. - match values.next().map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) { - Some(0) => None, - x => x, - };*/ - - let title = match values - .next() - .and_then(|b| String::from_utf8(b.to_vec()).ok()) - { - Some(s) if s.is_empty() => None, - x => x, - }; - let description = match values - .next() - .and_then(|b| String::from_utf8(b.to_vec()).ok()) - { - Some(s) if s.is_empty() => None, - x => x, - }; - let image = match values - .next() - .and_then(|b| String::from_utf8(b.to_vec()).ok()) - { - Some(s) if s.is_empty() => None, - x => x, - }; - let image_size = match values - .next() - .map(|b| usize::from_be_bytes(b.try_into().unwrap_or_default())) - { - Some(0) => None, - x => x, - }; - let image_width = match values - .next() - .map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default())) - { - Some(0) => None, - x => x, - }; - let image_height = match values - .next() - .map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default())) - { - Some(0) => None, - x => x, - }; - - Some(UrlPreviewData { - title, - description, - image, - image_size, - image_width, - image_height, - }) - } -} diff --git a/src/service/key_value/mod.rs b/src/service/key_value/mod.rs deleted file mode 100644 index 4391cac5..00000000 --- a/src/service/key_value/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -mod account_data; -//mod admin; -mod appservice; -mod globals; -mod key_backups; -mod media; -//mod pdu; -mod presence; -mod pusher; -mod rooms; -mod sending; -mod transaction_ids; -mod uiaa; -mod users; diff --git a/src/service/key_value/presence.rs b/src/service/key_value/presence.rs deleted file mode 100644 index 9defd06d..00000000 --- a/src/service/key_value/presence.rs +++ /dev/null @@ -1,127 +0,0 @@ -use conduit::debug_info; -use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; - -use crate::{ - presence::Presence, - services, - utils::{self, user_id_from_bytes}, - Error, KeyValueDatabase, Result, -}; - -impl crate::presence::Data for KeyValueDatabase { - fn get_presence(&self, user_id: &UserId) -> Result> { - if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { - let count = utils::u64_from_bytes(&count_bytes) - .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; - - let key = presenceid_key(count, user_id); - self.presenceid_presence - .get(&key)? - .map(|presence_bytes| -> Result<(u64, PresenceEvent)> { - Ok((count, Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)?)) - }) - .transpose() - } else { - Ok(None) - } - } - - fn set_presence( - &self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option, - last_active_ago: Option, status_msg: Option, - ) -> Result<()> { - let last_presence = self.get_presence(user_id)?; - let state_changed = match last_presence { - None => true, - Some(ref presence) => presence.1.content.presence != *presence_state, - }; - - let now = utils::millis_since_unix_epoch(); - let last_last_active_ts = match last_presence { - None => 0, - Some((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()), - }; - - let last_active_ts = match last_active_ago { - None => now, - Some(last_active_ago) => now.saturating_sub(last_active_ago.into()), - }; - - // tighten for state flicker? - if !state_changed && last_active_ts <= last_last_active_ts { - debug_info!( - "presence spam {:?} last_active_ts:{:?} <= {:?}", - user_id, - last_active_ts, - last_last_active_ts - ); - return Ok(()); - } - - let status_msg = if status_msg.as_ref().is_some_and(String::is_empty) { - None - } else { - status_msg - }; - - let presence = Presence::new( - presence_state.to_owned(), - currently_active.unwrap_or(false), - last_active_ts, - status_msg, - ); - let count = services().globals.next_count()?; - let key = presenceid_key(count, user_id); - - self.presenceid_presence - .insert(&key, &presence.to_json_bytes()?)?; - - self.userid_presenceid - .insert(user_id.as_bytes(), &count.to_be_bytes())?; - - if let Some((last_count, _)) = last_presence { - let key = presenceid_key(last_count, user_id); - self.presenceid_presence.remove(&key)?; - } - - Ok(()) - } - - fn remove_presence(&self, user_id: &UserId) -> Result<()> { - if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { - let count = utils::u64_from_bytes(&count_bytes) - .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; - let key = presenceid_key(count, user_id); - self.presenceid_presence.remove(&key)?; - self.userid_presenceid.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - fn presence_since<'a>(&'a self, since: u64) -> Box)> + 'a> { - Box::new( - self.presenceid_presence - .iter() - .flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, Vec)> { - let (count, user_id) = presenceid_parse(&key)?; - Ok((user_id, count, presence_bytes)) - }) - .filter(move |(_, count, _)| *count > since), - ) - } -} - -#[inline] -fn presenceid_key(count: u64, user_id: &UserId) -> Vec { - [count.to_be_bytes().to_vec(), user_id.as_bytes().to_vec()].concat() -} - -#[inline] -fn presenceid_parse(key: &[u8]) -> Result<(u64, OwnedUserId)> { - let (count, user_id) = key.split_at(8); - let user_id = user_id_from_bytes(user_id)?; - let count = utils::u64_from_bytes(count).unwrap(); - - Ok((count, user_id)) -} diff --git a/src/service/key_value/pusher.rs b/src/service/key_value/pusher.rs deleted file mode 100644 index 876b531c..00000000 --- a/src/service/key_value/pusher.rs +++ /dev/null @@ -1,65 +0,0 @@ -use ruma::{ - api::client::push::{set_pusher, Pusher}, - UserId, -}; - -use crate::{utils, Error, KeyValueDatabase, Result}; - -impl crate::pusher::Data for KeyValueDatabase { - fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { - match &pusher { - set_pusher::v3::PusherAction::Post(data) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); - self.senderkey_pusher - .insert(&key, &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"))?; - Ok(()) - }, - set_pusher::v3::PusherAction::Delete(ids) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(ids.pushkey.as_bytes()); - self.senderkey_pusher.remove(&key).map_err(Into::into) - }, - } - } - - fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { - let mut senderkey = sender.as_bytes().to_vec(); - senderkey.push(0xFF); - senderkey.extend_from_slice(pushkey.as_bytes()); - - self.senderkey_pusher - .get(&senderkey)? - .map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) - .transpose() - } - - fn get_pushers(&self, sender: &UserId) -> Result> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xFF); - - self.senderkey_pusher - .scan_prefix(prefix) - .map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) - .collect() - } - - fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { - let mut parts = k.splitn(2, |&b| b == 0xFF); - let _senderkey = parts.next(); - let push_key = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; - let push_key_string = utils::string_from_bytes(push_key) - .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; - - Ok(push_key_string) - })) - } -} diff --git a/src/service/key_value/rooms/alias.rs b/src/service/key_value/rooms/alias.rs deleted file mode 100644 index 402e59fd..00000000 --- a/src/service/key_value/rooms/alias.rs +++ /dev/null @@ -1,75 +0,0 @@ -use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; - -use crate::{services, utils, Error, KeyValueDatabase, Result}; - -impl crate::rooms::alias::Data for KeyValueDatabase { - fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { - self.alias_roomid - .insert(alias.alias().as_bytes(), room_id.as_bytes())?; - let mut aliasid = room_id.as_bytes().to_vec(); - aliasid.push(0xFF); - aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); - self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; - Ok(()) - } - - fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { - if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { - let mut prefix = room_id; - prefix.push(0xFF); - - for (key, _) in self.aliasid_alias.scan_prefix(prefix) { - self.aliasid_alias.remove(&key)?; - } - self.alias_roomid.remove(alias.alias().as_bytes())?; - } else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist.")); - } - Ok(()) - } - - fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { - self.alias_roomid - .get(alias.alias().as_bytes())? - .map(|bytes| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) - }) - .transpose() - } - - fn local_aliases_for_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) - })) - } - - fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { - Box::new( - self.alias_roomid - .iter() - .map(|(room_alias_bytes, room_id_bytes)| { - let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?; - - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?; - - Ok((room_id, room_alias_localpart)) - }), - ) - } -} diff --git a/src/service/key_value/rooms/auth_chain.rs b/src/service/key_value/rooms/auth_chain.rs deleted file mode 100644 index f01ff4aa..00000000 --- a/src/service/key_value/rooms/auth_chain.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::{mem::size_of, sync::Arc}; - -use crate::{utils, KeyValueDatabase, Result}; - -impl crate::rooms::auth_chain::Data for KeyValueDatabase { - fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { - // Check RAM cache - if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { - return Ok(Some(Arc::clone(result))); - } - - // We only save auth chains for single events in the db - if key.len() == 1 { - // Check DB cache - let chain = self - .shorteventid_authchain - .get(&key[0].to_be_bytes())? - .map(|chain| { - chain - .chunks_exact(size_of::()) - .map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct")) - .collect::>() - }); - - if let Some(chain) = chain { - // Cache in RAM - self.auth_chain_cache - .lock() - .unwrap() - .insert(vec![key[0]], Arc::clone(&chain)); - - return Ok(Some(chain)); - } - } - - Ok(None) - } - - fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) -> Result<()> { - // Only persist single events in db - if key.len() == 1 { - self.shorteventid_authchain.insert( - &key[0].to_be_bytes(), - &auth_chain - .iter() - .flat_map(|s| s.to_be_bytes().to_vec()) - .collect::>(), - )?; - } - - // Cache in RAM - self.auth_chain_cache - .lock() - .unwrap() - .insert(key, auth_chain); - - Ok(()) - } -} diff --git a/src/service/key_value/rooms/directory.rs b/src/service/key_value/rooms/directory.rs deleted file mode 100644 index 9265d2c8..00000000 --- a/src/service/key_value/rooms/directory.rs +++ /dev/null @@ -1,23 +0,0 @@ -use ruma::{OwnedRoomId, RoomId}; - -use crate::{utils, Error, KeyValueDatabase, Result}; - -impl crate::rooms::directory::Data for KeyValueDatabase { - fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) } - - fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) } - - fn is_public_room(&self, room_id: &RoomId) -> Result { - Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) - } - - fn public_rooms<'a>(&'a self) -> Box> + 'a> { - Box::new(self.publicroomids.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) - })) - } -} diff --git a/src/service/key_value/rooms/lazy_load.rs b/src/service/key_value/rooms/lazy_load.rs deleted file mode 100644 index 700505cd..00000000 --- a/src/service/key_value/rooms/lazy_load.rs +++ /dev/null @@ -1,53 +0,0 @@ -use ruma::{DeviceId, RoomId, UserId}; - -use crate::{KeyValueDatabase, Result}; - -impl crate::rooms::lazy_loading::Data for KeyValueDatabase { - fn lazy_load_was_sent_before( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, - ) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(ll_user.as_bytes()); - Ok(self.lazyloadedids.get(&key)?.is_some()) - } - - fn lazy_load_confirm_delivery( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, - confirmed_user_ids: &mut dyn Iterator, - ) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - for ll_id in confirmed_user_ids { - let mut key = prefix.clone(); - key.extend_from_slice(ll_id.as_bytes()); - self.lazyloadedids.insert(&key, &[])?; - } - - Ok(()) - } - - fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - for (key, _) in self.lazyloadedids.scan_prefix(prefix) { - self.lazyloadedids.remove(&key)?; - } - - Ok(()) - } -} diff --git a/src/service/key_value/rooms/metadata.rs b/src/service/key_value/rooms/metadata.rs deleted file mode 100644 index ab8c1a78..00000000 --- a/src/service/key_value/rooms/metadata.rs +++ /dev/null @@ -1,76 +0,0 @@ -use ruma::{OwnedRoomId, RoomId}; -use tracing::error; - -use crate::{services, utils, Error, KeyValueDatabase, Result}; - -impl crate::rooms::metadata::Data for KeyValueDatabase { - fn exists(&self, room_id: &RoomId) -> Result { - let prefix = match services().rooms.short.get_shortroomid(room_id)? { - Some(b) => b.to_be_bytes().to_vec(), - None => return Ok(false), - }; - - // Look for PDUs in that room. - Ok(self - .pduid_pdu - .iter_from(&prefix, false) - .next() - .filter(|(k, _)| k.starts_with(&prefix)) - .is_some()) - } - - fn iter_ids<'a>(&'a self) -> Box> + 'a> { - Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) - })) - } - - fn is_disabled(&self, room_id: &RoomId) -> Result { - Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) - } - - fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { - if disabled { - self.disabledroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.disabledroomids.remove(room_id.as_bytes())?; - } - - Ok(()) - } - - fn is_banned(&self, room_id: &RoomId) -> Result { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) } - - fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { - if banned { - self.bannedroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.bannedroomids.remove(room_id.as_bytes())?; - } - - Ok(()) - } - - fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { - Box::new(self.bannedroomids.iter().map( - |(room_id_bytes, _ /* non-banned rooms should not be in this table */)| { - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|e| { - error!("Invalid room_id bytes in bannedroomids: {e}"); - Error::bad_database("Invalid room_id in bannedroomids.") - })? - .try_into() - .map_err(|e| { - error!("Invalid room_id in bannedroomids: {e}"); - Error::bad_database("Invalid room_id in bannedroomids") - })?; - - Ok(room_id) - }, - )) - } -} diff --git a/src/service/key_value/rooms/mod.rs b/src/service/key_value/rooms/mod.rs deleted file mode 100644 index d69cf141..00000000 --- a/src/service/key_value/rooms/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -mod alias; -mod auth_chain; -mod directory; -mod lazy_load; -mod metadata; -mod outlier; -mod pdu_metadata; -mod read_receipt; -mod search; -mod short; -mod state; -mod state_accessor; -mod state_cache; -mod state_compressor; -mod threads; -mod timeline; -mod user; - -use crate::KeyValueDatabase; - -impl crate::rooms::Data for KeyValueDatabase {} diff --git a/src/service/key_value/rooms/outlier.rs b/src/service/key_value/rooms/outlier.rs deleted file mode 100644 index 701e4cb2..00000000 --- a/src/service/key_value/rooms/outlier.rs +++ /dev/null @@ -1,28 +0,0 @@ -use ruma::{CanonicalJsonObject, EventId}; - -use crate::{Error, KeyValueDatabase, PduEvent, Result}; - -impl crate::rooms::outlier::Data for KeyValueDatabase { - fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } - - fn get_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } - - fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.eventid_outlierpdu.insert( - event_id.as_bytes(), - &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), - ) - } -} diff --git a/src/service/key_value/rooms/pdu_metadata.rs b/src/service/key_value/rooms/pdu_metadata.rs deleted file mode 100644 index 225ed1cc..00000000 --- a/src/service/key_value/rooms/pdu_metadata.rs +++ /dev/null @@ -1,79 +0,0 @@ -use std::{mem, sync::Arc}; - -use ruma::{EventId, RoomId, UserId}; - -use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result}; - -impl crate::rooms::pdu_metadata::Data for KeyValueDatabase { - fn add_relation(&self, from: u64, to: u64) -> Result<()> { - let mut key = to.to_be_bytes().to_vec(); - key.extend_from_slice(&from.to_be_bytes()); - self.tofrom_relation.insert(&key, &[])?; - Ok(()) - } - - fn relations_until<'a>( - &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, - ) -> Result> + 'a>> { - let prefix = target.to_be_bytes().to_vec(); - let mut current = prefix.clone(); - - let count_raw = match until { - PduCount::Normal(x) => x.saturating_sub(1), - PduCount::Backfilled(x) => { - current.extend_from_slice(&0_u64.to_be_bytes()); - u64::MAX.saturating_sub(x).saturating_sub(1) - }, - }; - current.extend_from_slice(&count_raw.to_be_bytes()); - - Ok(Box::new( - self.tofrom_relation - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(tofrom, _data)| { - let from = utils::u64_from_bytes(&tofrom[(mem::size_of::())..]) - .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?; - - let mut pduid = shortroomid.to_be_bytes().to_vec(); - pduid.extend_from_slice(&from.to_be_bytes()); - - let mut pdu = services() - .rooms - .timeline - .get_pdu_from_id(&pduid)? - .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((PduCount::Normal(from), pdu)) - }), - )) - } - - fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { - for prev in event_ids { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(prev.as_bytes()); - self.referencedevents.insert(&key, &[])?; - } - - Ok(()) - } - - fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(event_id.as_bytes()); - Ok(self.referencedevents.get(&key)?.is_some()) - } - - fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { - self.softfailedeventids.insert(event_id.as_bytes(), &[]) - } - - fn is_event_soft_failed(&self, event_id: &EventId) -> Result { - self.softfailedeventids - .get(event_id.as_bytes()) - .map(|o| o.is_some()) - } -} diff --git a/src/service/key_value/rooms/read_receipt.rs b/src/service/key_value/rooms/read_receipt.rs deleted file mode 100644 index 6cd913e7..00000000 --- a/src/service/key_value/rooms/read_receipt.rs +++ /dev/null @@ -1,120 +0,0 @@ -use std::mem; - -use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId}; - -use crate::{services, utils, Error, KeyValueDatabase, Result}; - -impl crate::rooms::read_receipt::Data for KeyValueDatabase { - fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - // Remove old entry - if let Some((old, _)) = self - .readreceiptid_readreceipt - .iter_from(&last_possible_key, true) - .take_while(|(key, _)| key.starts_with(&prefix)) - .find(|(key, _)| { - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element") - == user_id.as_bytes() - }) { - // This is the old room_latest - self.readreceiptid_readreceipt.remove(&old)?; - } - - let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); - room_latest_id.push(0xFF); - room_latest_id.extend_from_slice(user_id.as_bytes()); - - self.readreceiptid_readreceipt.insert( - &room_latest_id, - &serde_json::to_vec(&event).expect("EduEvent::to_string always works"), - )?; - - Ok(()) - } - - fn readreceipts_since<'a>( - &'a self, room_id: &RoomId, since: u64, - ) -> Box)>> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - let prefix2 = prefix.clone(); - - let mut first_possible_edu = prefix.clone(); - first_possible_edu.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); // +1 so we don't send the event at since - - Box::new( - self.readreceiptid_readreceipt - .iter_from(&first_possible_edu, false) - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(k, v)| { - let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::()]) - .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; - let user_id = UserId::parse( - utils::string_from_bytes(&k[prefix.len() + mem::size_of::() + 1..]) - .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?, - ) - .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; - - let mut json = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?; - json.remove("room_id"); - - Ok(( - user_id, - count, - Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")), - )) - }), - ) - } - - fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_privateread - .insert(&key, &count.to_be_bytes())?; - - self.roomuserid_lastprivatereadupdate - .insert(&key, &services().globals.next_count()?.to_be_bytes()) - } - - fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_privateread - .get(&key)? - .map_or(Ok(None), |v| { - Ok(Some( - utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?, - )) - }) - } - - fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - Ok(self - .roomuserid_lastprivatereadupdate - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) - }) - .transpose()? - .unwrap_or(0)) - } -} diff --git a/src/service/key_value/rooms/search.rs b/src/service/key_value/rooms/search.rs deleted file mode 100644 index ab826172..00000000 --- a/src/service/key_value/rooms/search.rs +++ /dev/null @@ -1,64 +0,0 @@ -use ruma::RoomId; - -use crate::{services, utils, KeyValueDatabase, Result}; - -type SearchPdusResult<'a> = Result> + 'a>, Vec)>>; - -impl crate::rooms::search::Data for KeyValueDatabase { - fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - let mut batch = message_body - .split_terminator(|c: char| !c.is_alphanumeric()) - .filter(|s| !s.is_empty()) - .filter(|word| word.len() <= 50) - .map(str::to_lowercase) - .map(|word| { - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(word.as_bytes()); - key.push(0xFF); - key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here - (key, Vec::new()) - }); - - self.tokenids.insert_batch(&mut batch) - } - - fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { - let prefix = services() - .rooms - .short - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let words: Vec<_> = search_string - .split_terminator(|c: char| !c.is_alphanumeric()) - .filter(|s| !s.is_empty()) - .map(str::to_lowercase) - .collect(); - - let iterators = words.clone().into_iter().map(move |word| { - let mut prefix2 = prefix.clone(); - prefix2.extend_from_slice(word.as_bytes()); - prefix2.push(0xFF); - let prefix3 = prefix2.clone(); - - let mut last_possible_id = prefix2.clone(); - last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.tokenids - .iter_from(&last_possible_id, true) // Newest pdus first - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(key, _)| key[prefix3.len()..].to_vec()) - }); - - let Some(common_elements) = utils::common_elements(iterators, |a, b| { - // We compare b with a because we reversed the iterator earlier - b.cmp(a) - }) else { - return Ok(None); - }; - - Ok(Some((Box::new(common_elements), words))) - } -} diff --git a/src/service/key_value/rooms/short.rs b/src/service/key_value/rooms/short.rs deleted file mode 100644 index 69d85da4..00000000 --- a/src/service/key_value/rooms/short.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::sync::Arc; - -use ruma::{events::StateEventType, EventId, RoomId}; -use tracing::warn; - -use crate::{services, utils, Error, KeyValueDatabase, Result}; - -impl crate::rooms::short::Data for KeyValueDatabase { - fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { - let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { - utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? - } else { - let shorteventid = services().globals.next_count()?; - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; - shorteventid - }; - - Ok(short) - } - - fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result> { - let mut ret: Vec = Vec::with_capacity(event_ids.len()); - let keys = event_ids - .iter() - .map(|id| id.as_bytes()) - .collect::>(); - for (i, short) in self - .eventid_shorteventid - .multi_get(&keys)? - .iter() - .enumerate() - { - match short { - Some(short) => ret.push( - utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, - ), - None => { - let short = services().globals.next_count()?; - self.eventid_shorteventid - .insert(keys[i], &short.to_be_bytes())?; - self.shorteventid_eventid - .insert(&short.to_be_bytes(), keys[i])?; - - debug_assert!(ret.len() == i, "position of result must match input"); - ret.push(short); - }, - } - } - - Ok(ret) - } - - fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { - let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); - statekey_vec.push(0xFF); - statekey_vec.extend_from_slice(state_key.as_bytes()); - - let short = self - .statekey_shortstatekey - .get(&statekey_vec)? - .map(|shortstatekey| { - utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) - }) - .transpose()?; - - Ok(short) - } - - fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { - let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); - statekey_vec.push(0xFF); - statekey_vec.extend_from_slice(state_key.as_bytes()); - - let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? { - utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))? - } else { - let shortstatekey = services().globals.next_count()?; - self.statekey_shortstatekey - .insert(&statekey_vec, &shortstatekey.to_be_bytes())?; - self.shortstatekey_statekey - .insert(&shortstatekey.to_be_bytes(), &statekey_vec)?; - shortstatekey - }; - - Ok(short) - } - - fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - let bytes = self - .shorteventid_eventid - .get(&shorteventid.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; - - let event_id = EventId::parse_arc( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; - - Ok(event_id) - } - - fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - let bytes = self - .shortstatekey_statekey - .get(&shortstatekey.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; - - let mut parts = bytes.splitn(2, |&b| b == 0xFF); - let eventtype_bytes = parts.next().expect("split always returns one entry"); - let statekey_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; - - let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| { - warn!("Event type in shortstatekey_statekey is invalid: {}", e); - Error::bad_database("Event type in shortstatekey_statekey is invalid.") - })?); - - let state_key = utils::string_from_bytes(statekey_bytes) - .map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?; - - let result = (event_type, state_key); - - Ok(result) - } - - /// Returns (shortstatehash, already_existed) - fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? { - ( - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, - true, - ) - } else { - let shortstatehash = services().globals.next_count()?; - self.statehash_shortstatehash - .insert(state_hash, &shortstatehash.to_be_bytes())?; - (shortstatehash, false) - }) - } - - fn get_shortroomid(&self, room_id: &RoomId) -> Result> { - self.roomid_shortroomid - .get(room_id.as_bytes())? - .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))) - .transpose() - } - - fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { - Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { - utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))? - } else { - let short = services().globals.next_count()?; - self.roomid_shortroomid - .insert(room_id.as_bytes(), &short.to_be_bytes())?; - short - }) - } -} diff --git a/src/service/key_value/rooms/state.rs b/src/service/key_value/rooms/state.rs deleted file mode 100644 index f7637c57..00000000 --- a/src/service/key_value/rooms/state.rs +++ /dev/null @@ -1,73 +0,0 @@ -use std::{collections::HashSet, sync::Arc}; - -use ruma::{EventId, OwnedEventId, RoomId}; -use tokio::sync::MutexGuard; - -use crate::{utils, Error, KeyValueDatabase, Result}; - -impl crate::rooms::state::Data for KeyValueDatabase { - fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.roomid_shortstatehash - .get(room_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") - })?)) - }) - } - - fn set_room_state( - &self, - room_id: &RoomId, - new_shortstatehash: u64, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - self.roomid_shortstatehash - .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; - Ok(()) - } - - fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { - self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; - Ok(()) - } - - fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - self.roomid_pduleaves - .scan_prefix(prefix) - .map(|(_, bytes)| { - EventId::parse_arc( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) - }) - .collect() - } - - fn set_forward_extremities( - &self, - room_id: &RoomId, - event_ids: Vec, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { - self.roomid_pduleaves.remove(&key)?; - } - - for event_id in event_ids { - let mut key = prefix.clone(); - key.extend_from_slice(event_id.as_bytes()); - self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; - } - - Ok(()) - } -} diff --git a/src/service/key_value/rooms/state_accessor.rs b/src/service/key_value/rooms/state_accessor.rs deleted file mode 100644 index c36fd1cf..00000000 --- a/src/service/key_value/rooms/state_accessor.rs +++ /dev/null @@ -1,165 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use async_trait::async_trait; -use ruma::{events::StateEventType, EventId, RoomId}; - -use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; - -#[async_trait] -impl crate::rooms::state_accessor::Data for KeyValueDatabase { - #[allow(unused_qualifications)] // async traits - async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { - let full_state = services() - .rooms - .state_compressor - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - let mut result = HashMap::new(); - let mut i: u8 = 0; - for compressed in full_state.iter() { - let parsed = services() - .rooms - .state_compressor - .parse_compressed_state_event(compressed)?; - result.insert(parsed.0, parsed.1); - - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - Ok(result) - } - - #[allow(unused_qualifications)] // async traits - async fn state_full(&self, shortstatehash: u64) -> Result>> { - let full_state = services() - .rooms - .state_compressor - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - - let mut result = HashMap::new(); - let mut i: u8 = 0; - for compressed in full_state.iter() { - let (_, eventid) = services() - .rooms - .state_compressor - .parse_compressed_state_event(compressed)?; - if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? { - result.insert( - ( - pdu.kind.to_string().into(), - pdu.state_key - .as_ref() - .ok_or_else(|| Error::bad_database("State event has no state key."))? - .clone(), - ), - pdu, - ); - } - - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - - Ok(result) - } - - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - fn state_get_id( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - let Some(shortstatekey) = services() - .rooms - .short - .get_shortstatekey(event_type, state_key)? - else { - return Ok(None); - }; - let full_state = services() - .rooms - .state_compressor - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - Ok(full_state - .iter() - .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - .and_then(|compressed| { - services() - .rooms - .state_compressor - .parse_compressed_state_event(compressed) - .ok() - .map(|(_, id)| id) - })) - } - - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - fn state_get( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id)) - } - - /// Returns the state hash for this pdu. - fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { - self.eventid_shorteventid - .get(event_id.as_bytes())? - .map_or(Ok(None), |shorteventid| { - self.shorteventid_shortstatehash - .get(&shorteventid)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash") - }) - }) - .transpose() - }) - } - - /// Returns the full room state. - #[allow(unused_qualifications)] // async traits - async fn room_state_full(&self, room_id: &RoomId) -> Result>> { - if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { - self.state_full(current_shortstatehash).await - } else { - Ok(HashMap::new()) - } - } - - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - fn room_state_get_id( - &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { - self.state_get_id(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } - } - - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - fn room_state_get( - &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { - self.state_get(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } - } -} diff --git a/src/service/key_value/rooms/state_cache.rs b/src/service/key_value/rooms/state_cache.rs deleted file mode 100644 index 795da576..00000000 --- a/src/service/key_value/rooms/state_cache.rs +++ /dev/null @@ -1,626 +0,0 @@ -use std::{collections::HashSet, sync::Arc}; - -use itertools::Itertools; -use ruma::{ - events::{AnyStrippedStateEvent, AnySyncStateEvent}, - serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, -}; -use tracing::error; - -use crate::{ - appservice::RegistrationInfo, - services, user_is_local, - utils::{self}, - Error, KeyValueDatabase, Result, -}; - -type StrippedStateEventIter<'a> = Box>)>> + 'a>; - -type AnySyncStateEventIter<'a> = Box>)>> + 'a>; - -impl crate::rooms::state_cache::Data for KeyValueDatabase { - fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - self.roomuseroncejoinedids.insert(&userroom_id, &[]) - } - - fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let roomid = room_id.as_bytes().to_vec(); - let mut roomid_prefix = room_id.as_bytes().to_vec(); - roomid_prefix.push(0xFF); - - let mut roomuser_id = roomid_prefix.clone(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_joined.insert(&userroom_id, &[])?; - self.roomuserid_joined.insert(&roomuser_id, &[])?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - if self - .roomuserid_joined - .scan_prefix(roomid_prefix.clone()) - .count() == 0 - && self - .roomuserid_invitecount - .scan_prefix(roomid_prefix) - .count() == 0 - { - self.roomid_inviteviaservers.remove(&roomid)?; - } - - Ok(()) - } - - fn mark_as_invited( - &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, - invite_via: Option>, - ) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_invitestate.insert( - &userroom_id, - &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), - )?; - self.roomuserid_invitecount - .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - if let Some(servers) = invite_via { - let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); - #[allow(clippy::redundant_clone)] // this is a necessary clone? - prev_servers.append(servers.clone().as_mut()); - let servers = prev_servers.iter().rev().unique().rev().collect_vec(); - - let servers = servers - .iter() - .map(|server| server.as_bytes()) - .collect_vec() - .join(&[0xFF][..]); - - self.roomid_inviteviaservers - .insert(room_id.as_bytes(), &servers)?; - } - - Ok(()) - } - - fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let roomid = room_id.as_bytes().to_vec(); - let mut roomid_prefix = room_id.as_bytes().to_vec(); - roomid_prefix.push(0xFF); - - let mut roomuser_id = roomid_prefix.clone(); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_leftstate.insert( - &userroom_id, - &serde_json::to_vec(&Vec::>::new()).unwrap(), - )?; // TODO - self.roomuserid_leftcount - .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - - if self - .roomuserid_joined - .scan_prefix(roomid_prefix.clone()) - .count() == 0 - && self - .roomuserid_invitecount - .scan_prefix(roomid_prefix) - .count() == 0 - { - self.roomid_inviteviaservers.remove(&roomid)?; - } - - Ok(()) - } - - fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { - let mut joinedcount = 0_u64; - let mut invitedcount = 0_u64; - let mut joined_servers = HashSet::new(); - let mut real_users = HashSet::new(); - - for joined in self.room_members(room_id).filter_map(Result::ok) { - joined_servers.insert(joined.server_name().to_owned()); - if user_is_local(&joined) && !services().users.is_deactivated(&joined).unwrap_or(true) { - real_users.insert(joined); - } - joinedcount = joinedcount.saturating_add(1); - } - - for _invited in self.room_members_invited(room_id).filter_map(Result::ok) { - invitedcount = invitedcount.saturating_add(1); - } - - self.roomid_joinedcount - .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; - - self.roomid_invitedcount - .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; - - self.our_real_users_cache - .write() - .unwrap() - .insert(room_id.to_owned(), Arc::new(real_users)); - - for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) { - if !joined_servers.remove(&old_joined_server) { - // Server not in room anymore - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xFF); - roomserver_id.extend_from_slice(old_joined_server.as_bytes()); - - let mut serverroom_id = old_joined_server.as_bytes().to_vec(); - serverroom_id.push(0xFF); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; - } - } - - // Now only new servers are in joined_servers anymore - for server in joined_servers { - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xFF); - roomserver_id.extend_from_slice(server.as_bytes()); - - let mut serverroom_id = server.as_bytes().to_vec(); - serverroom_id.push(0xFF); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } - - self.appservice_in_room_cache - .write() - .unwrap() - .remove(room_id); - - Ok(()) - } - - #[tracing::instrument(skip(self, room_id))] - fn get_our_real_users(&self, room_id: &RoomId) -> Result>> { - let maybe = self - .our_real_users_cache - .read() - .unwrap() - .get(room_id) - .cloned(); - if let Some(users) = maybe { - Ok(users) - } else { - self.update_joined_count(room_id)?; - Ok(Arc::clone( - self.our_real_users_cache - .read() - .unwrap() - .get(room_id) - .unwrap(), - )) - } - } - - #[tracing::instrument(skip(self, room_id, appservice))] - fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { - let maybe = self - .appservice_in_room_cache - .read() - .unwrap() - .get(room_id) - .and_then(|map| map.get(&appservice.registration.id)) - .copied(); - - if let Some(b) = maybe { - Ok(b) - } else { - let bridge_user_id = UserId::parse_with_server_name( - appservice.registration.sender_localpart.as_str(), - services().globals.server_name(), - ) - .ok(); - - let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) - || self - .room_members(room_id) - .any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str()))); - - self.appservice_in_room_cache - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default() - .insert(appservice.registration.id.clone(), in_room); - - Ok(in_room) - } - } - - /// Makes a user forget a room. - #[tracing::instrument(skip(self))] - fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - Ok(()) - } - - /// Returns an iterator of all servers participating in this room. - #[tracing::instrument(skip(self))] - fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { - ServerName::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) - })) - } - - #[tracing::instrument(skip(self))] - fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { - let mut key = server.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - - self.serverroomids.get(&key).map(|o| o.is_some()) - } - - /// Returns an iterator of all rooms a server participates in (as far as we - /// know). - #[tracing::instrument(skip(self))] - fn server_rooms<'a>(&'a self, server: &ServerName) -> Box> + 'a> { - let mut prefix = server.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) - })) - } - - /// Returns an iterator over all joined members of a room. - #[tracing::instrument(skip(self))] - fn room_members<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) - })) - } - - /// Returns the number of users which are currently in a room - #[tracing::instrument(skip(self))] - fn room_joined_count(&self, room_id: &RoomId) -> Result> { - self.roomid_joinedcount - .get(room_id.as_bytes())? - .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) - .transpose() - } - - /// Returns the number of users which are currently invited to a room - #[tracing::instrument(skip(self))] - fn room_invited_count(&self, room_id: &RoomId) -> Result> { - self.roomid_invitedcount - .get(room_id.as_bytes())? - .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) - .transpose() - } - - /// Returns an iterator over all User IDs who ever joined a room. - #[tracing::instrument(skip(self))] - fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.roomuseroncejoinedids - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) - }), - ) - } - - /// Returns an iterator over all invited members of a room. - #[tracing::instrument(skip(self))] - fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.roomuserid_invitecount - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) - }), - ) - } - - #[tracing::instrument(skip(self))] - fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_invitecount - .get(&key)? - .map_or(Ok(None), |bytes| { - Ok(Some( - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?, - )) - }) - } - - #[tracing::instrument(skip(self))] - fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_leftcount - .get(&key)? - .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db."))) - .transpose() - } - - /// Returns an iterator over all rooms this user joined. - #[tracing::instrument(skip(self))] - fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { - Box::new( - self.userroomid_joined - .scan_prefix(user_id.as_bytes().to_vec()) - .map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) - }), - ) - } - - /// Returns an iterator over all rooms a user was invited to. - #[tracing::instrument(skip(self))] - fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.userroomid_invitestate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - - Ok((room_id, state)) - }), - ) - } - - #[tracing::instrument(skip(self))] - fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - - self.userroomid_invitestate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - - Ok(state) - }) - .transpose() - } - - #[tracing::instrument(skip(self))] - fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - - self.userroomid_leftstate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok(state) - }) - .transpose() - } - - /// Returns an iterator over all rooms a user left. - #[tracing::instrument(skip(self))] - fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.userroomid_leftstate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok((room_id, state)) - }), - ) - } - - #[tracing::instrument(skip(self))] - fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self))] - fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self))] - fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self))] - fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self))] - fn servers_invite_via(&self, room_id: &RoomId) -> Result>> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - - self.roomid_inviteviaservers - .get(&key)? - .map(|servers| { - let state = serde_json::from_slice(&servers).map_err(|e| { - error!("Invalid state in userroomid_leftstate: {e}"); - Error::bad_database("Invalid state in userroomid_leftstate.") - })?; - - Ok(state) - }) - .transpose() - } - - #[tracing::instrument(skip(self))] - fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { - let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); - prev_servers.append(servers.to_owned().as_mut()); - - let servers = prev_servers.iter().rev().unique().rev().collect_vec(); - - let servers = servers - .iter() - .map(|server| server.as_bytes()) - .collect_vec() - .join(&[0xFF][..]); - - self.roomid_inviteviaservers - .insert(room_id.as_bytes(), &servers)?; - - Ok(()) - } -} diff --git a/src/service/key_value/rooms/state_compressor.rs b/src/service/key_value/rooms/state_compressor.rs deleted file mode 100644 index bc0a2c33..00000000 --- a/src/service/key_value/rooms/state_compressor.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::{collections::HashSet, mem::size_of, sync::Arc}; - -use crate::{rooms::state_compressor::data::StateDiff, utils, Error, KeyValueDatabase, Result}; - -impl crate::rooms::state_compressor::Data for KeyValueDatabase { - fn get_statediff(&self, shortstatehash: u64) -> Result { - let value = self - .shortstatehash_statediff - .get(&shortstatehash.to_be_bytes())? - .ok_or_else(|| Error::bad_database("State hash does not exist"))?; - let parent = utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); - let parent = if parent != 0 { - Some(parent) - } else { - None - }; - - let mut add_mode = true; - let mut added = HashSet::new(); - let mut removed = HashSet::new(); - - let mut i = size_of::(); - while let Some(v) = value.get(i..i + 2 * size_of::()) { - if add_mode && v.starts_with(&0_u64.to_be_bytes()) { - add_mode = false; - i += size_of::(); - continue; - } - if add_mode { - added.insert(v.try_into().expect("we checked the size above")); - } else { - removed.insert(v.try_into().expect("we checked the size above")); - } - i += 2 * size_of::(); - } - - Ok(StateDiff { - parent, - added: Arc::new(added), - removed: Arc::new(removed), - }) - } - - fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { - let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); - for new in diff.added.iter() { - value.extend_from_slice(&new[..]); - } - - if !diff.removed.is_empty() { - value.extend_from_slice(&0_u64.to_be_bytes()); - for removed in diff.removed.iter() { - value.extend_from_slice(&removed[..]); - } - } - - self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value) - } -} diff --git a/src/service/key_value/rooms/threads.rs b/src/service/key_value/rooms/threads.rs deleted file mode 100644 index 9f0aad3a..00000000 --- a/src/service/key_value/rooms/threads.rs +++ /dev/null @@ -1,75 +0,0 @@ -use std::mem; - -use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; - -use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; - -type PduEventIterResult<'a> = Result> + 'a>>; - -impl crate::rooms::threads::Data for KeyValueDatabase { - fn threads_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, - ) -> PduEventIterResult<'a> { - let prefix = services() - .rooms - .short - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut current = prefix.clone(); - current.extend_from_slice(&(until - 1).to_be_bytes()); - - Ok(Box::new( - self.threadid_userids - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pduid, _users)| { - let count = utils::u64_from_bytes(&pduid[(mem::size_of::())..]) - .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; - let mut pdu = services() - .rooms - .timeline - .get_pdu_from_id(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((count, pdu)) - }), - )) - } - - fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { - let users = participants - .iter() - .map(|user| user.as_bytes()) - .collect::>() - .join(&[0xFF][..]); - - self.threadid_userids.insert(root_id, &users)?; - - Ok(()) - } - - fn get_participants(&self, root_id: &[u8]) -> Result>> { - if let Some(users) = self.threadid_userids.get(root_id)? { - Ok(Some( - users - .split(|b| *b == 0xFF) - .map(|bytes| { - UserId::parse( - utils::string_from_bytes(bytes) - .map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?, - ) - .map_err(|_| Error::bad_database("Invalid UserId in threadid_userids.")) - }) - .filter_map(Result::ok) - .collect(), - )) - } else { - Ok(None) - } - } -} diff --git a/src/service/key_value/rooms/timeline.rs b/src/service/key_value/rooms/timeline.rs deleted file mode 100644 index 7f22354c..00000000 --- a/src/service/key_value/rooms/timeline.rs +++ /dev/null @@ -1,295 +0,0 @@ -use std::{collections::hash_map, mem::size_of, sync::Arc}; - -use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; -use tracing::error; - -use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result}; - -impl crate::rooms::timeline::Data for KeyValueDatabase { - fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { - match self - .lasttimelinecount_cache - .lock() - .unwrap() - .entry(room_id.to_owned()) - { - hash_map::Entry::Vacant(v) => { - if let Some(last_count) = self - .pdus_until(sender_user, room_id, PduCount::max())? - .find_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) { - Ok(*v.insert(last_count.0)) - } else { - Ok(PduCount::Normal(0)) - } - }, - hash_map::Entry::Occupied(o) => Ok(*o.get()), - } - } - - /// Returns the `count` of this pdu's id. - fn get_pdu_count(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pdu_id| pdu_count(&pdu_id)) - .transpose() - } - - /// Returns the json of a pdu. - fn get_pdu_json(&self, event_id: &EventId) -> Result> { - self.get_non_outlier_pdu_json(event_id)?.map_or_else( - || { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() - }, - |x| Ok(Some(x)), - ) - } - - /// Returns the json of a pdu. - fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() - } - - /// Returns the pdu's id. - fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.eventid_pduid.get(event_id.as_bytes()) } - - /// Returns the pdu. - fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() - } - - /// Returns the pdu. - /// - /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - fn get_pdu(&self, event_id: &EventId) -> Result>> { - if let Some(pdu) = self - .get_non_outlier_pdu(event_id)? - .map_or_else( - || { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() - }, - |x| Ok(Some(x)), - )? - .map(Arc::new) - { - Ok(Some(pdu)) - } else { - Ok(None) - } - } - - /// Returns the pdu. - /// - /// This does __NOT__ check the outliers `Tree`. - fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) - } - - /// Returns the pdu as a `BTreeMap`. - fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) - } - - fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - )?; - - self.lasttimelinecount_cache - .lock() - .unwrap() - .insert(pdu.room_id.clone(), PduCount::Normal(count)); - - self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; - self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; - - Ok(()) - } - - fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - )?; - - self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?; - self.eventid_outlierpdu.remove(event_id.as_bytes())?; - - Ok(()) - } - - /// Removes a pdu and creates a new one with the same id. - fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> { - if self.pduid_pdu.get(pdu_id)?.is_some() { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"), - )?; - } else { - return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist.")); - } - - Ok(()) - } - - /// Returns an iterator over all events and their tokens in a room that - /// happened before the event with id `until` in reverse-chronological - /// order. - fn pdus_until<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, - ) -> Result> + 'a>> { - let (prefix, current) = count_to_id(room_id, until, 1, true)?; - - let user_id = user_id.to_owned(); - - Ok(Box::new( - self.pduid_pdu - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - pdu.add_age()?; - let count = pdu_count(&pdu_id)?; - Ok((count, pdu)) - }), - )) - } - - fn pdus_after<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, - ) -> Result> + 'a>> { - let (prefix, current) = count_to_id(room_id, from, 1, false)?; - - let user_id = user_id.to_owned(); - - Ok(Box::new( - self.pduid_pdu - .iter_from(¤t, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - pdu.add_age()?; - let count = pdu_count(&pdu_id)?; - Ok((count, pdu)) - }), - )) - } - - fn increment_notification_counts( - &self, room_id: &RoomId, notifies: Vec, highlights: Vec, - ) -> Result<()> { - let mut notifies_batch = Vec::new(); - let mut highlights_batch = Vec::new(); - for user in notifies { - let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - notifies_batch.push(userroom_id); - } - for user in highlights { - let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - highlights_batch.push(userroom_id); - } - - self.userroomid_notificationcount - .increment_batch(&mut notifies_batch.into_iter())?; - self.userroomid_highlightcount - .increment_batch(&mut highlights_batch.into_iter())?; - Ok(()) - } -} - -/// Returns the `count` of this pdu's id. -fn pdu_count(pdu_id: &[u8]) -> Result { - let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) - .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; - let second_last_u64 = - utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::()..pdu_id.len() - size_of::()]); - - if matches!(second_last_u64, Ok(0)) { - Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64))) - } else { - Ok(PduCount::Normal(last_u64)) - } -} - -fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec, Vec)> { - let prefix = services() - .rooms - .short - .get_shortroomid(room_id)? - .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? - .to_be_bytes() - .to_vec(); - let mut pdu_id = prefix.clone(); - // +1 so we don't send the base event - let count_raw = match count { - PduCount::Normal(x) => { - if subtract { - x.saturating_sub(offset) - } else { - x.saturating_add(offset) - } - }, - PduCount::Backfilled(x) => { - pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - let num = u64::MAX.saturating_sub(x); - if subtract { - num.saturating_sub(offset) - } else { - num.saturating_add(offset) - } - }, - }; - pdu_id.extend_from_slice(&count_raw.to_be_bytes()); - - Ok((prefix, pdu_id)) -} diff --git a/src/service/key_value/rooms/user.rs b/src/service/key_value/rooms/user.rs deleted file mode 100644 index a49dc815..00000000 --- a/src/service/key_value/rooms/user.rs +++ /dev/null @@ -1,137 +0,0 @@ -use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; - -use crate::{services, utils, Error, KeyValueDatabase, Result}; - -impl crate::rooms::user::Data for KeyValueDatabase { - fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - self.userroomid_notificationcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; - self.userroomid_highlightcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; - - self.roomuserid_lastnotificationread - .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; - - Ok(()) - } - - fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_notificationcount - .get(&userroom_id)? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db.")) - }) - } - - fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_highlightcount - .get(&userroom_id)? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db.")) - }) - } - - fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - Ok(self - .roomuserid_lastnotificationread - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) - }) - .transpose()? - .unwrap_or(0)) - } - - fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { - let shortroomid = services() - .rooms - .short - .get_shortroomid(room_id)? - .expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); - - self.roomsynctoken_shortstatehash - .insert(&key, &shortstatehash.to_be_bytes()) - } - - fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - let shortroomid = services() - .rooms - .short - .get_shortroomid(room_id)? - .expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); - - self.roomsynctoken_shortstatehash - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")) - }) - .transpose() - } - - fn get_shared_rooms<'a>( - &'a self, users: Vec, - ) -> Result> + 'a>> { - let iterators = users.into_iter().map(move |user_id| { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - self.userroomid_joined - .scan_prefix(prefix) - .map(|(key, _)| { - let roomid_index = key - .iter() - .enumerate() - .find(|(_, &b)| b == 0xFF) - .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? - .0 - .saturating_add(1); // +1 because the room id starts AFTER the separator - - let room_id = key[roomid_index..].to_vec(); - - Ok::<_, Error>(room_id) - }) - .filter_map(Result::ok) - }); - - // We use the default compare function because keys are sorted correctly (not - // reversed) - Ok(Box::new( - utils::common_elements(iterators, Ord::cmp) - .expect("users is not empty") - .map(|bytes| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?, - ) - .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - }), - )) - } -} diff --git a/src/service/key_value/sending.rs b/src/service/key_value/sending.rs deleted file mode 100644 index 544d37c9..00000000 --- a/src/service/key_value/sending.rs +++ /dev/null @@ -1,191 +0,0 @@ -use ruma::{ServerName, UserId}; - -use crate::{ - sending::{Destination, SendingEvent}, - services, utils, Error, KeyValueDatabase, Result, -}; - -impl crate::sending::Data for KeyValueDatabase { - fn active_requests<'a>(&'a self) -> Box, Destination, SendingEvent)>> + 'a> { - Box::new( - self.servercurrentevent_data - .iter() - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), - ) - } - - fn active_requests_for<'a>( - &'a self, destination: &Destination, - ) -> Box, SendingEvent)>> + 'a> { - let prefix = destination.get_prefix(); - Box::new( - self.servercurrentevent_data - .scan_prefix(prefix) - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))), - ) - } - - fn delete_active_request(&self, key: Vec) -> Result<()> { self.servercurrentevent_data.remove(&key) } - - fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> { - let prefix = destination.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { - self.servercurrentevent_data.remove(&key)?; - } - - Ok(()) - } - - fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> { - let prefix = destination.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { - self.servercurrentevent_data.remove(&key).unwrap(); - } - - for (key, _) in self.servernameevent_data.scan_prefix(prefix) { - self.servernameevent_data.remove(&key).unwrap(); - } - - Ok(()) - } - - fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result>> { - let mut batch = Vec::new(); - let mut keys = Vec::new(); - for (destination, event) in requests { - let mut key = destination.get_prefix(); - if let SendingEvent::Pdu(value) = &event { - key.extend_from_slice(value); - } else { - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); - } - let value = if let SendingEvent::Edu(value) = &event { - &**value - } else { - &[] - }; - batch.push((key.clone(), value.to_owned())); - keys.push(key); - } - self.servernameevent_data - .insert_batch(&mut batch.into_iter())?; - Ok(keys) - } - - fn queued_requests<'a>( - &'a self, destination: &Destination, - ) -> Box)>> + 'a> { - let prefix = destination.get_prefix(); - return Box::new( - self.servernameevent_data - .scan_prefix(prefix) - .map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))), - ); - } - - fn mark_as_active(&self, events: &[(SendingEvent, Vec)]) -> Result<()> { - for (e, key) in events { - if key.is_empty() { - continue; - } - - let value = if let SendingEvent::Edu(value) = &e { - &**value - } else { - &[] - }; - self.servercurrentevent_data.insert(key, value)?; - self.servernameevent_data.remove(key)?; - } - - Ok(()) - } - - fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { - self.servername_educount - .insert(server_name.as_bytes(), &last_count.to_be_bytes()) - } - - fn get_latest_educount(&self, server_name: &ServerName) -> Result { - self.servername_educount - .get(server_name.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) - }) - } -} - -#[tracing::instrument(skip(key))] -fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, SendingEvent)> { - // Appservices start with a plus - Ok::<_, Error>(if key.starts_with(b"+") { - let mut parts = key[1..].splitn(2, |&b| b == 0xFF); - - let server = parts.next().expect("splitn always returns one element"); - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - - let server = utils::string_from_bytes(server) - .map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?; - - ( - Destination::Appservice(server), - if value.is_empty() { - SendingEvent::Pdu(event.to_vec()) - } else { - SendingEvent::Edu(value) - }, - ) - } else if key.starts_with(b"$") { - let mut parts = key[1..].splitn(3, |&b| b == 0xFF); - - let user = parts.next().expect("splitn always returns one element"); - let user_string = utils::string_from_bytes(user) - .map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?; - let user_id = - UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?; - - let pushkey = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - let pushkey_string = utils::string_from_bytes(pushkey) - .map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?; - - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - - ( - Destination::Push(user_id, pushkey_string), - if value.is_empty() { - SendingEvent::Pdu(event.to_vec()) - } else { - // I'm pretty sure this should never be called - SendingEvent::Edu(value) - }, - ) - } else { - let mut parts = key.splitn(2, |&b| b == 0xFF); - - let server = parts.next().expect("splitn always returns one element"); - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - - let server = utils::string_from_bytes(server) - .map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?; - - ( - Destination::Normal( - ServerName::parse(server) - .map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?, - ), - if value.is_empty() { - SendingEvent::Pdu(event.to_vec()) - } else { - SendingEvent::Edu(value) - }, - ) - }) -} diff --git a/src/service/key_value/transaction_ids.rs b/src/service/key_value/transaction_ids.rs deleted file mode 100644 index 2dfcdfb1..00000000 --- a/src/service/key_value/transaction_ids.rs +++ /dev/null @@ -1,32 +0,0 @@ -use ruma::{DeviceId, TransactionId, UserId}; - -use crate::{KeyValueDatabase, Result}; - -impl crate::transaction_ids::Data for KeyValueDatabase { - fn add_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xFF); - key.extend_from_slice(txn_id.as_bytes()); - - self.userdevicetxnid_response.insert(&key, data)?; - - Ok(()) - } - - fn existing_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, - ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xFF); - key.extend_from_slice(txn_id.as_bytes()); - - // If there's no entry, this is a new transaction - self.userdevicetxnid_response.get(&key) - } -} diff --git a/src/service/key_value/uiaa.rs b/src/service/key_value/uiaa.rs deleted file mode 100644 index 801f08c1..00000000 --- a/src/service/key_value/uiaa.rs +++ /dev/null @@ -1,68 +0,0 @@ -use ruma::{ - api::client::{error::ErrorKind, uiaa::UiaaInfo}, - CanonicalJsonValue, DeviceId, UserId, -}; - -use crate::{Error, KeyValueDatabase, Result}; - -impl crate::uiaa::Data for KeyValueDatabase { - fn set_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, - ) -> Result<()> { - self.userdevicesessionid_uiaarequest - .write() - .unwrap() - .insert( - (user_id.to_owned(), device_id.to_owned(), session.to_owned()), - request.to_owned(), - ); - - Ok(()) - } - - fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option { - self.userdevicesessionid_uiaarequest - .read() - .unwrap() - .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) - .map(ToOwned::to_owned) - } - - fn update_uiaa_session( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>, - ) -> Result<()> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - if let Some(uiaainfo) = uiaainfo { - self.userdevicesessionid_uiaainfo.insert( - &userdevicesessionid, - &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), - )?; - } else { - self.userdevicesessionid_uiaainfo - .remove(&userdevicesessionid)?; - } - - Ok(()) - } - - fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - serde_json::from_slice( - &self - .userdevicesessionid_uiaainfo - .get(&userdevicesessionid)? - .ok_or(Error::BadRequest(ErrorKind::forbidden(), "UIAA session does not exist."))?, - ) - .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) - } -} diff --git a/src/service/key_value/users.rs b/src/service/key_value/users.rs deleted file mode 100644 index c8adb39f..00000000 --- a/src/service/key_value/users.rs +++ /dev/null @@ -1,898 +0,0 @@ -use std::{collections::BTreeMap, mem::size_of}; - -use argon2::{password_hash::SaltString, PasswordHasher}; -use ruma::{ - api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{AnyToDeviceEvent, StateEventType}, - serde::Raw, - uint, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, - OwnedMxcUri, OwnedUserId, UInt, UserId, -}; -use tracing::warn; - -use crate::{services, users::clean_signatures, utils, Error, KeyValueDatabase, Result}; - -impl crate::users::Data for KeyValueDatabase { - /// Check if a user has an account on this homeserver. - fn exists(&self, user_id: &UserId) -> Result { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) } - - /// Check if account is deactivated - fn is_deactivated(&self, user_id: &UserId) -> Result { - Ok(self - .userid_password - .get(user_id.as_bytes())? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))? - .is_empty()) - } - - /// Returns the number of users registered on this server. - fn count(&self) -> Result { Ok(self.userid_password.iter().count()) } - - /// Find out which user an access token belongs to. - fn find_from_token(&self, token: &str) -> Result> { - self.token_userdeviceid - .get(token.as_bytes())? - .map_or(Ok(None), |bytes| { - let mut parts = bytes.split(|&b| b == 0xFF); - let user_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?; - let device_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?; - - Ok(Some(( - UserId::parse( - utils::string_from_bytes(user_bytes) - .map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid."))?, - utils::string_from_bytes(device_bytes) - .map_err(|_| Error::bad_database("Device ID in token_userdeviceid is invalid."))?, - ))) - }) - } - - /// Returns an iterator over all users on this homeserver. - fn iter<'a>(&'a self) -> Box> + 'a> { - Box::new(self.userid_password.iter().map(|(bytes, _)| { - UserId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("User ID in userid_password is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) - })) - } - - /// Returns a list of local users as list of usernames. - /// - /// A user account is considered `local` if the length of it's password is - /// greater then zero. - fn list_local_users(&self) -> Result> { - let users: Vec = self - .userid_password - .iter() - .filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw)) - .collect(); - Ok(users) - } - - /// Returns the password hash for the given user. - fn password_hash(&self, user_id: &UserId) -> Result> { - self.userid_password - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Password hash in db is not valid string.") - })?)) - }) - } - - /// Hash and set the user's password to the Argon2 hash - fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - if let Some(password) = password { - if let Ok(hash) = calculate_password_hash(password) { - self.userid_password - .insert(user_id.as_bytes(), hash.as_bytes())?; - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Password does not meet the requirements.", - )) - } - } else { - self.userid_password.insert(user_id.as_bytes(), b"")?; - Ok(()) - } - } - - /// Returns the displayname of a user on this homeserver. - fn displayname(&self, user_id: &UserId) -> Result> { - self.userid_displayname - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Displayname in db is invalid."))?, - )) - }) - } - - /// Sets a new displayname or removes it if displayname is None. You still - /// need to nofify all rooms of this change. - fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { - if let Some(displayname) = displayname { - self.userid_displayname - .insert(user_id.as_bytes(), displayname.as_bytes())?; - } else { - self.userid_displayname.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Get the `avatar_url` of a user. - fn avatar_url(&self, user_id: &UserId) -> Result> { - self.userid_avatarurl - .get(user_id.as_bytes())? - .map(|bytes| { - let s_bytes = utils::string_from_bytes(&bytes).map_err(|e| { - warn!("Avatar URL in db is invalid: {}", e); - Error::bad_database("Avatar URL in db is invalid.") - })?; - let mxc_uri: OwnedMxcUri = s_bytes.into(); - Ok(mxc_uri) - }) - .transpose() - } - - /// Sets a new avatar_url or removes it if avatar_url is None. - fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { - if let Some(avatar_url) = avatar_url { - self.userid_avatarurl - .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; - } else { - self.userid_avatarurl.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Get the blurhash of a user. - fn blurhash(&self, user_id: &UserId) -> Result> { - self.userid_blurhash - .get(user_id.as_bytes())? - .map(|bytes| { - let s = utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; - - Ok(s) - }) - .transpose() - } - - /// Sets a new avatar_url or removes it if avatar_url is None. - fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { - if let Some(blurhash) = blurhash { - self.userid_blurhash - .insert(user_id.as_bytes(), blurhash.as_bytes())?; - } else { - self.userid_blurhash.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Adds a new device to a user. - fn create_device( - &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, - ) -> Result<()> { - // This method should never be called for nonexistent users. We shouldn't assert - // though... - if !self.exists(user_id)? { - warn!("Called create_device for non-existent user {} in database", user_id); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist.")); - } - - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(&Device { - device_id: device_id.into(), - display_name: initial_device_display_name, - last_seen_ip: None, // TODO - last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), - }) - .expect("Device::to_string never fails."), - )?; - - self.set_token(user_id, device_id, token)?; - - Ok(()) - } - - /// Removes a device from a user. - fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Remove tokens - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.userdeviceid_token.remove(&userdeviceid)?; - self.token_userdeviceid.remove(&old_token)?; - } - - // Remove todevice events - let mut prefix = userdeviceid.clone(); - prefix.push(0xFF); - - for (key, _) in self.todeviceid_events.scan_prefix(prefix) { - self.todeviceid_events.remove(&key)?; - } - - // TODO: Remove onetimekeys - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.remove(&userdeviceid)?; - - Ok(()) - } - - /// Returns an iterator over all device ids of this user. - fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - // All devices have metadata - Box::new( - self.userdeviceid_metadata - .scan_prefix(prefix) - .map(|(bytes, _)| { - Ok(utils::string_from_bytes( - bytes - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, - ) - .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? - .into()) - }), - ) - } - - /// Replaces the access token of one device. - fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // should not be None, but we shouldn't assert either lol... - if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - warn!( - "Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", - user_id, device_id - ); - return Err(Error::bad_database( - "User does not exist or device ID has no metadata in database.", - )); - } - - // Remove old token - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.token_userdeviceid.remove(&old_token)?; - // It will be removed from userdeviceid_token by the insert later - } - - // Assign token to user device combination - self.userdeviceid_token - .insert(&userdeviceid, token.as_bytes())?; - self.token_userdeviceid - .insert(token.as_bytes(), &userdeviceid)?; - - Ok(()) - } - - fn add_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, - one_time_key_value: &Raw, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - - // All devices have metadata - // Only existing devices should be able to call this, but we shouldn't assert - // either... - if self.userdeviceid_metadata.get(&key)?.is_none() { - warn!( - "Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in \ - database", - user_id, device_id - ); - return Err(Error::bad_database( - "User does not exist or device ID has no metadata in database.", - )); - } - - key.push(0xFF); - // TODO: Use DeviceKeyId::to_string when it's available (and update everything, - // because there are no wrapping quotation marks anymore) - key.extend_from_slice( - serde_json::to_string(one_time_key_key) - .expect("DeviceKeyId::to_string always works") - .as_bytes(), - ); - - self.onetimekeyid_onetimekeys.insert( - &key, - &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), - )?; - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; - - Ok(()) - } - - fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.userid_lastonetimekeyupdate - .get(user_id.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")) - }) - } - - fn take_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - ) -> Result)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.push(b'"'); // Annoying quotation mark - prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); - prefix.push(b':'); - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; - - self.onetimekeyid_onetimekeys - .scan_prefix(prefix) - .next() - .map(|(key, value)| { - self.onetimekeyid_onetimekeys.remove(&key)?; - - Ok(( - serde_json::from_slice( - key.rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, - ) - .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, - )) - }) - .transpose() - } - - fn count_one_time_keys( - &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - let mut counts = BTreeMap::new(); - - for algorithm in self - .onetimekeyid_onetimekeys - .scan_prefix(userdeviceid) - .map(|(bytes, _)| { - Ok::<_, Error>( - serde_json::from_slice::( - bytes - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| Error::bad_database("OneTimeKey ID in db is invalid."))?, - ) - .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? - .algorithm(), - ) - }) { - *counts.entry(algorithm?).or_default() += uint!(1); - } - - Ok(counts) - } - - fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.insert( - &userdeviceid, - &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), - )?; - - self.mark_device_key_update(user_id)?; - - Ok(()) - } - - fn add_cross_signing_keys( - &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, - user_signing_key: &Option>, notify: bool, - ) -> Result<()> { - // TODO: Check signatures - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let (master_key_key, _) = self.parse_master_key(user_id, master_key)?; - - self.keyid_key - .insert(&master_key_key, master_key.json().get().as_bytes())?; - - self.userid_masterkeyid - .insert(user_id.as_bytes(), &master_key_key)?; - - // Self-signing key - if let Some(self_signing_key) = self_signing_key { - let mut self_signing_key_ids = self_signing_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))? - .keys - .into_values(); - - let self_signing_key_id = self_signing_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?; - - if self_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained more than one key.", - )); - } - - let mut self_signing_key_key = prefix.clone(); - self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); - - self.keyid_key - .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?; - - self.userid_selfsigningkeyid - .insert(user_id.as_bytes(), &self_signing_key_key)?; - } - - // User-signing key - if let Some(user_signing_key) = user_signing_key { - let mut user_signing_key_ids = user_signing_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))? - .keys - .into_values(); - - let user_signing_key_id = user_signing_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?; - - if user_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User signing key contained more than one key.", - )); - } - - let mut user_signing_key_key = prefix; - user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); - - self.keyid_key - .insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?; - - self.userid_usersigningkeyid - .insert(user_id.as_bytes(), &user_signing_key_key)?; - } - - if notify { - self.mark_device_key_update(user_id)?; - } - - Ok(()) - } - - fn sign_key( - &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, - ) -> Result<()> { - let mut key = target_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(key_id.as_bytes()); - - let mut cross_signing_key: serde_json::Value = serde_json::from_slice( - &self - .keyid_key - .get(&key)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?, - ) - .map_err(|_| Error::bad_database("key in keyid_key is invalid."))?; - - let signatures = cross_signing_key - .get_mut("signatures") - .ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))? - .as_object_mut() - .ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))? - .entry(sender_id.to_string()) - .or_insert_with(|| serde_json::Map::new().into()); - - signatures - .as_object_mut() - .ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))? - .insert(signature.0, signature.1.into()); - - self.keyid_key.insert( - &key, - &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), - )?; - - self.mark_device_key_update(target_id)?; - - Ok(()) - } - - fn keys_changed<'a>( - &'a self, user_or_room_id: &str, from: u64, to: Option, - ) -> Box> + 'a> { - let mut prefix = user_or_room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let mut start = prefix.clone(); - start.extend_from_slice(&(from.saturating_add(1)).to_be_bytes()); - - let to = to.unwrap_or(u64::MAX); - - Box::new( - self.keychangeid_userid - .iter_from(&start, false) - .take_while(move |(k, _)| { - k.starts_with(&prefix) - && if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) { - if let Ok(c) = utils::u64_from_bytes(current) { - c <= to - } else { - warn!("BadDatabase: Could not parse keychangeid_userid bytes"); - false - } - } else { - warn!("BadDatabase: Could not parse keychangeid_userid"); - false - } - }) - .map(|(_, bytes)| { - UserId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid.")) - }), - ) - } - - fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - let count = services().globals.next_count()?.to_be_bytes(); - for room_id in services() - .rooms - .state_cache - .rooms_joined(user_id) - .filter_map(Result::ok) - { - // Don't send key updates to unencrypted rooms - if services() - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? - .is_none() - { - continue; - } - - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&count); - - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - } - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&count); - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - - Ok(()) - } - - fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some( - serde_json::from_slice(&bytes).map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?, - )) - }) - } - - fn parse_master_key( - &self, user_id: &UserId, master_key: &Raw, - ) -> Result<(Vec, CrossSigningKey)> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let master_key = master_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; - let mut master_key_ids = master_key.keys.values(); - let master_key_id = master_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; - if master_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Master key contained more than one key.", - )); - } - let mut master_key_key = prefix.clone(); - master_key_key.extend_from_slice(master_key_id.as_bytes()); - Ok((master_key_key, master_key)) - } - - fn get_key( - &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { - let mut cross_signing_key = serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; - clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?; - - Ok(Some(Raw::from_json( - serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"), - ))) - }) - } - - fn get_master_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.userid_masterkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) - } - - fn get_self_signing_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.userid_selfsigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) - } - - fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.userid_usersigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some( - serde_json::from_slice(&bytes) - .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?, - )) - }) - }) - } - - fn add_to_device_event( - &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, - content: serde_json::Value, - ) -> Result<()> { - let mut key = target_user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(target_device_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); - - let mut json = serde_json::Map::new(); - json.insert("type".to_owned(), event_type.to_owned().into()); - json.insert("sender".to_owned(), sender.to_string().into()); - json.insert("content".to_owned(), content); - - let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); - - self.todeviceid_events.insert(&key, &value)?; - - Ok(()) - } - - fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { - let mut events = Vec::new(); - - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - - for (_, value) in self.todeviceid_events.scan_prefix(prefix) { - events.push( - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, - ); - } - - Ok(events) - } - - fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - - let mut last = prefix.clone(); - last.extend_from_slice(&until.to_be_bytes()); - - for (key, _) in self - .todeviceid_events - .iter_from(&last, true) // this includes last - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(key, _)| { - Ok::<_, Error>(( - key.clone(), - utils::u64_from_bytes(&key[key.len() - size_of::()..key.len()]) - .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, - )) - }) - .filter_map(Result::ok) - .take_while(|&(_, count)| count <= until) - { - self.todeviceid_events.remove(&key)?; - } - - Ok(()) - } - - fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Only existing devices should be able to call this, but we shouldn't assert - // either... - if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - warn!( - "Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \ - metadata in database", - user_id, device_id - ); - return Err(Error::bad_database( - "User does not exist or device ID has no metadata in database.", - )); - } - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(device).expect("Device::to_string always works"), - )?; - - Ok(()) - } - - /// Get device metadata. - fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userdeviceid_metadata - .get(&userdeviceid)? - .map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("Metadata in userdeviceid_metadata is invalid.") - })?)) - }) - } - - fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.userid_devicelistversion - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid devicelistversion in db.")) - .map(Some) - }) - } - - fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - - Box::new( - self.userdeviceid_metadata - .scan_prefix(key) - .map(|(_, bytes)| { - serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) - }), - ) - } - - /// Creates a new sync filter. Returns the filter id. - fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { - let filter_id = utils::random_string(4); - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(filter_id.as_bytes()); - - self.userfilterid_filter - .insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?; - - Ok(filter_id) - } - - fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(filter_id.as_bytes()); - - let raw = self.userfilterid_filter.get(&key)?; - - if let Some(raw) = raw { - serde_json::from_slice(&raw).map_err(|_| Error::bad_database("Invalid filter event in db.")) - } else { - Ok(None) - } - } -} - -/// Will only return with Some(username) if the password was not empty and the -/// username could be successfully parsed. -/// If `utils::string_from_bytes`(...) returns an error that username will be -/// skipped and the error will be logged. -fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option { - // A valid password is not empty - if password.is_empty() { - None - } else { - match utils::string_from_bytes(username) { - Ok(u) => Some(u), - Err(e) => { - warn!("Failed to parse username while calling get_local_users(): {}", e.to_string()); - None - }, - } - } -} - -/// Calculate a new hash for the given password -fn calculate_password_hash(password: &str) -> Result { - let salt = SaltString::generate(rand::thread_rng()); - services() - .globals - .argon - .hash_password(password.as_bytes(), &salt) - .map(|it| it.to_string()) -} diff --git a/src/service/media/data.rs b/src/service/media/data.rs index c464e672..f2a83a53 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,6 +1,10 @@ -use crate::Result; +use conduit::debug_info; +use ruma::api::client::error::ErrorKind; +use tracing::debug; -pub trait Data: Send + Sync { +use crate::{media::UrlPreviewData, utils::string_from_bytes, Error, KeyValueDatabase, Result}; + +pub(crate) trait Data: Send + Sync { fn create_file_metadata( &self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>, @@ -21,7 +25,237 @@ pub trait Data: Send + Sync { #[allow(dead_code)] fn remove_url_preview(&self, url: &str) -> Result<()>; - fn set_url_preview(&self, url: &str, data: &super::UrlPreviewData, timestamp: std::time::Duration) -> Result<()>; + fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()>; - fn get_url_preview(&self, url: &str) -> Option; + fn get_url_preview(&self, url: &str) -> Option; +} + +impl Data for KeyValueDatabase { + fn create_file_metadata( + &self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, + content_type: Option<&str>, + ) -> Result> { + let mut key = mxc.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&width.to_be_bytes()); + key.extend_from_slice(&height.to_be_bytes()); + key.push(0xFF); + key.extend_from_slice( + content_disposition + .as_ref() + .map(|f| f.as_bytes()) + .unwrap_or_default(), + ); + key.push(0xFF); + key.extend_from_slice( + content_type + .as_ref() + .map(|c| c.as_bytes()) + .unwrap_or_default(), + ); + + self.mediaid_file.insert(&key, &[])?; + + if let Some(user) = sender_user { + let key = mxc.as_bytes().to_vec(); + let user = user.as_bytes().to_vec(); + self.mediaid_user.insert(&key, &user)?; + } + + Ok(key) + } + + fn delete_file_mxc(&self, mxc: String) -> Result<()> { + debug!("MXC URI: {:?}", mxc); + + let mut prefix = mxc.as_bytes().to_vec(); + prefix.push(0xFF); + + debug!("MXC db prefix: {prefix:?}"); + + for (key, _) in self.mediaid_file.scan_prefix(prefix) { + debug!("Deleting key: {:?}", key); + self.mediaid_file.remove(&key)?; + } + + for (key, value) in self.mediaid_user.scan_prefix(mxc.as_bytes().to_vec()) { + if key == mxc.as_bytes().to_vec() { + let user = string_from_bytes(&value).unwrap_or_default(); + + debug_info!("Deleting key \"{key:?}\" which was uploaded by user {user}"); + self.mediaid_user.remove(&key)?; + } + } + + Ok(()) + } + + /// Searches for all files with the given MXC + fn search_mxc_metadata_prefix(&self, mxc: String) -> Result>> { + debug!("MXC URI: {:?}", mxc); + + let mut prefix = mxc.as_bytes().to_vec(); + prefix.push(0xFF); + + let keys: Vec> = self + .mediaid_file + .scan_prefix(prefix) + .map(|(key, _)| key) + .collect(); + + if keys.is_empty() { + return Err(Error::bad_database( + "Failed to find any keys in database with the provided MXC.", + )); + } + + debug!("Got the following keys: {:?}", keys); + + Ok(keys) + } + + fn search_file_metadata( + &self, mxc: String, width: u32, height: u32, + ) -> Result<(Option, Option, Vec)> { + let mut prefix = mxc.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(&width.to_be_bytes()); + prefix.extend_from_slice(&height.to_be_bytes()); + prefix.push(0xFF); + + let (key, _) = self + .mediaid_file + .scan_prefix(prefix) + .next() + .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; + + let mut parts = key.rsplit(|&b| b == 0xFF); + + let content_type = parts + .next() + .map(|bytes| { + string_from_bytes(bytes) + .map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode.")) + }) + .transpose()?; + + let content_disposition_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; + + let content_disposition = if content_disposition_bytes.is_empty() { + None + } else { + Some( + string_from_bytes(content_disposition_bytes) + .map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?, + ) + }; + Ok((content_disposition, content_type, key)) + } + + /// Gets all the media keys in our database (this includes all the metadata + /// associated with it such as width, height, content-type, etc) + fn get_all_media_keys(&self) -> Vec> { self.mediaid_file.iter().map(|(key, _)| key).collect() } + + fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) } + + fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> { + let mut value = Vec::::new(); + value.extend_from_slice(×tamp.as_secs().to_be_bytes()); + value.push(0xFF); + value.extend_from_slice( + data.title + .as_ref() + .map(String::as_bytes) + .unwrap_or_default(), + ); + value.push(0xFF); + value.extend_from_slice( + data.description + .as_ref() + .map(String::as_bytes) + .unwrap_or_default(), + ); + value.push(0xFF); + value.extend_from_slice( + data.image + .as_ref() + .map(String::as_bytes) + .unwrap_or_default(), + ); + value.push(0xFF); + value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes()); + value.push(0xFF); + value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes()); + value.push(0xFF); + value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes()); + + self.url_previews.insert(url.as_bytes(), &value) + } + + fn get_url_preview(&self, url: &str) -> Option { + let values = self.url_previews.get(url.as_bytes()).ok()??; + + let mut values = values.split(|&b| b == 0xFF); + + let _ts = values.next(); + /* if we ever decide to use timestamp, this is here. + match values.next().map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) { + Some(0) => None, + x => x, + };*/ + + let title = match values + .next() + .and_then(|b| String::from_utf8(b.to_vec()).ok()) + { + Some(s) if s.is_empty() => None, + x => x, + }; + let description = match values + .next() + .and_then(|b| String::from_utf8(b.to_vec()).ok()) + { + Some(s) if s.is_empty() => None, + x => x, + }; + let image = match values + .next() + .and_then(|b| String::from_utf8(b.to_vec()).ok()) + { + Some(s) if s.is_empty() => None, + x => x, + }; + let image_size = match values + .next() + .map(|b| usize::from_be_bytes(b.try_into().unwrap_or_default())) + { + Some(0) => None, + x => x, + }; + let image_width = match values + .next() + .map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default())) + { + Some(0) => None, + x => x, + }; + let image_height = match values + .next() + .map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default())) + { + Some(0) => None, + x => x, + }; + + Some(UrlPreviewData { + title, + description, + image, + image_size, + image_width, + image_height, + }) + } } diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 02a4792a..5521011d 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -1,7 +1,7 @@ mod data; use std::{collections::HashMap, io::Cursor, sync::Arc, time::SystemTime}; -pub use data::Data; +use data::Data; use image::imageops::FilterType; use ruma::{OwnedMxcUri, OwnedUserId}; use serde::Serialize; @@ -39,7 +39,7 @@ pub struct UrlPreviewData { } pub struct Service { - pub db: Arc, + pub(super) db: Arc, pub url_preview_mutex: RwLock>>>, } diff --git a/src/service/mod.rs b/src/service/mod.rs index 386e8662..028ed959 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,4 +1,3 @@ -pub(crate) mod key_value; pub mod pdu; pub mod services; diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index 6f0f58f8..ad8e55f8 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -1,6 +1,12 @@ +use conduit::debug_info; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; -use crate::Result; +use crate::{ + presence::Presence, + services, + utils::{self, user_id_from_bytes}, + Error, KeyValueDatabase, Result, +}; pub trait Data: Send + Sync { /// Returns the latest presence event for the given user. @@ -19,3 +25,121 @@ pub trait Data: Send + Sync { /// with id `since`. fn presence_since<'a>(&'a self, since: u64) -> Box)> + 'a>; } + +impl Data for KeyValueDatabase { + fn get_presence(&self, user_id: &UserId) -> Result> { + if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { + let count = utils::u64_from_bytes(&count_bytes) + .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; + + let key = presenceid_key(count, user_id); + self.presenceid_presence + .get(&key)? + .map(|presence_bytes| -> Result<(u64, PresenceEvent)> { + Ok((count, Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)?)) + }) + .transpose() + } else { + Ok(None) + } + } + + fn set_presence( + &self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option, + last_active_ago: Option, status_msg: Option, + ) -> Result<()> { + let last_presence = self.get_presence(user_id)?; + let state_changed = match last_presence { + None => true, + Some(ref presence) => presence.1.content.presence != *presence_state, + }; + + let now = utils::millis_since_unix_epoch(); + let last_last_active_ts = match last_presence { + None => 0, + Some((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()), + }; + + let last_active_ts = match last_active_ago { + None => now, + Some(last_active_ago) => now.saturating_sub(last_active_ago.into()), + }; + + // tighten for state flicker? + if !state_changed && last_active_ts <= last_last_active_ts { + debug_info!( + "presence spam {:?} last_active_ts:{:?} <= {:?}", + user_id, + last_active_ts, + last_last_active_ts + ); + return Ok(()); + } + + let status_msg = if status_msg.as_ref().is_some_and(String::is_empty) { + None + } else { + status_msg + }; + + let presence = Presence::new( + presence_state.to_owned(), + currently_active.unwrap_or(false), + last_active_ts, + status_msg, + ); + let count = services().globals.next_count()?; + let key = presenceid_key(count, user_id); + + self.presenceid_presence + .insert(&key, &presence.to_json_bytes()?)?; + + self.userid_presenceid + .insert(user_id.as_bytes(), &count.to_be_bytes())?; + + if let Some((last_count, _)) = last_presence { + let key = presenceid_key(last_count, user_id); + self.presenceid_presence.remove(&key)?; + } + + Ok(()) + } + + fn remove_presence(&self, user_id: &UserId) -> Result<()> { + if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { + let count = utils::u64_from_bytes(&count_bytes) + .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; + let key = presenceid_key(count, user_id); + self.presenceid_presence.remove(&key)?; + self.userid_presenceid.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + fn presence_since<'a>(&'a self, since: u64) -> Box)> + 'a> { + Box::new( + self.presenceid_presence + .iter() + .flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, Vec)> { + let (count, user_id) = presenceid_parse(&key)?; + Ok((user_id, count, presence_bytes)) + }) + .filter(move |(_, count, _)| *count > since), + ) + } +} + +#[inline] +fn presenceid_key(count: u64, user_id: &UserId) -> Vec { + [count.to_be_bytes().to_vec(), user_id.as_bytes().to_vec()].concat() +} + +#[inline] +fn presenceid_parse(key: &[u8]) -> Result<(u64, OwnedUserId)> { + let (count, user_id) = key.split_at(8); + let user_id = user_id_from_bytes(user_id)?; + let count = utils::u64_from_bytes(count).unwrap(); + + Ok((count, user_id)) +} diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index 58be2ddf..de5030f6 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -2,7 +2,7 @@ mod data; use std::{sync::Arc, time::Duration}; -pub use data::Data; +use data::Data; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ events::presence::{PresenceEvent, PresenceEventContent}, diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs index b58cd3fc..d8c75f4d 100644 --- a/src/service/pusher/data.rs +++ b/src/service/pusher/data.rs @@ -3,9 +3,9 @@ use ruma::{ UserId, }; -use crate::Result; +use crate::{utils, Error, KeyValueDatabase, Result}; -pub trait Data: Send + Sync { +pub(crate) trait Data: Send + Sync { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>; fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result>; @@ -14,3 +14,62 @@ pub trait Data: Send + Sync { fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a>; } + +impl Data for KeyValueDatabase { + fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { + match &pusher { + set_pusher::v3::PusherAction::Post(data) => { + let mut key = sender.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); + self.senderkey_pusher + .insert(&key, &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"))?; + Ok(()) + }, + set_pusher::v3::PusherAction::Delete(ids) => { + let mut key = sender.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(ids.pushkey.as_bytes()); + self.senderkey_pusher.remove(&key).map_err(Into::into) + }, + } + } + + fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { + let mut senderkey = sender.as_bytes().to_vec(); + senderkey.push(0xFF); + senderkey.extend_from_slice(pushkey.as_bytes()); + + self.senderkey_pusher + .get(&senderkey)? + .map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) + .transpose() + } + + fn get_pushers(&self, sender: &UserId) -> Result> { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xFF); + + self.senderkey_pusher + .scan_prefix(prefix) + .map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) + .collect() + } + + fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a> { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { + let mut parts = k.splitn(2, |&b| b == 0xFF); + let _senderkey = parts.next(); + let push_key = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; + let push_key_string = utils::string_from_bytes(push_key) + .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; + + Ok(push_key_string) + })) + } +} diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 19a570c4..261d69dd 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -2,7 +2,7 @@ mod data; use std::{fmt::Debug, mem, sync::Arc}; use bytes::BytesMut; -pub use data::Data; +use data::Data; use ipaddress::IPAddress; use ruma::{ api::{ @@ -25,7 +25,7 @@ use tracing::{info, trace, warn}; use crate::{debug_info, services, Error, PduEvent, Result}; pub struct Service { - pub db: Arc, + pub(super) db: Arc, } impl Service { diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index 095d6e66..536f8228 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -1,6 +1,6 @@ -use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; +use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; -use crate::Result; +use crate::{services, utils, Error, KeyValueDatabase, Result}; pub trait Data: Send + Sync { /// Creates or updates the alias to the given room id. @@ -20,3 +20,75 @@ pub trait Data: Send + Sync { /// Returns all local aliases on the server fn all_local_aliases<'a>(&'a self) -> Box> + 'a>; } + +impl Data for KeyValueDatabase { + fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { + self.alias_roomid + .insert(alias.alias().as_bytes(), room_id.as_bytes())?; + let mut aliasid = room_id.as_bytes().to_vec(); + aliasid.push(0xFF); + aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; + Ok(()) + } + + fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { + if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { + let mut prefix = room_id; + prefix.push(0xFF); + + for (key, _) in self.aliasid_alias.scan_prefix(prefix) { + self.aliasid_alias.remove(&key)?; + } + self.alias_roomid.remove(alias.alias().as_bytes())?; + } else { + return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist.")); + } + Ok(()) + } + + fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { + self.alias_roomid + .get(alias.alias().as_bytes())? + .map(|bytes| { + RoomId::parse( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) + }) + .transpose() + } + + fn local_aliases_for_room<'a>( + &'a self, room_id: &RoomId, + ) -> Box> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? + .try_into() + .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) + })) + } + + fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { + Box::new( + self.alias_roomid + .iter() + .map(|(room_alias_bytes, room_id_bytes)| { + let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes) + .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?; + + let room_id = utils::string_from_bytes(&room_id_bytes) + .map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))? + .try_into() + .map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?; + + Ok((room_id, room_alias_localpart)) + }), + ) + } +} diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 6e8b386a..f8d10b45 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -2,7 +2,7 @@ mod data; use std::sync::Arc; -pub use data::Data; +use data::Data; use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; use crate::Result; diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index f77d2d90..c3f046fc 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -1,8 +1,64 @@ -use std::sync::Arc; +use std::{mem::size_of, sync::Arc}; -use crate::Result; +use crate::{utils, KeyValueDatabase, Result}; pub trait Data: Send + Sync { fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result>>; fn cache_auth_chain(&self, shorteventid: Vec, auth_chain: Arc<[u64]>) -> Result<()>; } + +impl Data for KeyValueDatabase { + fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { + // Check RAM cache + if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { + return Ok(Some(Arc::clone(result))); + } + + // We only save auth chains for single events in the db + if key.len() == 1 { + // Check DB cache + let chain = self + .shorteventid_authchain + .get(&key[0].to_be_bytes())? + .map(|chain| { + chain + .chunks_exact(size_of::()) + .map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct")) + .collect::>() + }); + + if let Some(chain) = chain { + // Cache in RAM + self.auth_chain_cache + .lock() + .unwrap() + .insert(vec![key[0]], Arc::clone(&chain)); + + return Ok(Some(chain)); + } + } + + Ok(None) + } + + fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) -> Result<()> { + // Only persist single events in db + if key.len() == 1 { + self.shorteventid_authchain.insert( + &key[0].to_be_bytes(), + &auth_chain + .iter() + .flat_map(|s| s.to_be_bytes().to_vec()) + .collect::>(), + )?; + } + + // Cache in RAM + self.auth_chain_cache + .lock() + .unwrap() + .insert(key, auth_chain); + + Ok(()) + } +} diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 4c9152b0..511b762e 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -4,7 +4,7 @@ use std::{ sync::Arc, }; -pub use data::Data; +use data::Data; use ruma::{api::client::error::ErrorKind, EventId, RoomId}; use tracing::{debug, error, trace, warn}; diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs index 691b8604..67afbb6c 100644 --- a/src/service/rooms/directory/data.rs +++ b/src/service/rooms/directory/data.rs @@ -1,6 +1,6 @@ use ruma::{OwnedRoomId, RoomId}; -use crate::Result; +use crate::{utils, Error, KeyValueDatabase, Result}; pub trait Data: Send + Sync { /// Adds the room to the public room directory @@ -15,3 +15,23 @@ pub trait Data: Send + Sync { /// Returns the unsorted public room directory fn public_rooms<'a>(&'a self) -> Box> + 'a>; } + +impl Data for KeyValueDatabase { + fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) } + + fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) } + + fn is_public_room(&self, room_id: &RoomId) -> Result { + Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) + } + + fn public_rooms<'a>(&'a self) -> Box> + 'a> { + Box::new(self.publicroomids.iter().map(|(bytes, _)| { + RoomId::parse( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) + })) + } +} diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index ab69d003..85909c74 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -2,7 +2,7 @@ mod data; use std::sync::Arc; -pub use data::Data; +use data::Data; use ruma::{OwnedRoomId, RoomId}; use crate::Result; diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs index 890a2f98..7c237901 100644 --- a/src/service/rooms/lazy_loading/data.rs +++ b/src/service/rooms/lazy_loading/data.rs @@ -1,6 +1,6 @@ use ruma::{DeviceId, RoomId, UserId}; -use crate::Result; +use crate::{KeyValueDatabase, Result}; pub trait Data: Send + Sync { fn lazy_load_was_sent_before( @@ -14,3 +14,53 @@ pub trait Data: Send + Sync { fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()>; } + +impl Data for KeyValueDatabase { + fn lazy_load_was_sent_before( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, + ) -> Result { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(ll_user.as_bytes()); + Ok(self.lazyloadedids.get(&key)?.is_some()) + } + + fn lazy_load_confirm_delivery( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, + confirmed_user_ids: &mut dyn Iterator, + ) -> Result<()> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xFF); + + for ll_id in confirmed_user_ids { + let mut key = prefix.clone(); + key.extend_from_slice(ll_id.as_bytes()); + self.lazyloadedids.insert(&key, &[])?; + } + + Ok(()) + } + + fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xFF); + + for (key, _) in self.lazyloadedids.scan_prefix(prefix) { + self.lazyloadedids.remove(&key)?; + } + + Ok(()) + } +} diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 565a186d..283a03c1 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -4,7 +4,7 @@ use std::{ sync::Arc, }; -pub use data::Data; +use data::Data; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use tokio::sync::Mutex; diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index d702b203..a6bb701e 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -1,6 +1,7 @@ use ruma::{OwnedRoomId, RoomId}; +use tracing::error; -use crate::Result; +use crate::{services, utils, Error, KeyValueDatabase, Result}; pub trait Data: Send + Sync { fn exists(&self, room_id: &RoomId) -> Result; @@ -11,3 +12,75 @@ pub trait Data: Send + Sync { fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()>; fn list_banned_rooms<'a>(&'a self) -> Box> + 'a>; } + +impl Data for KeyValueDatabase { + fn exists(&self, room_id: &RoomId) -> Result { + let prefix = match services().rooms.short.get_shortroomid(room_id)? { + Some(b) => b.to_be_bytes().to_vec(), + None => return Ok(false), + }; + + // Look for PDUs in that room. + Ok(self + .pduid_pdu + .iter_from(&prefix, false) + .next() + .filter(|(k, _)| k.starts_with(&prefix)) + .is_some()) + } + + fn iter_ids<'a>(&'a self) -> Box> + 'a> { + Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { + RoomId::parse( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) + })) + } + + fn is_disabled(&self, room_id: &RoomId) -> Result { + Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) + } + + fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { + if disabled { + self.disabledroomids.insert(room_id.as_bytes(), &[])?; + } else { + self.disabledroomids.remove(room_id.as_bytes())?; + } + + Ok(()) + } + + fn is_banned(&self, room_id: &RoomId) -> Result { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) } + + fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { + if banned { + self.bannedroomids.insert(room_id.as_bytes(), &[])?; + } else { + self.bannedroomids.remove(room_id.as_bytes())?; + } + + Ok(()) + } + + fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { + Box::new(self.bannedroomids.iter().map( + |(room_id_bytes, _ /* non-banned rooms should not be in this table */)| { + let room_id = utils::string_from_bytes(&room_id_bytes) + .map_err(|e| { + error!("Invalid room_id bytes in bannedroomids: {e}"); + Error::bad_database("Invalid room_id in bannedroomids.") + })? + .try_into() + .map_err(|e| { + error!("Invalid room_id in bannedroomids: {e}"); + Error::bad_database("Invalid room_id in bannedroomids") + })?; + + Ok(room_id) + }, + )) + } +} diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index e14d539d..2a6f7724 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -2,7 +2,7 @@ mod data; use std::sync::Arc; -pub use data::Data; +use data::Data; use ruma::{OwnedRoomId, RoomId}; use crate::Result; diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index baf3f7b5..bef56a25 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -19,27 +19,6 @@ pub mod timeline; pub mod typing; pub mod user; -pub trait Data: - alias::Data - + auth_chain::Data - + directory::Data - + lazy_loading::Data - + metadata::Data - + outlier::Data - + pdu_metadata::Data - + read_receipt::Data - + search::Data - + short::Data - + state::Data - + state_accessor::Data - + state_cache::Data - + state_compressor::Data - + timeline::Data - + threads::Data - + user::Data -{ -} - pub struct Service { pub alias: alias::Service, pub auth_chain: auth_chain::Service, diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs index 18eb3190..a278161c 100644 --- a/src/service/rooms/outlier/data.rs +++ b/src/service/rooms/outlier/data.rs @@ -1,9 +1,34 @@ use ruma::{CanonicalJsonObject, EventId}; -use crate::{PduEvent, Result}; +use crate::{Error, KeyValueDatabase, PduEvent, Result}; pub trait Data: Send + Sync { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result>; fn get_outlier_pdu(&self, event_id: &EventId) -> Result>; fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>; } + +impl Data for KeyValueDatabase { + fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map_or(Ok(None), |pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + } + + fn get_outlier_pdu(&self, event_id: &EventId) -> Result> { + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map_or(Ok(None), |pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + } + + fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { + self.eventid_outlierpdu.insert( + event_id.as_bytes(), + &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), + ) + } +} diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 9ec4010c..3e8b4ed9 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -2,7 +2,7 @@ mod data; use std::sync::Arc; -pub use data::Data; +use data::Data; use ruma::{CanonicalJsonObject, EventId}; use crate::{PduEvent, Result}; diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index ccc14edd..b5bea331 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,8 +1,8 @@ -use std::sync::Arc; +use std::{mem, sync::Arc}; use ruma::{EventId, RoomId, UserId}; -use crate::{PduCount, PduEvent, Result}; +use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result}; pub trait Data: Send + Sync { fn add_relation(&self, from: u64, to: u64) -> Result<()>; @@ -15,3 +15,77 @@ pub trait Data: Send + Sync { fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>; fn is_event_soft_failed(&self, event_id: &EventId) -> Result; } + +impl Data for KeyValueDatabase { + fn add_relation(&self, from: u64, to: u64) -> Result<()> { + let mut key = to.to_be_bytes().to_vec(); + key.extend_from_slice(&from.to_be_bytes()); + self.tofrom_relation.insert(&key, &[])?; + Ok(()) + } + + fn relations_until<'a>( + &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, + ) -> Result> + 'a>> { + let prefix = target.to_be_bytes().to_vec(); + let mut current = prefix.clone(); + + let count_raw = match until { + PduCount::Normal(x) => x.saturating_sub(1), + PduCount::Backfilled(x) => { + current.extend_from_slice(&0_u64.to_be_bytes()); + u64::MAX.saturating_sub(x).saturating_sub(1) + }, + }; + current.extend_from_slice(&count_raw.to_be_bytes()); + + Ok(Box::new( + self.tofrom_relation + .iter_from(¤t, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(tofrom, _data)| { + let from = utils::u64_from_bytes(&tofrom[(mem::size_of::())..]) + .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?; + + let mut pduid = shortroomid.to_be_bytes().to_vec(); + pduid.extend_from_slice(&from.to_be_bytes()); + + let mut pdu = services() + .rooms + .timeline + .get_pdu_from_id(&pduid)? + .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((PduCount::Normal(from), pdu)) + }), + )) + } + + fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { + for prev in event_ids { + let mut key = room_id.as_bytes().to_vec(); + key.extend_from_slice(prev.as_bytes()); + self.referencedevents.insert(&key, &[])?; + } + + Ok(()) + } + + fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { + let mut key = room_id.as_bytes().to_vec(); + key.extend_from_slice(event_id.as_bytes()); + Ok(self.referencedevents.get(&key)?.is_some()) + } + + fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { + self.softfailedeventids.insert(event_id.as_bytes(), &[]) + } + + fn is_event_soft_failed(&self, event_id: &EventId) -> Result { + self.softfailedeventids + .get(event_id.as_bytes()) + .map(|o| o.is_some()) + } +} diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index 7e0da835..25da29dd 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -2,7 +2,7 @@ mod data; use std::sync::Arc; -pub use data::Data; +use data::Data; use ruma::{ api::{client::relations::get_relating_events, Direction}, events::{relation::RelationType, TimelineEventType}, diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 4fe7be59..00d03e8c 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -1,10 +1,12 @@ +use std::mem; + use ruma::{ events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, serde::Raw, - OwnedUserId, RoomId, UserId, + CanonicalJsonObject, OwnedUserId, RoomId, UserId, }; -use crate::Result; +use crate::{services, utils, Error, KeyValueDatabase, Result}; type AnySyncEphemeralRoomEventIter<'a> = Box)>> + 'a>; @@ -32,3 +34,118 @@ pub trait Data: Send + Sync { #[allow(dead_code)] fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result; } + +impl Data for KeyValueDatabase { + fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + + // Remove old entry + if let Some((old, _)) = self + .readreceiptid_readreceipt + .iter_from(&last_possible_key, true) + .take_while(|(key, _)| key.starts_with(&prefix)) + .find(|(key, _)| { + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element") + == user_id.as_bytes() + }) { + // This is the old room_latest + self.readreceiptid_readreceipt.remove(&old)?; + } + + let mut room_latest_id = prefix; + room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + room_latest_id.push(0xFF); + room_latest_id.extend_from_slice(user_id.as_bytes()); + + self.readreceiptid_readreceipt.insert( + &room_latest_id, + &serde_json::to_vec(&event).expect("EduEvent::to_string always works"), + )?; + + Ok(()) + } + + fn readreceipts_since<'a>( + &'a self, room_id: &RoomId, since: u64, + ) -> Box)>> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + let prefix2 = prefix.clone(); + + let mut first_possible_edu = prefix.clone(); + first_possible_edu.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); // +1 so we don't send the event at since + + Box::new( + self.readreceiptid_readreceipt + .iter_from(&first_possible_edu, false) + .take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(move |(k, v)| { + let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::()]) + .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; + let user_id = UserId::parse( + utils::string_from_bytes(&k[prefix.len() + mem::size_of::() + 1..]) + .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?, + ) + .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; + + let mut json = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?; + json.remove("room_id"); + + Ok(( + user_id, + count, + Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")), + )) + }), + ) + } + + fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_privateread + .insert(&key, &count.to_be_bytes())?; + + self.roomuserid_lastprivatereadupdate + .insert(&key, &services().globals.next_count()?.to_be_bytes()) + } + + fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_privateread + .get(&key)? + .map_or(Ok(None), |v| { + Ok(Some( + utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?, + )) + }) + } + + fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); + + Ok(self + .roomuserid_lastprivatereadupdate + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) + }) + .transpose()? + .unwrap_or(0)) + } +} diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index a5b9c325..9afc1fd2 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -2,7 +2,7 @@ mod data; use std::sync::Arc; -pub use data::Data; +use data::Data; use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; use crate::{services, Result}; diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 96439adf..091f190e 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,6 +1,6 @@ use ruma::RoomId; -use crate::Result; +use crate::{services, utils, KeyValueDatabase, Result}; type SearchPdusResult<'a> = Result> + 'a>, Vec)>>; @@ -9,3 +9,62 @@ pub trait Data: Send + Sync { fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a>; } + +impl Data for KeyValueDatabase { + fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + let mut batch = message_body + .split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) + .filter(|word| word.len() <= 50) + .map(str::to_lowercase) + .map(|word| { + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(word.as_bytes()); + key.push(0xFF); + key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here + (key, Vec::new()) + }); + + self.tokenids.insert_batch(&mut batch) + } + + fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { + let prefix = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let words: Vec<_> = search_string + .split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) + .map(str::to_lowercase) + .collect(); + + let iterators = words.clone().into_iter().map(move |word| { + let mut prefix2 = prefix.clone(); + prefix2.extend_from_slice(word.as_bytes()); + prefix2.push(0xFF); + let prefix3 = prefix2.clone(); + + let mut last_possible_id = prefix2.clone(); + last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.tokenids + .iter_from(&last_possible_id, true) // Newest pdus first + .take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(move |(key, _)| key[prefix3.len()..].to_vec()) + }); + + let Some(common_elements) = utils::common_elements(iterators, |a, b| { + // We compare b with a because we reversed the iterator earlier + b.cmp(a) + }) else { + return Ok(None); + }; + + Ok(Some((Box::new(common_elements), words))) + } +} diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 569761a3..80ac45ae 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -2,7 +2,7 @@ mod data; use std::sync::Arc; -pub use data::Data; +use data::Data; use ruma::RoomId; use crate::Result; diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index d0e2085f..11ac39b5 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use ruma::{events::StateEventType, EventId, RoomId}; +use tracing::warn; -use crate::Result; +use crate::{services, utils, Error, KeyValueDatabase, Result}; pub trait Data: Send + Sync { fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result; @@ -24,3 +25,161 @@ pub trait Data: Send + Sync { fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result; } + +impl Data for KeyValueDatabase { + fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { + let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { + utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? + } else { + let shorteventid = services().globals.next_count()?; + self.eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; + shorteventid + }; + + Ok(short) + } + + fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result> { + let mut ret: Vec = Vec::with_capacity(event_ids.len()); + let keys = event_ids + .iter() + .map(|id| id.as_bytes()) + .collect::>(); + for (i, short) in self + .eventid_shorteventid + .multi_get(&keys)? + .iter() + .enumerate() + { + match short { + Some(short) => ret.push( + utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, + ), + None => { + let short = services().globals.next_count()?; + self.eventid_shorteventid + .insert(keys[i], &short.to_be_bytes())?; + self.shorteventid_eventid + .insert(&short.to_be_bytes(), keys[i])?; + + debug_assert!(ret.len() == i, "position of result must match input"); + ret.push(short); + }, + } + } + + Ok(ret) + } + + fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { + let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); + statekey_vec.push(0xFF); + statekey_vec.extend_from_slice(state_key.as_bytes()); + + let short = self + .statekey_shortstatekey + .get(&statekey_vec)? + .map(|shortstatekey| { + utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) + }) + .transpose()?; + + Ok(short) + } + + fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { + let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); + statekey_vec.push(0xFF); + statekey_vec.extend_from_slice(state_key.as_bytes()); + + let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? { + utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))? + } else { + let shortstatekey = services().globals.next_count()?; + self.statekey_shortstatekey + .insert(&statekey_vec, &shortstatekey.to_be_bytes())?; + self.shortstatekey_statekey + .insert(&shortstatekey.to_be_bytes(), &statekey_vec)?; + shortstatekey + }; + + Ok(short) + } + + fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { + let bytes = self + .shorteventid_eventid + .get(&shorteventid.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; + + let event_id = EventId::parse_arc( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; + + Ok(event_id) + } + + fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + let bytes = self + .shortstatekey_statekey + .get(&shortstatekey.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; + + let mut parts = bytes.splitn(2, |&b| b == 0xFF); + let eventtype_bytes = parts.next().expect("split always returns one entry"); + let statekey_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; + + let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| { + warn!("Event type in shortstatekey_statekey is invalid: {}", e); + Error::bad_database("Event type in shortstatekey_statekey is invalid.") + })?); + + let state_key = utils::string_from_bytes(statekey_bytes) + .map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?; + + let result = (event_type, state_key); + + Ok(result) + } + + /// Returns (shortstatehash, already_existed) + fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { + Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? { + ( + utils::u64_from_bytes(&shortstatehash) + .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, + true, + ) + } else { + let shortstatehash = services().globals.next_count()?; + self.statehash_shortstatehash + .insert(state_hash, &shortstatehash.to_be_bytes())?; + (shortstatehash, false) + }) + } + + fn get_shortroomid(&self, room_id: &RoomId) -> Result> { + self.roomid_shortroomid + .get(room_id.as_bytes())? + .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))) + .transpose() + } + + fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { + Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { + utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))? + } else { + let short = services().globals.next_count()?; + self.roomid_shortroomid + .insert(room_id.as_bytes(), &short.to_be_bytes())?; + short + }) + } +} diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 657de66a..2e994c3c 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,7 +1,7 @@ mod data; use std::sync::Arc; -pub use data::Data; +use data::Data; use ruma::{events::StateEventType, EventId, RoomId}; use crate::Result; diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index f486f1f8..f0fef086 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -3,7 +3,7 @@ use std::{collections::HashSet, sync::Arc}; use ruma::{EventId, OwnedEventId, RoomId}; use tokio::sync::MutexGuard; -use crate::Result; +use crate::{utils, Error, KeyValueDatabase, Result}; pub trait Data: Send + Sync { /// Returns the last state hash key added to the db for the given room. @@ -31,3 +31,70 @@ pub trait Data: Send + Sync { _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result<()>; } + +impl Data for KeyValueDatabase { + fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { + self.roomid_shortstatehash + .get(room_id.as_bytes())? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") + })?)) + }) + } + + fn set_room_state( + &self, + room_id: &RoomId, + new_shortstatehash: u64, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + self.roomid_shortstatehash + .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; + Ok(()) + } + + fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { + self.shorteventid_shortstatehash + .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; + Ok(()) + } + + fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + self.roomid_pduleaves + .scan_prefix(prefix) + .map(|(_, bytes)| { + EventId::parse_arc( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) + }) + .collect() + } + + fn set_forward_extremities( + &self, + room_id: &RoomId, + event_ids: Vec, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { + self.roomid_pduleaves.remove(&key)?; + } + + for event_id in event_ids { + let mut key = prefix.clone(); + key.extend_from_slice(event_id.as_bytes()); + self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; + } + + Ok(()) + } +} diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 8031d566..5a05b162 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -4,7 +4,7 @@ use std::{ sync::Arc, }; -pub use data::Data; +use data::Data; use ruma::{ api::client::error::ErrorKind, events::{ diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 5fd58864..a7b39e8c 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use ruma::{events::StateEventType, EventId, RoomId}; -use crate::{PduEvent, Result}; +use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; #[async_trait] pub trait Data: Send + Sync { @@ -46,3 +46,162 @@ pub trait Data: Send + Sync { &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result>>; } + +#[async_trait] +impl Data for KeyValueDatabase { + #[allow(unused_qualifications)] // async traits + async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + let full_state = services() + .rooms + .state_compressor + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + let mut result = HashMap::new(); + let mut i: u8 = 0; + for compressed in full_state.iter() { + let parsed = services() + .rooms + .state_compressor + .parse_compressed_state_event(compressed)?; + result.insert(parsed.0, parsed.1); + + i = i.wrapping_add(1); + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + Ok(result) + } + + #[allow(unused_qualifications)] // async traits + async fn state_full(&self, shortstatehash: u64) -> Result>> { + let full_state = services() + .rooms + .state_compressor + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + + let mut result = HashMap::new(); + let mut i: u8 = 0; + for compressed in full_state.iter() { + let (_, eventid) = services() + .rooms + .state_compressor + .parse_compressed_state_event(compressed)?; + if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? { + result.insert( + ( + pdu.kind.to_string().into(), + pdu.state_key + .as_ref() + .ok_or_else(|| Error::bad_database("State event has no state key."))? + .clone(), + ), + pdu, + ); + } + + i = i.wrapping_add(1); + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + + Ok(result) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + fn state_get_id( + &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + ) -> Result>> { + let Some(shortstatekey) = services() + .rooms + .short + .get_shortstatekey(event_type, state_key)? + else { + return Ok(None); + }; + let full_state = services() + .rooms + .state_compressor + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + Ok(full_state + .iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + .and_then(|compressed| { + services() + .rooms + .state_compressor + .parse_compressed_state_event(compressed) + .ok() + .map(|(_, id)| id) + })) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + fn state_get( + &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + ) -> Result>> { + self.state_get_id(shortstatehash, event_type, state_key)? + .map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id)) + } + + /// Returns the state hash for this pdu. + fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { + self.eventid_shorteventid + .get(event_id.as_bytes())? + .map_or(Ok(None), |shorteventid| { + self.shorteventid_shortstatehash + .get(&shorteventid)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash") + }) + }) + .transpose() + }) + } + + /// Returns the full room state. + #[allow(unused_qualifications)] // async traits + async fn room_state_full(&self, room_id: &RoomId) -> Result>> { + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + self.state_full(current_shortstatehash).await + } else { + Ok(HashMap::new()) + } + } + + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + fn room_state_get_id( + &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, + ) -> Result>> { + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + self.state_get_id(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } + + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + fn room_state_get( + &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, + ) -> Result>> { + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + self.state_get(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } +} diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index d2e51361..a05ba719 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -4,7 +4,7 @@ use std::{ sync::{Arc, Mutex}, }; -pub use data::Data; +use data::Data; use lru_cache::LruCache; use ruma::{ events::{ diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 70fcd6d1..08e93d89 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,15 +1,21 @@ use std::{collections::HashSet, sync::Arc}; +use itertools::Itertools; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; +use tracing::error; -use crate::{service::appservice::RegistrationInfo, Result}; +use crate::{ + appservice::RegistrationInfo, + services, user_is_local, + utils::{self}, + Error, KeyValueDatabase, Result, +}; type StrippedStateEventIter<'a> = Box>)>> + 'a>; - type AnySyncStateEventIter<'a> = Box>)>> + 'a>; pub trait Data: Send + Sync { @@ -91,3 +97,609 @@ pub trait Data: Send + Sync { #[allow(dead_code)] fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()>; } + +impl Data for KeyValueDatabase { + fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + self.roomuseroncejoinedids.insert(&userroom_id, &[]) + } + + fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let roomid = room_id.as_bytes().to_vec(); + let mut roomid_prefix = room_id.as_bytes().to_vec(); + roomid_prefix.push(0xFF); + + let mut roomuser_id = roomid_prefix.clone(); + roomuser_id.push(0xFF); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_joined.insert(&userroom_id, &[])?; + self.roomuserid_joined.insert(&roomuser_id, &[])?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + if self + .roomuserid_joined + .scan_prefix(roomid_prefix.clone()) + .count() == 0 + && self + .roomuserid_invitecount + .scan_prefix(roomid_prefix) + .count() == 0 + { + self.roomid_inviteviaservers.remove(&roomid)?; + } + + Ok(()) + } + + fn mark_as_invited( + &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, + invite_via: Option>, + ) -> Result<()> { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xFF); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_invitestate.insert( + &userroom_id, + &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), + )?; + self.roomuserid_invitecount + .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + self.userroomid_joined.remove(&userroom_id)?; + self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + if let Some(servers) = invite_via { + let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); + #[allow(clippy::redundant_clone)] // this is a necessary clone? + prev_servers.append(servers.clone().as_mut()); + let servers = prev_servers.iter().rev().unique().rev().collect_vec(); + + let servers = servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xFF][..]); + + self.roomid_inviteviaservers + .insert(room_id.as_bytes(), &servers)?; + } + + Ok(()) + } + + fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let roomid = room_id.as_bytes().to_vec(); + let mut roomid_prefix = room_id.as_bytes().to_vec(); + roomid_prefix.push(0xFF); + + let mut roomuser_id = roomid_prefix.clone(); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_leftstate.insert( + &userroom_id, + &serde_json::to_vec(&Vec::>::new()).unwrap(), + )?; // TODO + self.roomuserid_leftcount + .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + self.userroomid_joined.remove(&userroom_id)?; + self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; + + if self + .roomuserid_joined + .scan_prefix(roomid_prefix.clone()) + .count() == 0 + && self + .roomuserid_invitecount + .scan_prefix(roomid_prefix) + .count() == 0 + { + self.roomid_inviteviaservers.remove(&roomid)?; + } + + Ok(()) + } + + fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { + let mut joinedcount = 0_u64; + let mut invitedcount = 0_u64; + let mut joined_servers = HashSet::new(); + let mut real_users = HashSet::new(); + + for joined in self.room_members(room_id).filter_map(Result::ok) { + joined_servers.insert(joined.server_name().to_owned()); + if user_is_local(&joined) && !services().users.is_deactivated(&joined).unwrap_or(true) { + real_users.insert(joined); + } + joinedcount = joinedcount.saturating_add(1); + } + + for _invited in self.room_members_invited(room_id).filter_map(Result::ok) { + invitedcount = invitedcount.saturating_add(1); + } + + self.roomid_joinedcount + .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; + + self.roomid_invitedcount + .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; + + self.our_real_users_cache + .write() + .unwrap() + .insert(room_id.to_owned(), Arc::new(real_users)); + + for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) { + if !joined_servers.remove(&old_joined_server) { + // Server not in room anymore + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xFF); + roomserver_id.extend_from_slice(old_joined_server.as_bytes()); + + let mut serverroom_id = old_joined_server.as_bytes().to_vec(); + serverroom_id.push(0xFF); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.roomserverids.remove(&roomserver_id)?; + self.serverroomids.remove(&serverroom_id)?; + } + } + + // Now only new servers are in joined_servers anymore + for server in joined_servers { + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xFF); + roomserver_id.extend_from_slice(server.as_bytes()); + + let mut serverroom_id = server.as_bytes().to_vec(); + serverroom_id.push(0xFF); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.roomserverids.insert(&roomserver_id, &[])?; + self.serverroomids.insert(&serverroom_id, &[])?; + } + + self.appservice_in_room_cache + .write() + .unwrap() + .remove(room_id); + + Ok(()) + } + + #[tracing::instrument(skip(self, room_id))] + fn get_our_real_users(&self, room_id: &RoomId) -> Result>> { + let maybe = self + .our_real_users_cache + .read() + .unwrap() + .get(room_id) + .cloned(); + if let Some(users) = maybe { + Ok(users) + } else { + self.update_joined_count(room_id)?; + Ok(Arc::clone( + self.our_real_users_cache + .read() + .unwrap() + .get(room_id) + .unwrap(), + )) + } + } + + #[tracing::instrument(skip(self, room_id, appservice))] + fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { + let maybe = self + .appservice_in_room_cache + .read() + .unwrap() + .get(room_id) + .and_then(|map| map.get(&appservice.registration.id)) + .copied(); + + if let Some(b) = maybe { + Ok(b) + } else { + let bridge_user_id = UserId::parse_with_server_name( + appservice.registration.sender_localpart.as_str(), + services().globals.server_name(), + ) + .ok(); + + let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) + || self + .room_members(room_id) + .any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str()))); + + self.appservice_in_room_cache + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default() + .insert(appservice.registration.id.clone(), in_room); + + Ok(in_room) + } + } + + /// Makes a user forget a room. + #[tracing::instrument(skip(self))] + fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xFF); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + Ok(()) + } + + /// Returns an iterator of all servers participating in this room. + #[tracing::instrument(skip(self))] + fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { + ServerName::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) + })) + } + + #[tracing::instrument(skip(self))] + fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { + let mut key = server.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + + self.serverroomids.get(&key).map(|o| o.is_some()) + } + + /// Returns an iterator of all rooms a server participates in (as far as we + /// know). + #[tracing::instrument(skip(self))] + fn server_rooms<'a>(&'a self, server: &ServerName) -> Box> + 'a> { + let mut prefix = server.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { + RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) + })) + } + + /// Returns an iterator over all joined members of a room. + #[tracing::instrument(skip(self))] + fn room_members<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { + UserId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) + })) + } + + /// Returns the number of users which are currently in a room + #[tracing::instrument(skip(self))] + fn room_joined_count(&self, room_id: &RoomId) -> Result> { + self.roomid_joinedcount + .get(room_id.as_bytes())? + .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) + .transpose() + } + + /// Returns the number of users which are currently invited to a room + #[tracing::instrument(skip(self))] + fn room_invited_count(&self, room_id: &RoomId) -> Result> { + self.roomid_invitedcount + .get(room_id.as_bytes())? + .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) + .transpose() + } + + /// Returns an iterator over all User IDs who ever joined a room. + #[tracing::instrument(skip(self))] + fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new( + self.roomuseroncejoinedids + .scan_prefix(prefix) + .map(|(key, _)| { + UserId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) + }), + ) + } + + /// Returns an iterator over all invited members of a room. + #[tracing::instrument(skip(self))] + fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new( + self.roomuserid_invitecount + .scan_prefix(prefix) + .map(|(key, _)| { + UserId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) + }), + ) + } + + #[tracing::instrument(skip(self))] + fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_invitecount + .get(&key)? + .map_or(Ok(None), |bytes| { + Ok(Some( + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?, + )) + }) + } + + #[tracing::instrument(skip(self))] + fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_leftcount + .get(&key)? + .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db."))) + .transpose() + } + + /// Returns an iterator over all rooms this user joined. + #[tracing::instrument(skip(self))] + fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { + Box::new( + self.userroomid_joined + .scan_prefix(user_id.as_bytes().to_vec()) + .map(|(key, _)| { + RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) + }), + ) + } + + /// Returns an iterator over all rooms a user was invited to. + #[tracing::instrument(skip(self))] + fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new( + self.userroomid_invitestate + .scan_prefix(prefix) + .map(|(key, state)| { + let room_id = RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; + + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; + + Ok((room_id, state)) + }), + ) + } + + #[tracing::instrument(skip(self))] + fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + + self.userroomid_invitestate + .get(&key)? + .map(|state| { + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; + + Ok(state) + }) + .transpose() + } + + #[tracing::instrument(skip(self))] + fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + + self.userroomid_leftstate + .get(&key)? + .map(|state| { + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; + + Ok(state) + }) + .transpose() + } + + /// Returns an iterator over all rooms a user left. + #[tracing::instrument(skip(self))] + fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new( + self.userroomid_leftstate + .scan_prefix(prefix) + .map(|(key, state)| { + let room_id = RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; + + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; + + Ok((room_id, state)) + }), + ) + } + + #[tracing::instrument(skip(self))] + fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) + } + + #[tracing::instrument(skip(self))] + fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) + } + + #[tracing::instrument(skip(self))] + fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) + } + + #[tracing::instrument(skip(self))] + fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) + } + + #[tracing::instrument(skip(self))] + fn servers_invite_via(&self, room_id: &RoomId) -> Result>> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + + self.roomid_inviteviaservers + .get(&key)? + .map(|servers| { + let state = serde_json::from_slice(&servers).map_err(|e| { + error!("Invalid state in userroomid_leftstate: {e}"); + Error::bad_database("Invalid state in userroomid_leftstate.") + })?; + + Ok(state) + }) + .transpose() + } + + #[tracing::instrument(skip(self))] + fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { + let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); + prev_servers.append(servers.to_owned().as_mut()); + + let servers = prev_servers.iter().rev().unique().rev().collect_vec(); + + let servers = servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xFF][..]); + + self.roomid_inviteviaservers + .insert(room_id.as_bytes(), &servers)?; + + Ok(()) + } +} diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 976d858b..c9ac278c 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, sync::Arc}; -pub use data::Data; +use data::Data; use itertools::Itertools; use ruma::{ events::{ diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index eddc8716..4612ebc6 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -1,7 +1,7 @@ -use std::{collections::HashSet, sync::Arc}; +use std::{collections::HashSet, mem::size_of, sync::Arc}; use super::CompressedStateEvent; -use crate::Result; +use crate::{utils, Error, KeyValueDatabase, Result}; pub struct StateDiff { pub parent: Option, @@ -13,3 +13,60 @@ pub trait Data: Send + Sync { fn get_statediff(&self, shortstatehash: u64) -> Result; fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()>; } + +impl Data for KeyValueDatabase { + fn get_statediff(&self, shortstatehash: u64) -> Result { + let value = self + .shortstatehash_statediff + .get(&shortstatehash.to_be_bytes())? + .ok_or_else(|| Error::bad_database("State hash does not exist"))?; + let parent = utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); + let parent = if parent != 0 { + Some(parent) + } else { + None + }; + + let mut add_mode = true; + let mut added = HashSet::new(); + let mut removed = HashSet::new(); + + let mut i = size_of::(); + while let Some(v) = value.get(i..i + 2 * size_of::()) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i += size_of::(); + continue; + } + if add_mode { + added.insert(v.try_into().expect("we checked the size above")); + } else { + removed.insert(v.try_into().expect("we checked the size above")); + } + i += 2 * size_of::(); + } + + Ok(StateDiff { + parent, + added: Arc::new(added), + removed: Arc::new(removed), + }) + } + + fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { + let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); + for new in diff.added.iter() { + value.extend_from_slice(&new[..]); + } + + if !diff.removed.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in diff.removed.iter() { + value.extend_from_slice(&removed[..]); + } + } + + self.shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value) + } +} diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index a3622b7b..3f025b47 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -1,11 +1,11 @@ -pub mod data; +mod data; use std::{ collections::HashSet, mem::size_of, sync::{Arc, Mutex}, }; -pub use data::Data; +use data::Data; use lru_cache::LruCache; use ruma::{EventId, RoomId}; diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index b18f4b79..3974ca02 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,6 +1,8 @@ +use std::mem; + use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; -use crate::{PduEvent, Result}; +use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; type PduEventIterResult<'a> = Result> + 'a>>; @@ -12,3 +14,71 @@ pub trait Data: Send + Sync { fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()>; fn get_participants(&self, root_id: &[u8]) -> Result>>; } + +impl Data for KeyValueDatabase { + fn threads_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, + ) -> PduEventIterResult<'a> { + let prefix = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut current = prefix.clone(); + current.extend_from_slice(&(until - 1).to_be_bytes()); + + Ok(Box::new( + self.threadid_userids + .iter_from(¤t, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pduid, _users)| { + let count = utils::u64_from_bytes(&pduid[(mem::size_of::())..]) + .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; + let mut pdu = services() + .rooms + .timeline + .get_pdu_from_id(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((count, pdu)) + }), + )) + } + + fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { + let users = participants + .iter() + .map(|user| user.as_bytes()) + .collect::>() + .join(&[0xFF][..]); + + self.threadid_userids.insert(root_id, &users)?; + + Ok(()) + } + + fn get_participants(&self, root_id: &[u8]) -> Result>> { + if let Some(users) = self.threadid_userids.get(root_id)? { + Ok(Some( + users + .split(|b| *b == 0xFF) + .map(|bytes| { + UserId::parse( + utils::string_from_bytes(bytes) + .map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?, + ) + .map_err(|_| Error::bad_database("Invalid UserId in threadid_userids.")) + }) + .filter_map(Result::ok) + .collect(), + )) + } else { + Ok(None) + } + } +} diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index 05833a91..a7d5c434 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -2,7 +2,7 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; -pub use data::Data; +use data::Data; use ruma::{ api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads}, events::relation::BundledThread, diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index a036b455..9fb1eea4 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -1,9 +1,10 @@ -use std::sync::Arc; +use std::{collections::hash_map, mem::size_of, sync::Arc}; -use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; +use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; +use tracing::error; use super::PduCount; -use crate::{PduEvent, Result}; +use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; pub trait Data: Send + Sync { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; @@ -66,3 +67,292 @@ pub trait Data: Send + Sync { &self, room_id: &RoomId, notifies: Vec, highlights: Vec, ) -> Result<()>; } + +impl Data for KeyValueDatabase { + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + match self + .lasttimelinecount_cache + .lock() + .unwrap() + .entry(room_id.to_owned()) + { + hash_map::Entry::Vacant(v) => { + if let Some(last_count) = self + .pdus_until(sender_user, room_id, PduCount::max())? + .find_map(|r| { + // Filter out buggy events + if r.is_err() { + error!("Bad pdu in pdus_since: {:?}", r); + } + r.ok() + }) { + Ok(*v.insert(last_count.0)) + } else { + Ok(PduCount::Normal(0)) + } + }, + hash_map::Entry::Occupied(o) => Ok(*o.get()), + } + } + + /// Returns the `count` of this pdu's id. + fn get_pdu_count(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pdu_id| pdu_count(&pdu_id)) + .transpose() + } + + /// Returns the json of a pdu. + fn get_pdu_json(&self, event_id: &EventId) -> Result> { + self.get_non_outlier_pdu_json(event_id)?.map_or_else( + || { + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) + .transpose() + }, + |x| Ok(Some(x)), + ) + } + + /// Returns the json of a pdu. + fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pduid| { + self.pduid_pdu + .get(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + }) + .transpose()? + .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) + .transpose() + } + + /// Returns the pdu's id. + fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.eventid_pduid.get(event_id.as_bytes()) } + + /// Returns the pdu. + fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pduid| { + self.pduid_pdu + .get(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + }) + .transpose()? + .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) + .transpose() + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + fn get_pdu(&self, event_id: &EventId) -> Result>> { + if let Some(pdu) = self + .get_non_outlier_pdu(event_id)? + .map_or_else( + || { + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) + .transpose() + }, + |x| Ok(Some(x)), + )? + .map(Arc::new) + { + Ok(Some(pdu)) + } else { + Ok(None) + } + } + + /// Returns the pdu. + /// + /// This does __NOT__ check the outliers `Tree`. + fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { + self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } + + /// Returns the pdu as a `BTreeMap`. + fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { + self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } + + fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> { + self.pduid_pdu.insert( + pdu_id, + &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), + )?; + + self.lasttimelinecount_cache + .lock() + .unwrap() + .insert(pdu.room_id.clone(), PduCount::Normal(count)); + + self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; + self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; + + Ok(()) + } + + fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> { + self.pduid_pdu.insert( + pdu_id, + &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), + )?; + + self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?; + self.eventid_outlierpdu.remove(event_id.as_bytes())?; + + Ok(()) + } + + /// Removes a pdu and creates a new one with the same id. + fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> { + if self.pduid_pdu.get(pdu_id)?.is_some() { + self.pduid_pdu.insert( + pdu_id, + &serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"), + )?; + } else { + return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist.")); + } + + Ok(()) + } + + /// Returns an iterator over all events and their tokens in a room that + /// happened before the event with id `until` in reverse-chronological + /// order. + fn pdus_until<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, + ) -> Result> + 'a>> { + let (prefix, current) = count_to_id(room_id, until, 1, true)?; + + let user_id = user_id.to_owned(); + + Ok(Box::new( + self.pduid_pdu + .iter_from(¤t, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + pdu.add_age()?; + let count = pdu_count(&pdu_id)?; + Ok((count, pdu)) + }), + )) + } + + fn pdus_after<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, + ) -> Result> + 'a>> { + let (prefix, current) = count_to_id(room_id, from, 1, false)?; + + let user_id = user_id.to_owned(); + + Ok(Box::new( + self.pduid_pdu + .iter_from(¤t, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + pdu.add_age()?; + let count = pdu_count(&pdu_id)?; + Ok((count, pdu)) + }), + )) + } + + fn increment_notification_counts( + &self, room_id: &RoomId, notifies: Vec, highlights: Vec, + ) -> Result<()> { + let mut notifies_batch = Vec::new(); + let mut highlights_batch = Vec::new(); + for user in notifies { + let mut userroom_id = user.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + notifies_batch.push(userroom_id); + } + for user in highlights { + let mut userroom_id = user.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + highlights_batch.push(userroom_id); + } + + self.userroomid_notificationcount + .increment_batch(&mut notifies_batch.into_iter())?; + self.userroomid_highlightcount + .increment_batch(&mut highlights_batch.into_iter())?; + Ok(()) + } +} + +/// Returns the `count` of this pdu's id. +fn pdu_count(pdu_id: &[u8]) -> Result { + let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) + .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; + let second_last_u64 = + utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::()..pdu_id.len() - size_of::()]); + + if matches!(second_last_u64, Ok(0)) { + Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64))) + } else { + Ok(PduCount::Normal(last_u64)) + } +} + +fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec, Vec)> { + let prefix = services() + .rooms + .short + .get_shortroomid(room_id)? + .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? + .to_be_bytes() + .to_vec(); + let mut pdu_id = prefix.clone(); + // +1 so we don't send the base event + let count_raw = match count { + PduCount::Normal(x) => { + if subtract { + x.saturating_sub(offset) + } else { + x.saturating_add(offset) + } + }, + PduCount::Backfilled(x) => { + pdu_id.extend_from_slice(&0_u64.to_be_bytes()); + let num = u64::MAX.saturating_sub(x); + if subtract { + num.saturating_sub(offset) + } else { + num.saturating_add(offset) + } + }, + }; + pdu_id.extend_from_slice(&count_raw.to_be_bytes()); + + Ok((prefix, pdu_id)) +} diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 82266437..4d91375f 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,11 +1,11 @@ -pub mod data; +mod data; use std::{ collections::{BTreeMap, HashMap, HashSet}, sync::Arc, }; -pub use data::Data; +use data::Data; use rand::prelude::SliceRandom; use ruma::{ api::{client::error::ErrorKind, federation}, @@ -195,7 +195,7 @@ impl Service { state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result> { // Coalesce database writes for the remainder of this scope. - let _cork = services().globals.db.cork_and_flush()?; + let _cork = services().globals.cork_and_flush()?; let shortroomid = services() .rooms diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index 2fd1c29e..3e8587da 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,6 +1,6 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; -use crate::Result; +use crate::{services, utils, Error, KeyValueDatabase, Result}; pub trait Data: Send + Sync { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; @@ -20,3 +20,137 @@ pub trait Data: Send + Sync { &'a self, users: Vec, ) -> Result> + 'a>>; } + +impl Data for KeyValueDatabase { + fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xFF); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + self.userroomid_notificationcount + .insert(&userroom_id, &0_u64.to_be_bytes())?; + self.userroomid_highlightcount + .insert(&userroom_id, &0_u64.to_be_bytes())?; + + self.roomuserid_lastnotificationread + .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + + Ok(()) + } + + fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_notificationcount + .get(&userroom_id)? + .map_or(Ok(0), |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db.")) + }) + } + + fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_highlightcount + .get(&userroom_id)? + .map_or(Ok(0), |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db.")) + }) + } + + fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); + + Ok(self + .roomuserid_lastnotificationread + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) + }) + .transpose()? + .unwrap_or(0)) + } + + fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { + let shortroomid = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists"); + + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(&token.to_be_bytes()); + + self.roomsynctoken_shortstatehash + .insert(&key, &shortstatehash.to_be_bytes()) + } + + fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { + let shortroomid = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists"); + + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(&token.to_be_bytes()); + + self.roomsynctoken_shortstatehash + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")) + }) + .transpose() + } + + fn get_shared_rooms<'a>( + &'a self, users: Vec, + ) -> Result> + 'a>> { + let iterators = users.into_iter().map(move |user_id| { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + self.userroomid_joined + .scan_prefix(prefix) + .map(|(key, _)| { + let roomid_index = key + .iter() + .enumerate() + .find(|(_, &b)| b == 0xFF) + .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? + .0 + .saturating_add(1); // +1 because the room id starts AFTER the separator + + let room_id = key[roomid_index..].to_vec(); + + Ok::<_, Error>(room_id) + }) + .filter_map(Result::ok) + }); + + // We use the default compare function because keys are sorted correctly (not + // reversed) + Ok(Box::new( + utils::common_elements(iterators, Ord::cmp) + .expect("users is not empty") + .map(|bytes| { + RoomId::parse( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?, + ) + .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) + }), + )) + } +} diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 5f4d4708..e589a444 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -2,7 +2,7 @@ mod data; use std::sync::Arc; -pub use data::Data; +use data::Data; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use crate::Result; diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 41479021..9057c603 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -1,7 +1,7 @@ -use ruma::ServerName; +use ruma::{ServerName, UserId}; use super::{Destination, SendingEvent}; -use crate::Result; +use crate::{services, utils, Error, KeyValueDatabase, Result}; type OutgoingSendingIter<'a> = Box, Destination, SendingEvent)>> + 'a>; type SendingEventIter<'a> = Box, SendingEvent)>> + 'a>; @@ -23,3 +23,188 @@ pub trait Data: Send + Sync { fn set_latest_educount(&self, server_name: &ServerName, educount: u64) -> Result<()>; fn get_latest_educount(&self, server_name: &ServerName) -> Result; } + +impl Data for KeyValueDatabase { + fn active_requests<'a>(&'a self) -> Box, Destination, SendingEvent)>> + 'a> { + Box::new( + self.servercurrentevent_data + .iter() + .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), + ) + } + + fn active_requests_for<'a>( + &'a self, destination: &Destination, + ) -> Box, SendingEvent)>> + 'a> { + let prefix = destination.get_prefix(); + Box::new( + self.servercurrentevent_data + .scan_prefix(prefix) + .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))), + ) + } + + fn delete_active_request(&self, key: Vec) -> Result<()> { self.servercurrentevent_data.remove(&key) } + + fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> { + let prefix = destination.get_prefix(); + for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { + self.servercurrentevent_data.remove(&key)?; + } + + Ok(()) + } + + fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> { + let prefix = destination.get_prefix(); + for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { + self.servercurrentevent_data.remove(&key).unwrap(); + } + + for (key, _) in self.servernameevent_data.scan_prefix(prefix) { + self.servernameevent_data.remove(&key).unwrap(); + } + + Ok(()) + } + + fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result>> { + let mut batch = Vec::new(); + let mut keys = Vec::new(); + for (destination, event) in requests { + let mut key = destination.get_prefix(); + if let SendingEvent::Pdu(value) = &event { + key.extend_from_slice(value); + } else { + key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + } + let value = if let SendingEvent::Edu(value) = &event { + &**value + } else { + &[] + }; + batch.push((key.clone(), value.to_owned())); + keys.push(key); + } + self.servernameevent_data + .insert_batch(&mut batch.into_iter())?; + Ok(keys) + } + + fn queued_requests<'a>( + &'a self, destination: &Destination, + ) -> Box)>> + 'a> { + let prefix = destination.get_prefix(); + return Box::new( + self.servernameevent_data + .scan_prefix(prefix) + .map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))), + ); + } + + fn mark_as_active(&self, events: &[(SendingEvent, Vec)]) -> Result<()> { + for (e, key) in events { + if key.is_empty() { + continue; + } + + let value = if let SendingEvent::Edu(value) = &e { + &**value + } else { + &[] + }; + self.servercurrentevent_data.insert(key, value)?; + self.servernameevent_data.remove(key)?; + } + + Ok(()) + } + + fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { + self.servername_educount + .insert(server_name.as_bytes(), &last_count.to_be_bytes()) + } + + fn get_latest_educount(&self, server_name: &ServerName) -> Result { + self.servername_educount + .get(server_name.as_bytes())? + .map_or(Ok(0), |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) + }) + } +} + +#[tracing::instrument(skip(key))] +fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, SendingEvent)> { + // Appservices start with a plus + Ok::<_, Error>(if key.starts_with(b"+") { + let mut parts = key[1..].splitn(2, |&b| b == 0xFF); + + let server = parts.next().expect("splitn always returns one element"); + let event = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + + let server = utils::string_from_bytes(server) + .map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?; + + ( + Destination::Appservice(server), + if value.is_empty() { + SendingEvent::Pdu(event.to_vec()) + } else { + SendingEvent::Edu(value) + }, + ) + } else if key.starts_with(b"$") { + let mut parts = key[1..].splitn(3, |&b| b == 0xFF); + + let user = parts.next().expect("splitn always returns one element"); + let user_string = utils::string_from_bytes(user) + .map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?; + let user_id = + UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?; + + let pushkey = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let pushkey_string = utils::string_from_bytes(pushkey) + .map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?; + + let event = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + + ( + Destination::Push(user_id, pushkey_string), + if value.is_empty() { + SendingEvent::Pdu(event.to_vec()) + } else { + // I'm pretty sure this should never be called + SendingEvent::Edu(value) + }, + ) + } else { + let mut parts = key.splitn(2, |&b| b == 0xFF); + + let server = parts.next().expect("splitn always returns one element"); + let event = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + + let server = utils::string_from_bytes(server) + .map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?; + + ( + Destination::Normal( + ServerName::parse(server) + .map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?, + ), + if value.is_empty() { + SendingEvent::Pdu(event.to_vec()) + } else { + SendingEvent::Edu(value) + }, + ) + }) +} diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index b4a6fdeb..a9f64f7b 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -1,6 +1,6 @@ use std::{fmt::Debug, sync::Arc}; -pub use data::Data; +use data::Data; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, @@ -81,7 +81,7 @@ impl Service { pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { let dest = Destination::Push(user.to_owned(), pushkey); let event = SendingEvent::Pdu(pdu_id.to_owned()); - let _cork = services().globals.db.cork()?; + let _cork = services().globals.cork()?; let keys = self.db.queue_requests(&[(&dest, event.clone())])?; self.dispatch(Msg { dest, @@ -94,7 +94,7 @@ impl Service { pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { let dest = Destination::Appservice(appservice_id); let event = SendingEvent::Pdu(pdu_id); - let _cork = services().globals.db.cork()?; + let _cork = services().globals.cork()?; let keys = self.db.queue_requests(&[(&dest, event.clone())])?; self.dispatch(Msg { dest, @@ -121,7 +121,7 @@ impl Service { .into_iter() .map(|server| (Destination::Normal(server), SendingEvent::Pdu(pdu_id.to_owned()))) .collect::>(); - let _cork = services().globals.db.cork()?; + let _cork = services().globals.cork()?; let keys = self.db.queue_requests( &requests .iter() @@ -143,7 +143,7 @@ impl Service { pub fn send_edu_server(&self, server: &ServerName, serialized: Vec) -> Result<()> { let dest = Destination::Normal(server.to_owned()); let event = SendingEvent::Edu(serialized); - let _cork = services().globals.db.cork()?; + let _cork = services().globals.cork()?; let keys = self.db.queue_requests(&[(&dest, event.clone())])?; self.dispatch(Msg { dest, @@ -170,7 +170,7 @@ impl Service { .into_iter() .map(|server| (Destination::Normal(server), SendingEvent::Edu(serialized.clone()))) .collect::>(); - let _cork = services().globals.db.cork()?; + let _cork = services().globals.cork()?; let keys = self.db.queue_requests( &requests .iter() diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 3d1d89da..8bb93105 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -100,7 +100,7 @@ impl Service { fn handle_response_ok( &self, dest: &Destination, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus, ) { - let _cork = services().globals.db.cork(); + let _cork = services().globals.cork(); self.db .delete_all_active_requests_for(dest) .expect("all active requests deleted"); @@ -173,7 +173,7 @@ impl Service { return Ok(None); } - let _cork = services().globals.db.cork(); + let _cork = services().globals.cork(); let mut events = Vec::new(); // Must retry any previous transaction for this remote. @@ -187,7 +187,7 @@ impl Service { } // Compose the next transaction - let _cork = services().globals.db.cork(); + let _cork = services().globals.cork(); if !new_events.is_empty() { self.db.mark_as_active(&new_events)?; for (e, _) in new_events { diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs index 2aed1981..8ea3b8fd 100644 --- a/src/service/transaction_ids/data.rs +++ b/src/service/transaction_ids/data.rs @@ -1,8 +1,8 @@ use ruma::{DeviceId, TransactionId, UserId}; -use crate::Result; +use crate::{KeyValueDatabase, Result}; -pub trait Data: Send + Sync { +pub(crate) trait Data: Send + Sync { fn add_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], ) -> Result<()>; @@ -11,3 +11,32 @@ pub trait Data: Send + Sync { &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, ) -> Result>>; } + +impl Data for KeyValueDatabase { + fn add_txnid( + &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); + key.push(0xFF); + key.extend_from_slice(txn_id.as_bytes()); + + self.userdevicetxnid_response.insert(&key, data)?; + + Ok(()) + } + + fn existing_txnid( + &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, + ) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); + key.push(0xFF); + key.extend_from_slice(txn_id.as_bytes()); + + // If there's no entry, this is a new transaction + self.userdevicetxnid_response.get(&key) + } +} diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index ba9869e7..e986f0ac 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -2,13 +2,13 @@ mod data; use std::sync::Arc; -pub use data::Data; +use data::Data; use ruma::{DeviceId, TransactionId, UserId}; use crate::Result; pub struct Service { - pub db: Arc, + pub(super) db: Arc, } impl Service { diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs index 3a157068..d6d745bc 100644 --- a/src/service/uiaa/data.rs +++ b/src/service/uiaa/data.rs @@ -1,8 +1,11 @@ -use ruma::{api::client::uiaa::UiaaInfo, CanonicalJsonValue, DeviceId, UserId}; +use ruma::{ + api::client::{error::ErrorKind, uiaa::UiaaInfo}, + CanonicalJsonValue, DeviceId, UserId, +}; -use crate::Result; +use crate::{Error, KeyValueDatabase, Result}; -pub trait Data: Send + Sync { +pub(crate) trait Data: Send + Sync { fn set_uiaa_request( &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, ) -> Result<()>; @@ -15,3 +18,65 @@ pub trait Data: Send + Sync { fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result; } + +impl Data for KeyValueDatabase { + fn set_uiaa_request( + &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, + ) -> Result<()> { + self.userdevicesessionid_uiaarequest + .write() + .unwrap() + .insert( + (user_id.to_owned(), device_id.to_owned(), session.to_owned()), + request.to_owned(), + ); + + Ok(()) + } + + fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option { + self.userdevicesessionid_uiaarequest + .read() + .unwrap() + .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) + .map(ToOwned::to_owned) + } + + fn update_uiaa_session( + &self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>, + ) -> Result<()> { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + if let Some(uiaainfo) = uiaainfo { + self.userdevicesessionid_uiaainfo.insert( + &userdevicesessionid, + &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), + )?; + } else { + self.userdevicesessionid_uiaainfo + .remove(&userdevicesessionid)?; + } + + Ok(()) + } + + fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + serde_json::from_slice( + &self + .userdevicesessionid_uiaainfo + .get(&userdevicesessionid)? + .ok_or(Error::BadRequest(ErrorKind::forbidden(), "UIAA session does not exist."))?, + ) + .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) + } +} diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index cd131c52..63293867 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use argon2::{PasswordHash, PasswordVerifier}; use conduit::{utils, Error, Result}; -pub use data::Data; +use data::Data; use ruma::{ api::client::{ error::ErrorKind, @@ -19,7 +19,7 @@ use crate::services; pub const SESSION_ID_LENGTH: usize = 32; pub struct Service { - pub db: Arc, + pub(super) db: Arc, } impl Service { diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 04074e85..1254a988 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -1,14 +1,17 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap, mem::size_of}; +use argon2::{password_hash::SaltString, PasswordHasher}; use ruma::{ - api::client::{device::Device, filter::FilterDefinition}, + api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::AnyToDeviceEvent, + events::{AnyToDeviceEvent, StateEventType}, serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId, + uint, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, + OwnedMxcUri, OwnedUserId, UInt, UserId, }; +use tracing::warn; -use crate::Result; +use crate::{services, users::clean_signatures, utils, Error, KeyValueDatabase, Result}; pub trait Data: Send + Sync { /// Check if a user has an account on this homeserver. @@ -144,3 +147,887 @@ pub trait Data: Send + Sync { fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result>; } + +impl Data for KeyValueDatabase { + /// Check if a user has an account on this homeserver. + fn exists(&self, user_id: &UserId) -> Result { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) } + + /// Check if account is deactivated + fn is_deactivated(&self, user_id: &UserId) -> Result { + Ok(self + .userid_password + .get(user_id.as_bytes())? + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))? + .is_empty()) + } + + /// Returns the number of users registered on this server. + fn count(&self) -> Result { Ok(self.userid_password.iter().count()) } + + /// Find out which user an access token belongs to. + fn find_from_token(&self, token: &str) -> Result> { + self.token_userdeviceid + .get(token.as_bytes())? + .map_or(Ok(None), |bytes| { + let mut parts = bytes.split(|&b| b == 0xFF); + let user_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?; + let device_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?; + + Ok(Some(( + UserId::parse( + utils::string_from_bytes(user_bytes) + .map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid."))?, + utils::string_from_bytes(device_bytes) + .map_err(|_| Error::bad_database("Device ID in token_userdeviceid is invalid."))?, + ))) + }) + } + + /// Returns an iterator over all users on this homeserver. + fn iter<'a>(&'a self) -> Box> + 'a> { + Box::new(self.userid_password.iter().map(|(bytes, _)| { + UserId::parse( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("User ID in userid_password is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) + })) + } + + /// Returns a list of local users as list of usernames. + /// + /// A user account is considered `local` if the length of it's password is + /// greater then zero. + fn list_local_users(&self) -> Result> { + let users: Vec = self + .userid_password + .iter() + .filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw)) + .collect(); + Ok(users) + } + + /// Returns the password hash for the given user. + fn password_hash(&self, user_id: &UserId) -> Result> { + self.userid_password + .get(user_id.as_bytes())? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Password hash in db is not valid string.") + })?)) + }) + } + + /// Hash and set the user's password to the Argon2 hash + fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + if let Some(password) = password { + if let Ok(hash) = calculate_password_hash(password) { + self.userid_password + .insert(user_id.as_bytes(), hash.as_bytes())?; + Ok(()) + } else { + Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Password does not meet the requirements.", + )) + } + } else { + self.userid_password.insert(user_id.as_bytes(), b"")?; + Ok(()) + } + } + + /// Returns the displayname of a user on this homeserver. + fn displayname(&self, user_id: &UserId) -> Result> { + self.userid_displayname + .get(user_id.as_bytes())? + .map_or(Ok(None), |bytes| { + Ok(Some( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Displayname in db is invalid."))?, + )) + }) + } + + /// Sets a new displayname or removes it if displayname is None. You still + /// need to nofify all rooms of this change. + fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { + if let Some(displayname) = displayname { + self.userid_displayname + .insert(user_id.as_bytes(), displayname.as_bytes())?; + } else { + self.userid_displayname.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + /// Get the `avatar_url` of a user. + fn avatar_url(&self, user_id: &UserId) -> Result> { + self.userid_avatarurl + .get(user_id.as_bytes())? + .map(|bytes| { + let s_bytes = utils::string_from_bytes(&bytes).map_err(|e| { + warn!("Avatar URL in db is invalid: {}", e); + Error::bad_database("Avatar URL in db is invalid.") + })?; + let mxc_uri: OwnedMxcUri = s_bytes.into(); + Ok(mxc_uri) + }) + .transpose() + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { + if let Some(avatar_url) = avatar_url { + self.userid_avatarurl + .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; + } else { + self.userid_avatarurl.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + /// Get the blurhash of a user. + fn blurhash(&self, user_id: &UserId) -> Result> { + self.userid_blurhash + .get(user_id.as_bytes())? + .map(|bytes| { + let s = utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; + + Ok(s) + }) + .transpose() + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { + if let Some(blurhash) = blurhash { + self.userid_blurhash + .insert(user_id.as_bytes(), blurhash.as_bytes())?; + } else { + self.userid_blurhash.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + /// Adds a new device to a user. + fn create_device( + &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, + ) -> Result<()> { + // This method should never be called for nonexistent users. We shouldn't assert + // though... + if !self.exists(user_id)? { + warn!("Called create_device for non-existent user {} in database", user_id); + return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist.")); + } + + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.userid_devicelistversion + .increment(user_id.as_bytes())?; + + self.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(&Device { + device_id: device_id.into(), + display_name: initial_device_display_name, + last_seen_ip: None, // TODO + last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), + }) + .expect("Device::to_string never fails."), + )?; + + self.set_token(user_id, device_id, token)?; + + Ok(()) + } + + /// Removes a device from a user. + fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // Remove tokens + if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { + self.userdeviceid_token.remove(&userdeviceid)?; + self.token_userdeviceid.remove(&old_token)?; + } + + // Remove todevice events + let mut prefix = userdeviceid.clone(); + prefix.push(0xFF); + + for (key, _) in self.todeviceid_events.scan_prefix(prefix) { + self.todeviceid_events.remove(&key)?; + } + + // TODO: Remove onetimekeys + + self.userid_devicelistversion + .increment(user_id.as_bytes())?; + + self.userdeviceid_metadata.remove(&userdeviceid)?; + + Ok(()) + } + + /// Returns an iterator over all device ids of this user. + fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + // All devices have metadata + Box::new( + self.userdeviceid_metadata + .scan_prefix(prefix) + .map(|(bytes, _)| { + Ok(utils::string_from_bytes( + bytes + .rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? + .into()) + }), + ) + } + + /// Replaces the access token of one device. + fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // should not be None, but we shouldn't assert either lol... + if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { + warn!( + "Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", + user_id, device_id + ); + return Err(Error::bad_database( + "User does not exist or device ID has no metadata in database.", + )); + } + + // Remove old token + if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { + self.token_userdeviceid.remove(&old_token)?; + // It will be removed from userdeviceid_token by the insert later + } + + // Assign token to user device combination + self.userdeviceid_token + .insert(&userdeviceid, token.as_bytes())?; + self.token_userdeviceid + .insert(token.as_bytes(), &userdeviceid)?; + + Ok(()) + } + + fn add_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, + one_time_key_value: &Raw, + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.as_bytes()); + + // All devices have metadata + // Only existing devices should be able to call this, but we shouldn't assert + // either... + if self.userdeviceid_metadata.get(&key)?.is_none() { + warn!( + "Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in \ + database", + user_id, device_id + ); + return Err(Error::bad_database( + "User does not exist or device ID has no metadata in database.", + )); + } + + key.push(0xFF); + // TODO: Use DeviceKeyId::to_string when it's available (and update everything, + // because there are no wrapping quotation marks anymore) + key.extend_from_slice( + serde_json::to_string(one_time_key_key) + .expect("DeviceKeyId::to_string always works") + .as_bytes(), + ); + + self.onetimekeyid_onetimekeys.insert( + &key, + &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), + )?; + + self.userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; + + Ok(()) + } + + fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { + self.userid_lastonetimekeyupdate + .get(user_id.as_bytes())? + .map_or(Ok(0), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")) + }) + } + + fn take_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, + ) -> Result)>> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.push(b'"'); // Annoying quotation mark + prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); + prefix.push(b':'); + + self.userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; + + self.onetimekeyid_onetimekeys + .scan_prefix(prefix) + .next() + .map(|(key, value)| { + self.onetimekeyid_onetimekeys.remove(&key)?; + + Ok(( + serde_json::from_slice( + key.rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, + serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, + )) + }) + .transpose() + } + + fn count_one_time_keys( + &self, user_id: &UserId, device_id: &DeviceId, + ) -> Result> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + let mut counts = BTreeMap::new(); + + for algorithm in self + .onetimekeyid_onetimekeys + .scan_prefix(userdeviceid) + .map(|(bytes, _)| { + Ok::<_, Error>( + serde_json::from_slice::( + bytes + .rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| Error::bad_database("OneTimeKey ID in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? + .algorithm(), + ) + }) { + *counts.entry(algorithm?).or_default() += uint!(1); + } + + Ok(counts) + } + + fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.keyid_key.insert( + &userdeviceid, + &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), + )?; + + self.mark_device_key_update(user_id)?; + + Ok(()) + } + + fn add_cross_signing_keys( + &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, + user_signing_key: &Option>, notify: bool, + ) -> Result<()> { + // TODO: Check signatures + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let (master_key_key, _) = self.parse_master_key(user_id, master_key)?; + + self.keyid_key + .insert(&master_key_key, master_key.json().get().as_bytes())?; + + self.userid_masterkeyid + .insert(user_id.as_bytes(), &master_key_key)?; + + // Self-signing key + if let Some(self_signing_key) = self_signing_key { + let mut self_signing_key_ids = self_signing_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))? + .keys + .into_values(); + + let self_signing_key_id = self_signing_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?; + + if self_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained more than one key.", + )); + } + + let mut self_signing_key_key = prefix.clone(); + self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); + + self.keyid_key + .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?; + + self.userid_selfsigningkeyid + .insert(user_id.as_bytes(), &self_signing_key_key)?; + } + + // User-signing key + if let Some(user_signing_key) = user_signing_key { + let mut user_signing_key_ids = user_signing_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))? + .keys + .into_values(); + + let user_signing_key_id = user_signing_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?; + + if user_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "User signing key contained more than one key.", + )); + } + + let mut user_signing_key_key = prefix; + user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); + + self.keyid_key + .insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?; + + self.userid_usersigningkeyid + .insert(user_id.as_bytes(), &user_signing_key_key)?; + } + + if notify { + self.mark_device_key_update(user_id)?; + } + + Ok(()) + } + + fn sign_key( + &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, + ) -> Result<()> { + let mut key = target_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(key_id.as_bytes()); + + let mut cross_signing_key: serde_json::Value = serde_json::from_slice( + &self + .keyid_key + .get(&key)? + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?, + ) + .map_err(|_| Error::bad_database("key in keyid_key is invalid."))?; + + let signatures = cross_signing_key + .get_mut("signatures") + .ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))? + .as_object_mut() + .ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))? + .entry(sender_id.to_string()) + .or_insert_with(|| serde_json::Map::new().into()); + + signatures + .as_object_mut() + .ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))? + .insert(signature.0, signature.1.into()); + + self.keyid_key.insert( + &key, + &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), + )?; + + self.mark_device_key_update(target_id)?; + + Ok(()) + } + + fn keys_changed<'a>( + &'a self, user_or_room_id: &str, from: u64, to: Option, + ) -> Box> + 'a> { + let mut prefix = user_or_room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let mut start = prefix.clone(); + start.extend_from_slice(&(from.saturating_add(1)).to_be_bytes()); + + let to = to.unwrap_or(u64::MAX); + + Box::new( + self.keychangeid_userid + .iter_from(&start, false) + .take_while(move |(k, _)| { + k.starts_with(&prefix) + && if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) { + if let Ok(c) = utils::u64_from_bytes(current) { + c <= to + } else { + warn!("BadDatabase: Could not parse keychangeid_userid bytes"); + false + } + } else { + warn!("BadDatabase: Could not parse keychangeid_userid"); + false + } + }) + .map(|(_, bytes)| { + UserId::parse( + utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid.")) + }), + ) + } + + fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { + let count = services().globals.next_count()?.to_be_bytes(); + for room_id in services() + .rooms + .state_cache + .rooms_joined(user_id) + .filter_map(Result::ok) + { + // Don't send key updates to unencrypted rooms + if services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? + .is_none() + { + continue; + } + + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&count); + + self.keychangeid_userid.insert(&key, user_id.as_bytes())?; + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&count); + self.keychangeid_userid.insert(&key, user_id.as_bytes())?; + + Ok(()) + } + + fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.as_bytes()); + + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + Ok(Some( + serde_json::from_slice(&bytes).map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?, + )) + }) + } + + fn parse_master_key( + &self, user_id: &UserId, master_key: &Raw, + ) -> Result<(Vec, CrossSigningKey)> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let master_key = master_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; + let mut master_key_ids = master_key.keys.values(); + let master_key_id = master_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; + if master_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Master key contained more than one key.", + )); + } + let mut master_key_key = prefix.clone(); + master_key_key.extend_from_slice(master_key_id.as_bytes()); + Ok((master_key_key, master_key)) + } + + fn get_key( + &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result>> { + self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { + let mut cross_signing_key = serde_json::from_slice::(&bytes) + .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; + clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?; + + Ok(Some(Raw::from_json( + serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"), + ))) + }) + } + + fn get_master_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result>> { + self.userid_masterkeyid + .get(user_id.as_bytes())? + .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) + } + + fn get_self_signing_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result>> { + self.userid_selfsigningkeyid + .get(user_id.as_bytes())? + .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) + } + + fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { + self.userid_usersigningkeyid + .get(user_id.as_bytes())? + .map_or(Ok(None), |key| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + Ok(Some( + serde_json::from_slice(&bytes) + .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?, + )) + }) + }) + } + + fn add_to_device_event( + &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, + content: serde_json::Value, + ) -> Result<()> { + let mut key = target_user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(target_device_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + + let mut json = serde_json::Map::new(); + json.insert("type".to_owned(), event_type.to_owned().into()); + json.insert("sender".to_owned(), sender.to_string().into()); + json.insert("content".to_owned(), content); + + let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); + + self.todeviceid_events.insert(&key, &value)?; + + Ok(()) + } + + fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { + let mut events = Vec::new(); + + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + + for (_, value) in self.todeviceid_events.scan_prefix(prefix) { + events.push( + serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, + ); + } + + Ok(events) + } + + fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + + let mut last = prefix.clone(); + last.extend_from_slice(&until.to_be_bytes()); + + for (key, _) in self + .todeviceid_events + .iter_from(&last, true) // this includes last + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(key, _)| { + Ok::<_, Error>(( + key.clone(), + utils::u64_from_bytes(&key[key.len() - size_of::()..key.len()]) + .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, + )) + }) + .filter_map(Result::ok) + .take_while(|&(_, count)| count <= until) + { + self.todeviceid_events.remove(&key)?; + } + + Ok(()) + } + + fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // Only existing devices should be able to call this, but we shouldn't assert + // either... + if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { + warn!( + "Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \ + metadata in database", + user_id, device_id + ); + return Err(Error::bad_database( + "User does not exist or device ID has no metadata in database.", + )); + } + + self.userid_devicelistversion + .increment(user_id.as_bytes())?; + + self.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(device).expect("Device::to_string always works"), + )?; + + Ok(()) + } + + /// Get device metadata. + fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.userdeviceid_metadata + .get(&userdeviceid)? + .map_or(Ok(None), |bytes| { + Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { + Error::bad_database("Metadata in userdeviceid_metadata is invalid.") + })?)) + }) + } + + fn get_devicelist_version(&self, user_id: &UserId) -> Result> { + self.userid_devicelistversion + .get(user_id.as_bytes())? + .map_or(Ok(None), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid devicelistversion in db.")) + .map(Some) + }) + } + + fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + + Box::new( + self.userdeviceid_metadata + .scan_prefix(key) + .map(|(_, bytes)| { + serde_json::from_slice::(&bytes) + .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) + }), + ) + } + + /// Creates a new sync filter. Returns the filter id. + fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { + let filter_id = utils::random_string(4); + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(filter_id.as_bytes()); + + self.userfilterid_filter + .insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?; + + Ok(filter_id) + } + + fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(filter_id.as_bytes()); + + let raw = self.userfilterid_filter.get(&key)?; + + if let Some(raw) = raw { + serde_json::from_slice(&raw).map_err(|_| Error::bad_database("Invalid filter event in db.")) + } else { + Ok(None) + } + } +} + +/// Will only return with Some(username) if the password was not empty and the +/// username could be successfully parsed. +/// If `utils::string_from_bytes`(...) returns an error that username will be +/// skipped and the error will be logged. +fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option { + // A valid password is not empty + if password.is_empty() { + None + } else { + match utils::string_from_bytes(username) { + Ok(u) => Some(u), + Err(e) => { + warn!("Failed to parse username while calling get_local_users(): {}", e.to_string()); + None + }, + } + } +} + +/// Calculate a new hash for the given password +fn calculate_password_hash(password: &str) -> Result { + let salt = SaltString::generate(rand::thread_rng()); + services() + .globals + .argon + .hash_password(password.as_bytes(), &salt) + .map(|it| it.to_string()) +} diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index fde2ed89..ec17e796 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -5,7 +5,7 @@ use std::{ sync::{Arc, Mutex}, }; -pub use data::Data; +use data::Data; use ruma::{ api::client::{ device::Device,