Hot-Reloading Refactor

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-05-09 15:59:08 -07:00 committed by June 🍓🦴
parent ae1a4fd283
commit 6c1434c165
212 changed files with 5679 additions and 4206 deletions

81
src/database/Cargo.toml Normal file
View file

@ -0,0 +1,81 @@
[package]
name = "conduit_database"
version.workspace = true
edition.workspace = true
[lib]
path = "mod.rs"
crate-type = [
"rlib",
# "dylib",
]
[features]
default = [
"rocksdb",
"io_uring",
"jemalloc",
"zstd_compression",
"release_max_log_level",
]
dev_release_log_level = []
release_max_log_level = [
"tracing/max_level_trace",
"tracing/release_max_level_info",
"log/max_level_trace",
"log/release_max_level_info",
]
sqlite = [
"dep:rusqlite",
"dep:parking_lot",
"dep:thread_local",
]
rocksdb = [
"dep:rust-rocksdb",
]
jemalloc = [
"dep:tikv-jemalloc-sys",
"dep:tikv-jemalloc-ctl",
"dep:tikv-jemallocator",
"rust-rocksdb/jemalloc",
]
jemalloc_prof = [
"tikv-jemalloc-sys/profiling",
]
io_uring = [
"rust-rocksdb/io-uring",
]
zstd_compression = [
"rust-rocksdb/zstd",
]
[dependencies]
chrono.workspace = true
conduit-core.workspace = true
futures-util.workspace = true
log.workspace = true
lru-cache.workspace = true
num_cpus.workspace = true
parking_lot.optional = true
parking_lot.workspace = true
ruma.workspace = true
rusqlite.optional = true
rusqlite.workspace = true
rust-rocksdb.optional = true
rust-rocksdb.workspace = true
thread_local.optional = true
thread_local.workspace = true
tikv-jemallocator.optional = true
tikv-jemallocator.workspace = true
tikv-jemalloc-ctl.optional = true
tikv-jemalloc-ctl.workspace = true
tikv-jemalloc-sys.optional = true
tikv-jemalloc-sys.workspace = true
tokio.workspace = true
tracing.workspace = true
zstd.optional = true
zstd.workspace = true
[lints]
workspace = true

View file

@ -2,14 +2,14 @@ use std::sync::Arc;
use super::KeyValueDatabaseEngine;
pub(crate) struct Cork {
pub struct Cork {
db: Arc<dyn KeyValueDatabaseEngine>,
flush: bool,
sync: bool,
}
impl Cork {
pub(crate) fn new(db: &Arc<dyn KeyValueDatabaseEngine>, flush: bool, sync: bool) -> Self {
pub fn new(db: &Arc<dyn KeyValueDatabaseEngine>, flush: bool, sync: bool) -> Self {
db.cork().unwrap();
Cork {
db: db.clone(),

View file

@ -1,137 +0,0 @@
use std::collections::HashMap;
use ruma::{
api::client::error::ErrorKind,
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
serde::Raw,
RoomId, UserId,
};
use tracing::warn;
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::account_data::Data for KeyValueDatabase {
/// Places one event in the account data of the user and removes the
/// previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
fn update(
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
data: &serde_json::Value,
) -> Result<()> {
let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF);
let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
roomuserdataid.push(0xFF);
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
let mut key = prefix;
key.extend_from_slice(event_type.to_string().as_bytes());
if data.get("type").is_none() || data.get("content").is_none() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Account data doesn't have all required fields.",
));
}
self.roomuserdataid_accountdata.insert(
&roomuserdataid,
&serde_json::to_vec(&data).expect("to_vec always works on json values"),
)?;
let prev = self.roomusertype_roomuserdataid.get(&key)?;
self.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid)?;
// Remove old entry
if let Some(prev) = prev {
self.roomuserdataid_accountdata.remove(&prev)?;
}
Ok(())
}
/// Searches the account data for a specific kind.
#[tracing::instrument(skip(self, room_id, user_id, kind))]
fn get(
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
) -> Result<Option<Box<serde_json::value::RawValue>>> {
let mut key = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(kind.to_string().as_bytes());
self.roomusertype_roomuserdataid
.get(&key)?
.and_then(|roomuserdataid| {
self.roomuserdataid_accountdata
.get(&roomuserdataid)
.transpose()
})
.transpose()?
.map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize")))
.transpose()
}
/// Returns all changes to the account data that happened after `since`.
#[tracing::instrument(skip(self, room_id, user_id, since))]
fn changes_since(
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
let mut userdata = HashMap::new();
let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF);
// Skip the data that's exactly at since, because we sent that last time
let mut first_possible = prefix.clone();
first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes());
for r in self
.roomuserdataid_accountdata
.iter_from(&first_possible, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(k, v)| {
Ok::<_, Error>((
RoomAccountDataEventType::from(
utils::string_from_bytes(
k.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("RoomUserData ID in db is invalid."))?,
)
.map_err(|e| {
warn!("RoomUserData ID in database is invalid: {}", e);
Error::bad_database("RoomUserData ID in db is invalid.")
})?,
),
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v)
.map_err(|_| Error::bad_database("Database contains invalid account data."))?,
))
}) {
let (kind, data) = r?;
userdata.insert(kind, data);
}
Ok(userdata)
}
}

View file

@ -1,55 +0,0 @@
use ruma::api::appservice::Registration;
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::appservice::Data for KeyValueDatabase {
/// Registers an appservice and returns the ID to the caller
fn register_appservice(&self, yaml: Registration) -> Result<String> {
let id = yaml.id.as_str();
self.id_appserviceregistrations
.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
Ok(id.to_owned())
}
/// Remove an appservice registration
///
/// # Arguments
///
/// * `service_name` - the name you send to register the service previously
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.id_appserviceregistrations
.remove(service_name.as_bytes())?;
Ok(())
}
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
self.id_appserviceregistrations
.get(id.as_bytes())?
.map(|bytes| {
serde_yaml::from_slice(&bytes)
.map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations."))
})
.transpose()
}
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
utils::string_from_bytes(&id)
.map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
})))
}
fn all(&self) -> Result<Vec<(String, Registration)>> {
self.iter_ids()?
.filter_map(Result::ok)
.map(move |id| {
Ok((
id.clone(),
self.get_registration(&id)?
.expect("iter_ids only returns appservices that exist"),
))
})
.collect()
}
}

View file

@ -1,298 +0,0 @@
use std::collections::{BTreeMap, HashMap};
use async_trait::async_trait;
use futures_util::{stream::FuturesUnordered, StreamExt};
use lru_cache::LruCache;
use ruma::{
api::federation::discovery::{ServerSigningKeys, VerifyKey},
signatures::Ed25519KeyPair,
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
};
use crate::{
database::{Cork, KeyValueDatabase},
service, services, utils, Error, Result,
};
const COUNTER: &[u8] = b"c";
const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
#[async_trait]
impl service::globals::Data for KeyValueDatabase {
fn next_count(&self) -> Result<u64> {
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
.map_err(|_| Error::bad_database("Count has invalid bytes."))
}
fn current_count(&self) -> Result<u64> {
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
})
}
fn last_check_for_updates_id(&self) -> Result<u64> {
self.global
.get(LAST_CHECK_FOR_UPDATES_COUNT)?
.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("last check for updates count has invalid bytes."))
})
}
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
self.global
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
Ok(())
}
#[allow(unused_qualifications)] // async traits
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let userid_bytes = user_id.as_bytes().to_vec();
let mut userid_prefix = userid_bytes.clone();
userid_prefix.push(0xFF);
let mut userdeviceid_prefix = userid_prefix.clone();
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
userdeviceid_prefix.push(0xFF);
let mut futures = FuturesUnordered::new();
// Return when *any* user changed their key
// TODO: only send for user they share a room with
futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix));
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
futures.push(
self.userroomid_notificationcount
.watch_prefix(&userid_prefix),
);
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
// Events for rooms we are in
for room_id in services()
.rooms
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
let short_roomid = services()
.rooms
.short
.get_shortroomid(&room_id)
.ok()
.flatten()
.expect("room exists")
.to_be_bytes()
.to_vec();
let roomid_bytes = room_id.as_bytes().to_vec();
let mut roomid_prefix = roomid_bytes.clone();
roomid_prefix.push(0xFF);
// PDUs
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
// EDUs
futures.push(Box::pin(async move {
let _result = services().rooms.typing.wait_for_update(&room_id).await;
}));
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
// Key changes
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
// Room account data
let mut roomuser_prefix = roomid_prefix.clone();
roomuser_prefix.extend_from_slice(&userid_prefix);
futures.push(
self.roomusertype_roomuserdataid
.watch_prefix(&roomuser_prefix),
);
}
let mut globaluserdata_prefix = vec![0xFF];
globaluserdata_prefix.extend_from_slice(&userid_prefix);
futures.push(
self.roomusertype_roomuserdataid
.watch_prefix(&globaluserdata_prefix),
);
// More key changes (used when user is not joined to any rooms)
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
// One time keys
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
futures.push(Box::pin(services().globals.rotate.watch()));
// Wait until one of them finds something
futures.next().await;
Ok(())
}
fn cleanup(&self) -> Result<()> { self.db.cleanup() }
fn flush(&self) -> Result<()> { self.db.flush() }
fn cork(&self) -> Result<Cork> { Ok(Cork::new(&self.db, false, false)) }
fn cork_and_flush(&self) -> Result<Cork> { Ok(Cork::new(&self.db, true, false)) }
fn cork_and_sync(&self) -> Result<Cork> { Ok(Cork::new(&self.db, true, true)) }
fn memory_usage(&self) -> String {
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
let our_real_users_cache = self.our_real_users_cache.read().unwrap().len();
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len();
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len();
let max_auth_chain_cache = self.auth_chain_cache.lock().unwrap().capacity();
let max_our_real_users_cache = self.our_real_users_cache.read().unwrap().capacity();
let max_appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().capacity();
let max_lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().capacity();
format!(
"\
auth_chain_cache: {auth_chain_cache} / {max_auth_chain_cache}
our_real_users_cache: {our_real_users_cache} / {max_our_real_users_cache}
appservice_in_room_cache: {appservice_in_room_cache} / {max_appservice_in_room_cache}
lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cache}\n\n
{}",
self.db.memory_usage().unwrap_or_default()
)
}
fn clear_caches(&self, amount: u32) {
if amount > 1 {
let c = &mut *self.auth_chain_cache.lock().unwrap();
*c = LruCache::new(c.capacity());
}
if amount > 2 {
let c = &mut *self.our_real_users_cache.write().unwrap();
*c = HashMap::new();
}
if amount > 3 {
let c = &mut *self.appservice_in_room_cache.write().unwrap();
*c = HashMap::new();
}
if amount > 4 {
let c = &mut *self.lasttimelinecount_cache.lock().unwrap();
*c = HashMap::new();
}
}
fn load_keypair(&self) -> Result<Ed25519KeyPair> {
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|| {
let keypair = utils::generate_keypair();
self.global.insert(b"keypair", &keypair)?;
Ok::<_, Error>(keypair)
},
Ok,
)?;
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF);
utils::string_from_bytes(
// 1. version
parts
.next()
.expect("splitn always returns at least one element"),
)
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
.and_then(|version| {
// 2. key
parts
.next()
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
.map(|key| (version, key))
})
.and_then(|(version, key)| {
Ed25519KeyPair::from_der(key, version)
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
})
}
fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
fn add_signing_key(
&self, origin: &ServerName, new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
// Not atomic, but this is not critical
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
let mut keys = signingkeys
.and_then(|keys| serde_json::from_slice(&keys).ok())
.unwrap_or_else(|| {
// Just insert "now", it doesn't matter
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
});
let ServerSigningKeys {
verify_keys,
old_verify_keys,
..
} = new_keys;
keys.verify_keys.extend(verify_keys);
keys.old_verify_keys.extend(old_verify_keys);
self.server_signingkeys.insert(
origin.as_bytes(),
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
)?;
let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
Ok(tree)
}
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let signingkeys = self
.server_signingkeys
.get(origin.as_bytes())?
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
.map_or_else(BTreeMap::new, |keys: ServerSigningKeys| {
let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
tree
});
Ok(signingkeys)
}
fn database_version(&self) -> Result<u64> {
self.global.get(b"version")?.map_or(Ok(0), |version| {
utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
})
}
fn bump_database_version(&self, new_version: u64) -> Result<()> {
self.global.insert(b"version", &new_version.to_be_bytes())?;
Ok(())
}
fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { self.db.backup() }
fn backup_list(&self) -> Result<String> { self.db.backup_list() }
fn file_list(&self) -> Result<String> { self.db.file_list() }
}

View file

@ -1,317 +0,0 @@
use std::collections::BTreeMap;
use ruma::{
api::client::{
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
error::ErrorKind,
},
serde::Raw,
OwnedRoomId, RoomId, UserId,
};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::key_backups::Data for KeyValueDatabase {
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let version = services().globals.next_count()?.to_string();
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.insert(
&key,
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
)?;
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
Ok(version)
}
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.remove(&key)?;
self.backupid_etag.remove(&key)?;
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() {
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
}
self.backupid_algorithm
.insert(&key, backup_metadata.json().get().as_bytes())?;
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
Ok(version.to_owned())
}
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm
.iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.next()
.map(|(key, _)| {
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
})
.transpose()
}
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm
.iter_from(&last_possible_key, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.next()
.map(|(key, value)| {
let version = utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
Ok((
version,
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?,
))
})
.transpose()
}
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
self.backupid_algorithm
.get(&key)?
.map_or(Ok(None), |bytes| {
serde_json::from_slice(&bytes)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
})
}
fn add_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
if self.backupid_algorithm.get(&key)?.is_none() {
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
}
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup
.insert(&key, key_data.json().get().as_bytes())?;
Ok(())
}
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
}
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
Ok(utils::u64_from_bytes(
&self
.backupid_etag
.get(&key)?
.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
)
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
.to_string())
}
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xFF);
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
for result in self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
let room_id = RoomId::parse(
utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?;
let key_data = serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
Ok::<_, Error>((room_id, session_id, key_data))
}) {
let (room_id, session_id, key_data) = result?;
rooms
.entry(room_id)
.or_insert_with(|| RoomKeyBackup {
sessions: BTreeMap::new(),
})
.sessions
.insert(session_id, key_data);
}
Ok(rooms)
}
fn get_room(
&self, user_id: &UserId, version: &str, room_id: &RoomId,
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(version.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
Ok(self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = utils::string_from_bytes(
parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
)
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
let key_data = serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
Ok::<_, Error>((session_id, key_data))
})
.filter_map(Result::ok)
.collect())
}
fn get_session(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
) -> Result<Option<Raw<KeyBackupData>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup
.get(&key)?
.map(|value| {
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))
})
.transpose()
}
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(version.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(session_id.as_bytes());
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
self.backupkeyid_backup.remove(&outdated_key)?;
}
Ok(())
}
}

View file

@ -1,247 +0,0 @@
use ruma::api::client::error::ErrorKind;
use tracing::debug;
use crate::{
database::KeyValueDatabase,
service::{self, media::UrlPreviewData},
utils::string_from_bytes,
Error, Result,
};
impl service::media::Data for KeyValueDatabase {
fn create_file_metadata(
&self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>,
content_type: Option<&str>,
) -> Result<Vec<u8>> {
let mut key = mxc.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&width.to_be_bytes());
key.extend_from_slice(&height.to_be_bytes());
key.push(0xFF);
key.extend_from_slice(
content_disposition
.as_ref()
.map(|f| f.as_bytes())
.unwrap_or_default(),
);
key.push(0xFF);
key.extend_from_slice(
content_type
.as_ref()
.map(|c| c.as_bytes())
.unwrap_or_default(),
);
self.mediaid_file.insert(&key, &[])?;
if let Some(user) = sender_user {
let key = mxc.as_bytes().to_vec();
let user = user.as_bytes().to_vec();
self.mediaid_user.insert(&key, &user)?;
}
Ok(key)
}
fn delete_file_mxc(&self, mxc: String) -> Result<()> {
debug!("MXC URI: {:?}", mxc);
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
debug!("MXC db prefix: {prefix:?}");
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
debug!("Deleting key: {:?}", key);
self.mediaid_file.remove(&key)?;
}
for (key, value) in self.mediaid_user.scan_prefix(mxc.as_bytes().to_vec()) {
if key == mxc.as_bytes().to_vec() {
let user = string_from_bytes(&value).unwrap_or_default();
debug!("Deleting key \"{key:?}\" which was uploaded by user {user}");
self.mediaid_user.remove(&key)?;
}
}
Ok(())
}
/// Searches for all files with the given MXC
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>> {
debug!("MXC URI: {:?}", mxc);
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
let mut keys: Vec<Vec<u8>> = vec![];
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
keys.push(key);
}
if keys.is_empty() {
return Err(Error::bad_database(
"Failed to find any keys in database with the provided MXC.",
));
}
debug!("Got the following keys: {:?}", keys);
Ok(keys)
}
fn search_file_metadata(
&self, mxc: String, width: u32, height: u32,
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(&width.to_be_bytes());
prefix.extend_from_slice(&height.to_be_bytes());
prefix.push(0xFF);
let (key, _) = self
.mediaid_file
.scan_prefix(prefix)
.next()
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
let mut parts = key.rsplit(|&b| b == 0xFF);
let content_type = parts
.next()
.map(|bytes| {
string_from_bytes(bytes)
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))
})
.transpose()?;
let content_disposition_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
let content_disposition = if content_disposition_bytes.is_empty() {
None
} else {
Some(
string_from_bytes(content_disposition_bytes)
.map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?,
)
};
Ok((content_disposition, content_type, key))
}
/// Gets all the media keys in our database (this includes all the metadata
/// associated with it such as width, height, content-type, etc)
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
let mut keys: Vec<Vec<u8>> = vec![];
for (key, _) in self.mediaid_file.iter() {
keys.push(key);
}
Ok(keys)
}
fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) }
fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> {
let mut value = Vec::<u8>::new();
value.extend_from_slice(&timestamp.as_secs().to_be_bytes());
value.push(0xFF);
value.extend_from_slice(
data.title
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF);
value.extend_from_slice(
data.description
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF);
value.extend_from_slice(
data.image
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF);
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
value.push(0xFF);
value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
value.push(0xFF);
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
self.url_previews.insert(url.as_bytes(), &value)
}
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
let values = self.url_previews.get(url.as_bytes()).ok()??;
let mut values = values.split(|&b| b == 0xFF);
let _ts = values.next();
/* if we ever decide to use timestamp, this is here.
match values.next().map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) {
Some(0) => None,
x => x,
};*/
let title = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
};
let description = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
};
let image = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None,
x => x,
};
let image_size = match values
.next()
.map(|b| usize::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
};
let image_width = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
};
let image_height = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None,
x => x,
};
Some(UrlPreviewData {
title,
description,
image,
image_size,
image_width,
image_height,
})
}
}

View file

@ -1,14 +0,0 @@
mod account_data;
//mod admin;
mod appservice;
mod globals;
mod key_backups;
mod media;
//mod pdu;
mod presence;
mod pusher;
mod rooms;
mod sending;
mod transaction_ids;
mod uiaa;
mod users;

View file

@ -1,128 +0,0 @@
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId};
use crate::{
database::KeyValueDatabase,
debug_info,
service::{self, presence::Presence},
services,
utils::{self, user_id_from_bytes},
Error, Result,
};
impl service::presence::Data for KeyValueDatabase {
fn get_presence(&self, user_id: &UserId) -> Result<Option<(u64, PresenceEvent)>> {
if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? {
let count = utils::u64_from_bytes(&count_bytes)
.map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?;
let key = presenceid_key(count, user_id);
self.presenceid_presence
.get(&key)?
.map(|presence_bytes| -> Result<(u64, PresenceEvent)> {
Ok((count, Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)?))
})
.transpose()
} else {
Ok(None)
}
}
fn set_presence(
&self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option<bool>,
last_active_ago: Option<UInt>, status_msg: Option<String>,
) -> Result<()> {
let last_presence = self.get_presence(user_id)?;
let state_changed = match last_presence {
None => true,
Some(ref presence) => presence.1.content.presence != *presence_state,
};
let now = utils::millis_since_unix_epoch();
let last_last_active_ts = match last_presence {
None => 0,
Some((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()),
};
let last_active_ts = match last_active_ago {
None => now,
Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
};
// tighten for state flicker?
if !state_changed && last_active_ts <= last_last_active_ts {
debug_info!(
"presence spam {:?} last_active_ts:{:?} <= {:?}",
user_id,
last_active_ts,
last_last_active_ts
);
return Ok(());
}
let status_msg = if status_msg.as_ref().is_some_and(String::is_empty) {
None
} else {
status_msg
};
let presence = Presence::new(
presence_state.to_owned(),
currently_active.unwrap_or(false),
last_active_ts,
status_msg,
);
let count = services().globals.next_count()?;
let key = presenceid_key(count, user_id);
self.presenceid_presence
.insert(&key, &presence.to_json_bytes()?)?;
self.userid_presenceid
.insert(user_id.as_bytes(), &count.to_be_bytes())?;
if let Some((last_count, _)) = last_presence {
let key = presenceid_key(last_count, user_id);
self.presenceid_presence.remove(&key)?;
}
Ok(())
}
fn remove_presence(&self, user_id: &UserId) -> Result<()> {
if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? {
let count = utils::u64_from_bytes(&count_bytes)
.map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?;
let key = presenceid_key(count, user_id);
self.presenceid_presence.remove(&key)?;
self.userid_presenceid.remove(user_id.as_bytes())?;
}
Ok(())
}
fn presence_since<'a>(&'a self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + 'a> {
Box::new(
self.presenceid_presence
.iter()
.flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, Vec<u8>)> {
let (count, user_id) = presenceid_parse(&key)?;
Ok((user_id, count, presence_bytes))
})
.filter(move |(_, count, _)| *count > since),
)
}
}
#[inline]
fn presenceid_key(count: u64, user_id: &UserId) -> Vec<u8> {
[count.to_be_bytes().to_vec(), user_id.as_bytes().to_vec()].concat()
}
#[inline]
fn presenceid_parse(key: &[u8]) -> Result<(u64, OwnedUserId)> {
let (count, user_id) = key.split_at(8);
let user_id = user_id_from_bytes(user_id)?;
let count = utils::u64_from_bytes(count).unwrap();
Ok((count, user_id))
}

View file

@ -1,65 +0,0 @@
use ruma::{
api::client::push::{set_pusher, Pusher},
UserId,
};
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::pusher::Data for KeyValueDatabase {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
match &pusher {
set_pusher::v3::PusherAction::Post(data) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
self.senderkey_pusher
.insert(&key, &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"))?;
Ok(())
},
set_pusher::v3::PusherAction::Delete(ids) => {
let mut key = sender.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(ids.pushkey.as_bytes());
self.senderkey_pusher.remove(&key).map_err(Into::into)
},
}
}
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
let mut senderkey = sender.as_bytes().to_vec();
senderkey.push(0xFF);
senderkey.extend_from_slice(pushkey.as_bytes());
self.senderkey_pusher
.get(&senderkey)?
.map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
.transpose()
}
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xFF);
self.senderkey_pusher
.scan_prefix(prefix)
.map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
.collect()
}
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
let mut parts = k.splitn(2, |&b| b == 0xFF);
let _senderkey = parts.next();
let push_key = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
let push_key_string = utils::string_from_bytes(push_key)
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
Ok(push_key_string)
}))
}
}

View file

@ -1,75 +0,0 @@
use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::rooms::alias::Data for KeyValueDatabase {
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
self.alias_roomid
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF);
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
Ok(())
}
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
let mut prefix = room_id;
prefix.push(0xFF);
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
self.aliasid_alias.remove(&key)?;
}
self.alias_roomid.remove(alias.alias().as_bytes())?;
} else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
}
Ok(())
}
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
self.alias_roomid
.get(alias.alias().as_bytes())?
.map(|bytes| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
})
.transpose()
}
fn local_aliases_for_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
}))
}
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
Box::new(
self.alias_roomid
.iter()
.map(|(room_alias_bytes, room_id_bytes)| {
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?;
let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
Ok((room_id, room_alias_localpart))
}),
)
}
}

View file

@ -1,59 +0,0 @@
use std::{mem::size_of, sync::Arc};
use crate::{database::KeyValueDatabase, service, utils, Result};
impl service::rooms::auth_chain::Data for KeyValueDatabase {
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
// Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
return Ok(Some(Arc::clone(result)));
}
// We only save auth chains for single events in the db
if key.len() == 1 {
// Check DB cache
let chain = self
.shorteventid_authchain
.get(&key[0].to_be_bytes())?
.map(|chain| {
chain
.chunks_exact(size_of::<u64>())
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
.collect::<Arc<[u64]>>()
});
if let Some(chain) = chain {
// Cache in RAM
self.auth_chain_cache
.lock()
.unwrap()
.insert(vec![key[0]], Arc::clone(&chain));
return Ok(Some(chain));
}
}
Ok(None)
}
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()> {
// Only persist single events in db
if key.len() == 1 {
self.shorteventid_authchain.insert(
&key[0].to_be_bytes(),
&auth_chain
.iter()
.flat_map(|s| s.to_be_bytes().to_vec())
.collect::<Vec<u8>>(),
)?;
}
// Cache in RAM
self.auth_chain_cache
.lock()
.unwrap()
.insert(key, auth_chain);
Ok(())
}
}

View file

@ -1,23 +0,0 @@
use ruma::{OwnedRoomId, RoomId};
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::rooms::directory::Data for KeyValueDatabase {
fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) }
fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) }
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
}
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
}))
}
}

View file

@ -1,53 +0,0 @@
use ruma::{DeviceId, RoomId, UserId};
use crate::{database::KeyValueDatabase, service, Result};
impl service::rooms::lazy_loading::Data for KeyValueDatabase {
fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(ll_user.as_bytes());
Ok(self.lazyloadedids.get(&key)?.is_some())
}
fn lazy_load_confirm_delivery(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for ll_id in confirmed_user_ids {
let mut key = prefix.clone();
key.extend_from_slice(ll_id.as_bytes());
self.lazyloadedids.insert(&key, &[])?;
}
Ok(())
}
fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
self.lazyloadedids.remove(&key)?;
}
Ok(())
}
}

View file

@ -1,76 +0,0 @@
use ruma::{OwnedRoomId, RoomId};
use tracing::error;
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::rooms::metadata::Data for KeyValueDatabase {
fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(),
None => return Ok(false),
};
// Look for PDUs in that room.
Ok(self
.pduid_pdu
.iter_from(&prefix, false)
.next()
.filter(|(k, _)| k.starts_with(&prefix))
.is_some())
}
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
}))
}
fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
}
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
if disabled {
self.disabledroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.disabledroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
fn is_banned(&self, room_id: &RoomId) -> Result<bool> { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) }
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
if banned {
self.bannedroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.bannedroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.bannedroomids.iter().map(
|(room_id_bytes, _ /* non-banned rooms should not be in this table */)| {
let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|e| {
error!("Invalid room_id bytes in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids.")
})?
.try_into()
.map_err(|e| {
error!("Invalid room_id in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids")
})?;
Ok(room_id)
},
))
}
}

View file

@ -1,21 +0,0 @@
mod alias;
mod auth_chain;
mod directory;
mod lazy_load;
mod metadata;
mod outlier;
mod pdu_metadata;
mod read_receipt;
mod search;
mod short;
mod state;
mod state_accessor;
mod state_cache;
mod state_compressor;
mod threads;
mod timeline;
mod user;
use crate::{database::KeyValueDatabase, service};
impl service::rooms::Data for KeyValueDatabase {}

View file

@ -1,28 +0,0 @@
use ruma::{CanonicalJsonObject, EventId};
use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result};
impl service::rooms::outlier::Data for KeyValueDatabase {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
}
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
}
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
self.eventid_outlierpdu.insert(
event_id.as_bytes(),
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
)
}
}

View file

@ -1,83 +0,0 @@
use std::{mem, sync::Arc};
use ruma::{EventId, RoomId, UserId};
use crate::{
database::KeyValueDatabase,
service::{self, rooms::timeline::PduCount},
services, utils, Error, PduEvent, Result,
};
impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
fn add_relation(&self, from: u64, to: u64) -> Result<()> {
let mut key = to.to_be_bytes().to_vec();
key.extend_from_slice(&from.to_be_bytes());
self.tofrom_relation.insert(&key, &[])?;
Ok(())
}
fn relations_until<'a>(
&'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let prefix = target.to_be_bytes().to_vec();
let mut current = prefix.clone();
let count_raw = match until {
PduCount::Normal(x) => x.saturating_sub(1),
PduCount::Backfilled(x) => {
current.extend_from_slice(&0_u64.to_be_bytes());
u64::MAX.saturating_sub(x).saturating_sub(1)
},
};
current.extend_from_slice(&count_raw.to_be_bytes());
Ok(Box::new(
self.tofrom_relation
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(tofrom, _data)| {
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes());
let mut pdu = services()
.rooms
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((PduCount::Normal(from), pdu))
}),
))
}
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
for prev in event_ids {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes());
self.referencedevents.insert(&key, &[])?;
}
Ok(())
}
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(event_id.as_bytes());
Ok(self.referencedevents.get(&key)?.is_some())
}
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
self.softfailedeventids.insert(event_id.as_bytes(), &[])
}
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
self.softfailedeventids
.get(event_id.as_bytes())
.map(|o| o.is_some())
}
}

View file

@ -1,120 +0,0 @@
use std::mem;
use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::rooms::read_receipt::Data for KeyValueDatabase {
fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
// Remove old entry
if let Some((old, _)) = self
.readreceiptid_readreceipt
.iter_from(&last_possible_key, true)
.take_while(|(key, _)| key.starts_with(&prefix))
.find(|(key, _)| {
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element")
== user_id.as_bytes()
}) {
// This is the old room_latest
self.readreceiptid_readreceipt.remove(&old)?;
}
let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
room_latest_id.push(0xFF);
room_latest_id.extend_from_slice(user_id.as_bytes());
self.readreceiptid_readreceipt.insert(
&room_latest_id,
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
)?;
Ok(())
}
fn readreceipts_since<'a>(
&'a self, room_id: &RoomId, since: u64,
) -> Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
let prefix2 = prefix.clone();
let mut first_possible_edu = prefix.clone();
first_possible_edu.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); // +1 so we don't send the event at since
Box::new(
self.readreceiptid_readreceipt
.iter_from(&first_possible_edu, false)
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(k, v)| {
let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::<u64>()])
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
let user_id = UserId::parse(
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
.map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?,
)
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v)
.map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
json.remove("room_id");
Ok((
user_id,
count,
Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")),
))
}),
)
}
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.insert(&key, &count.to_be_bytes())?;
self.roomuserid_lastprivatereadupdate
.insert(&key, &services().globals.next_count()?.to_be_bytes())
}
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.get(&key)?
.map_or(Ok(None), |v| {
Ok(Some(
utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?,
))
})
}
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
.roomuserid_lastprivatereadupdate
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
})
.transpose()?
.unwrap_or(0))
}
}

View file

@ -1,64 +0,0 @@
use ruma::RoomId;
use crate::{database::KeyValueDatabase, service, services, utils, Result};
type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
impl service::rooms::search::Data for KeyValueDatabase {
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
let mut batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.filter(|word| word.len() <= 50)
.map(str::to_lowercase)
.map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
key.push(0xFF);
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
(key, Vec::new())
});
self.tokenids.insert_batch(&mut batch)
}
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let words: Vec<_> = search_string
.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.map(str::to_lowercase)
.collect();
let iterators = words.clone().into_iter().map(move |word| {
let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xFF);
let prefix3 = prefix2.clone();
let mut last_possible_id = prefix2.clone();
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
self.tokenids
.iter_from(&last_possible_id, true) // Newest pdus first
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(key, _)| key[prefix3.len()..].to_vec())
});
let Some(common_elements) = utils::common_elements(iterators, |a, b| {
// We compare b with a because we reversed the iterator earlier
b.cmp(a)
}) else {
return Ok(None);
};
Ok(Some((Box::new(common_elements), words)))
}
}

View file

@ -1,164 +0,0 @@
use std::sync::Arc;
use ruma::{events::StateEventType, EventId, RoomId};
use tracing::warn;
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::rooms::short::Data for KeyValueDatabase {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? {
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
} else {
let shorteventid = services().globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid
};
Ok(short)
}
fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result<Vec<u64>> {
let mut ret: Vec<u64> = Vec::with_capacity(event_ids.len());
let keys = event_ids
.iter()
.map(|id| id.as_bytes())
.collect::<Vec<&[u8]>>();
for (i, short) in self
.eventid_shorteventid
.multi_get(&keys)?
.iter()
.enumerate()
{
match short {
Some(short) => ret.push(
utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
),
None => {
let short = services().globals.next_count()?;
self.eventid_shorteventid
.insert(keys[i], &short.to_be_bytes())?;
self.shorteventid_eventid
.insert(&short.to_be_bytes(), keys[i])?;
debug_assert!(ret.len() == i, "position of result must match input");
ret.push(short);
},
}
}
Ok(ret)
}
fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes());
let short = self
.statekey_shortstatekey
.get(&statekey_vec)?
.map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
})
.transpose()?;
Ok(short)
}
fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes());
let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?
} else {
let shortstatekey = services().globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
shortstatekey
};
Ok(short)
}
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
let bytes = self
.shorteventid_eventid
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
let event_id = EventId::parse_arc(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
Ok(event_id)
}
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
let bytes = self
.shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
let mut parts = bytes.splitn(2, |&b| b == 0xFF);
let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
})?);
let state_key = utils::string_from_bytes(statekey_bytes)
.map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?;
let result = (event_type, state_key);
Ok(result)
}
/// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? {
(
utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
true,
)
} else {
let shortstatehash = services().globals.next_count()?;
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
})
}
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortroomid
.get(room_id.as_bytes())?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")))
.transpose()
}
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? {
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
} else {
let short = services().globals.next_count()?;
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short
})
}
}

View file

@ -1,73 +0,0 @@
use std::{collections::HashSet, sync::Arc};
use ruma::{EventId, OwnedEventId, RoomId};
use tokio::sync::MutexGuard;
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::rooms::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
})?))
})
}
fn set_room_state(
&self,
room_id: &RoomId,
new_shortstatehash: u64,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
self.roomid_shortstatehash
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
}
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(())
}
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
self.roomid_pduleaves
.scan_prefix(prefix)
.map(|(_, bytes)| {
EventId::parse_arc(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
})
.collect()
}
fn set_forward_extremities(
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?;
}
for event_id in event_ids {
let mut key = prefix.clone();
key.extend_from_slice(event_id.as_bytes());
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
}
Ok(())
}
}

View file

@ -1,165 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
#[async_trait]
impl service::rooms::state_accessor::Data for KeyValueDatabase {
#[allow(unused_qualifications)] // async traits
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new();
let mut i: u8 = 0;
for compressed in full_state.iter() {
let parsed = services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)?;
result.insert(parsed.0, parsed.1);
i = i.wrapping_add(1);
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
Ok(result)
}
#[allow(unused_qualifications)] // async traits
async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new();
let mut i: u8 = 0;
for compressed in full_state.iter() {
let (_, eventid) = services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)?;
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
result.insert(
(
pdu.kind.to_string().into(),
pdu.state_key
.as_ref()
.ok_or_else(|| Error::bad_database("State event has no state key."))?
.clone(),
),
pdu,
);
}
i = i.wrapping_add(1);
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
Ok(result)
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
let Some(shortstatekey) = services()
.rooms
.short
.get_shortstatekey(event_type, state_key)?
else {
return Ok(None);
};
let full_state = services()
.rooms
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| {
services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.state_get_id(shortstatehash, event_type, state_key)?
.map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id))
}
/// Returns the state hash for this pdu.
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_shorteventid
.get(event_id.as_bytes())?
.map_or(Ok(None), |shorteventid| {
self.shorteventid_shortstatehash
.get(&shorteventid)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash")
})
})
.transpose()
})
}
/// Returns the full room state.
#[allow(unused_qualifications)] // async traits
async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_full(current_shortstatehash).await
} else {
Ok(HashMap::new())
}
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get_id(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_get_id(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
self.state_get(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
}
}

View file

@ -1,627 +0,0 @@
use std::{collections::HashSet, sync::Arc};
use itertools::Itertools;
use ruma::{
events::{AnyStrippedStateEvent, AnySyncStateEvent},
serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
};
use tracing::error;
use crate::{
database::KeyValueDatabase,
service::{self, appservice::RegistrationInfo},
services,
utils::{self, user_id::user_is_local},
Error, Result,
};
type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
impl service::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.roomuseroncejoinedids.insert(&userroom_id, &[])
}
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let roomid = room_id.as_bytes().to_vec();
let mut roomid_prefix = room_id.as_bytes().to_vec();
roomid_prefix.push(0xFF);
let mut roomuser_id = roomid_prefix.clone();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_joined.insert(&userroom_id, &[])?;
self.roomuserid_joined.insert(&roomuser_id, &[])?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
if self
.roomuserid_joined
.scan_prefix(roomid_prefix.clone())
.count() == 0
&& self
.roomuserid_invitecount
.scan_prefix(roomid_prefix)
.count() == 0
{
self.roomid_inviteviaservers.remove(&roomid)?;
}
Ok(())
}
fn mark_as_invited(
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate.insert(
&userroom_id,
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
)?;
self.roomuserid_invitecount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
if let Some(servers) = invite_via {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
#[allow(clippy::redundant_clone)] // this is a necessary clone?
prev_servers.append(servers.clone().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers)?;
}
Ok(())
}
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let roomid = room_id.as_bytes().to_vec();
let mut roomid_prefix = room_id.as_bytes().to_vec();
roomid_prefix.push(0xFF);
let mut roomuser_id = roomid_prefix.clone();
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_leftstate.insert(
&userroom_id,
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
)?; // TODO
self.roomuserid_leftcount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
if self
.roomuserid_joined
.scan_prefix(roomid_prefix.clone())
.count() == 0
&& self
.roomuserid_invitecount
.scan_prefix(roomid_prefix)
.count() == 0
{
self.roomid_inviteviaservers.remove(&roomid)?;
}
Ok(())
}
fn update_joined_count(&self, room_id: &RoomId) -> Result<()> {
let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64;
let mut joined_servers = HashSet::new();
let mut real_users = HashSet::new();
for joined in self.room_members(room_id).filter_map(Result::ok) {
joined_servers.insert(joined.server_name().to_owned());
if user_is_local(&joined) && !services().users.is_deactivated(&joined).unwrap_or(true) {
real_users.insert(joined);
}
joinedcount = joinedcount.saturating_add(1);
}
for _invited in self.room_members_invited(room_id).filter_map(Result::ok) {
invitedcount = invitedcount.saturating_add(1);
}
self.roomid_joinedcount
.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
self.roomid_invitedcount
.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
self.our_real_users_cache
.write()
.unwrap()
.insert(room_id.to_owned(), Arc::new(real_users));
for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) {
if !joined_servers.remove(&old_joined_server) {
// Server not in room anymore
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.remove(&roomserver_id)?;
self.serverroomids.remove(&serverroom_id)?;
}
}
// Now only new servers are in joined_servers anymore
for server in joined_servers {
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(server.as_bytes());
let mut serverroom_id = server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.insert(&roomserver_id, &[])?;
self.serverroomids.insert(&serverroom_id, &[])?;
}
self.appservice_in_room_cache
.write()
.unwrap()
.remove(room_id);
Ok(())
}
#[tracing::instrument(skip(self, room_id))]
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> {
let maybe = self
.our_real_users_cache
.read()
.unwrap()
.get(room_id)
.cloned();
if let Some(users) = maybe {
Ok(users)
} else {
self.update_joined_count(room_id)?;
Ok(Arc::clone(
self.our_real_users_cache
.read()
.unwrap()
.get(room_id)
.unwrap(),
))
}
}
#[tracing::instrument(skip(self, room_id, appservice))]
fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> {
let maybe = self
.appservice_in_room_cache
.read()
.unwrap()
.get(room_id)
.and_then(|map| map.get(&appservice.registration.id))
.copied();
if let Some(b) = maybe {
Ok(b)
} else {
let bridge_user_id = UserId::parse_with_server_name(
appservice.registration.sender_localpart.as_str(),
services().globals.server_name(),
)
.ok();
let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
|| self
.room_members(room_id)
.any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str())));
self.appservice_in_room_cache
.write()
.unwrap()
.entry(room_id.to_owned())
.or_default()
.insert(appservice.registration.id.clone(), in_room);
Ok(in_room)
}
}
/// Makes a user forget a room.
#[tracing::instrument(skip(self))]
fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
Ok(())
}
/// Returns an iterator of all servers participating in this room.
#[tracing::instrument(skip(self))]
fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
ServerName::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
}))
}
#[tracing::instrument(skip(self))]
fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
let mut key = server.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.serverroomids.get(&key).map(|o| o.is_some())
}
/// Returns an iterator of all rooms a server participates in (as far as we
/// know).
#[tracing::instrument(skip(self))]
fn server_rooms<'a>(&'a self, server: &ServerName) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
let mut prefix = server.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
}))
}
/// Returns an iterator over all joined members of a room.
#[tracing::instrument(skip(self))]
fn room_members<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
}))
}
/// Returns the number of users which are currently in a room
#[tracing::instrument(skip(self))]
fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_joinedcount
.get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
.transpose()
}
/// Returns the number of users which are currently invited to a room
#[tracing::instrument(skip(self))]
fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_invitedcount
.get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
.transpose()
}
/// Returns an iterator over all User IDs who ever joined a room.
#[tracing::instrument(skip(self))]
fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.roomuseroncejoinedids
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
}),
)
}
/// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self))]
fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.roomuserid_invitecount
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
}),
)
}
#[tracing::instrument(skip(self))]
fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_invitecount
.get(&key)?
.map_or(Ok(None), |bytes| {
Ok(Some(
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?,
))
})
}
#[tracing::instrument(skip(self))]
fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_leftcount
.get(&key)?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db.")))
.transpose()
}
/// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self))]
fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(
self.userroomid_joined
.scan_prefix(user_id.as_bytes().to_vec())
.map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
}),
)
}
/// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self))]
fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.userroomid_invitestate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok((room_id, state))
}),
)
}
#[tracing::instrument(skip(self))]
fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok(state)
})
.transpose()
}
#[tracing::instrument(skip(self))]
fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.userroomid_leftstate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok(state)
})
.transpose()
}
/// Returns an iterator over all rooms a user left.
#[tracing::instrument(skip(self))]
fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.userroomid_leftstate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok((room_id, state))
}),
)
}
#[tracing::instrument(skip(self))]
fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self))]
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
self.roomid_inviteviaservers
.get(&key)?
.map(|servers| {
let state = serde_json::from_slice(&servers).map_err(|e| {
error!("Invalid state in userroomid_leftstate: {e}");
Error::bad_database("Invalid state in userroomid_leftstate.")
})?;
Ok(state)
})
.transpose()
}
#[tracing::instrument(skip(self))]
fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
prev_servers.append(servers.to_owned().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers)?;
Ok(())
}
}

View file

@ -1,64 +0,0 @@
use std::{collections::HashSet, mem::size_of, sync::Arc};
use crate::{
database::KeyValueDatabase,
service::{self, rooms::state_compressor::data::StateDiff},
utils, Error, Result,
};
impl service::rooms::state_compressor::Data for KeyValueDatabase {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self
.shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
let parent = if parent != 0 {
Some(parent)
} else {
None
};
let mut add_mode = true;
let mut added = HashSet::new();
let mut removed = HashSet::new();
let mut i = size_of::<u64>();
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false;
i += size_of::<u64>();
continue;
}
if add_mode {
added.insert(v.try_into().expect("we checked the size above"));
} else {
removed.insert(v.try_into().expect("we checked the size above"));
}
i += 2 * size_of::<u64>();
}
Ok(StateDiff {
parent,
added: Arc::new(added),
removed: Arc::new(removed),
})
}
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
for new in diff.added.iter() {
value.extend_from_slice(&new[..]);
}
if !diff.removed.is_empty() {
value.extend_from_slice(&0_u64.to_be_bytes());
for removed in diff.removed.iter() {
value.extend_from_slice(&removed[..]);
}
}
self.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value)
}
}

View file

@ -1,75 +0,0 @@
use std::mem;
use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>;
impl service::rooms::threads::Data for KeyValueDatabase {
fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
) -> PduEventIterResult<'a> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let mut current = prefix.clone();
current.extend_from_slice(&(until - 1).to_be_bytes());
Ok(Box::new(
self.threadid_userids
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pduid, _users)| {
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
let mut pdu = services()
.rooms
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((count, pdu))
}),
))
}
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
let users = participants
.iter()
.map(|user| user.as_bytes())
.collect::<Vec<_>>()
.join(&[0xFF][..]);
self.threadid_userids.insert(root_id, &users)?;
Ok(())
}
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
if let Some(users) = self.threadid_userids.get(root_id)? {
Ok(Some(
users
.split(|b| *b == 0xFF)
.map(|bytes| {
UserId::parse(
utils::string_from_bytes(bytes)
.map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?,
)
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
})
.filter_map(Result::ok)
.collect(),
))
} else {
Ok(None)
}
}
}

View file

@ -1,296 +0,0 @@
use std::{collections::hash_map, mem::size_of, sync::Arc};
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
use service::rooms::timeline::PduCount;
use tracing::error;
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
impl service::rooms::timeline::Data for KeyValueDatabase {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
match self
.lasttimelinecount_cache
.lock()
.unwrap()
.entry(room_id.to_owned())
{
hash_map::Entry::Vacant(v) => {
if let Some(last_count) = self
.pdus_until(sender_user, room_id, PduCount::max())?
.find_map(|r| {
// Filter out buggy events
if r.is_err() {
error!("Bad pdu in pdus_since: {:?}", r);
}
r.ok()
}) {
Ok(*v.insert(last_count.0))
} else {
Ok(PduCount::Normal(0))
}
},
hash_map::Entry::Occupied(o) => Ok(*o.get()),
}
}
/// Returns the `count` of this pdu's id.
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pdu_id| pdu_count(&pdu_id))
.transpose()
}
/// Returns the json of a pdu.
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
},
|x| Ok(Some(x)),
)
}
/// Returns the json of a pdu.
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
})
.transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
}
/// Returns the pdu's id.
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { self.eventid_pduid.get(event_id.as_bytes()) }
/// Returns the pdu.
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
})
.transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
}
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(pdu) = self
.get_non_outlier_pdu(event_id)?
.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
},
|x| Ok(Some(x)),
)?
.map(Arc::new)
{
Ok(Some(pdu))
} else {
Ok(None)
}
}
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
}
fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
self.lasttimelinecount_cache
.lock()
.unwrap()
.insert(pdu.room_id.clone(), PduCount::Normal(count));
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
Ok(())
}
fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(event_id.as_bytes())?;
Ok(())
}
/// Removes a pdu and creates a new one with the same id.
fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> {
if self.pduid_pdu.get(pdu_id)?.is_some() {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
)?;
} else {
return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist."));
}
Ok(())
}
/// Returns an iterator over all events and their tokens in a room that
/// happened before the event with id `until` in reverse-chronological
/// order.
fn pdus_until<'a>(
&'a self, user_id: &UserId, room_id: &RoomId, until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
let user_id = user_id.to_owned();
Ok(Box::new(
self.pduid_pdu
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
pdu.add_age()?;
let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}),
))
}
fn pdus_after<'a>(
&'a self, user_id: &UserId, room_id: &RoomId, from: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
let user_id = user_id.to_owned();
Ok(Box::new(
self.pduid_pdu
.iter_from(&current, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
pdu.add_age()?;
let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}),
))
}
fn increment_notification_counts(
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
) -> Result<()> {
let mut notifies_batch = Vec::new();
let mut highlights_batch = Vec::new();
for user in notifies {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
notifies_batch.push(userroom_id);
}
for user in highlights {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
highlights_batch.push(userroom_id);
}
self.userroomid_notificationcount
.increment_batch(&mut notifies_batch.into_iter())?;
self.userroomid_highlightcount
.increment_batch(&mut highlights_batch.into_iter())?;
Ok(())
}
}
/// Returns the `count` of this pdu's id.
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
let second_last_u64 =
utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()]);
if matches!(second_last_u64, Ok(0)) {
Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64)))
} else {
Ok(PduCount::Normal(last_u64))
}
}
fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();
// +1 so we don't send the base event
let count_raw = match count {
PduCount::Normal(x) => {
if subtract {
x.saturating_sub(offset)
} else {
x.saturating_add(offset)
}
},
PduCount::Backfilled(x) => {
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
let num = u64::MAX.saturating_sub(x);
if subtract {
num.saturating_sub(offset)
} else {
num.saturating_add(offset)
}
},
};
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
Ok((prefix, pdu_id))
}

View file

@ -1,137 +0,0 @@
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
impl service::rooms::user::Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_notificationcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.userroomid_highlightcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.roomuserid_lastnotificationread
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
Ok(())
}
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_notificationcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db."))
})
}
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_highlightcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db."))
})
}
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
.roomuserid_lastnotificationread
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
})
.transpose()?
.unwrap_or(0))
}
fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
let shortroomid = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.insert(&key, &shortstatehash.to_be_bytes())
}
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash"))
})
.transpose()
}
fn get_shared_rooms<'a>(
&'a self, users: Vec<OwnedUserId>,
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
let iterators = users.into_iter().map(move |user_id| {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
self.userroomid_joined
.scan_prefix(prefix)
.map(|(key, _)| {
let roomid_index = key
.iter()
.enumerate()
.find(|(_, &b)| b == 0xFF)
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
.0
.saturating_add(1); // +1 because the room id starts AFTER the separator
let room_id = key[roomid_index..].to_vec();
Ok::<_, Error>(room_id)
})
.filter_map(Result::ok)
});
// We use the default compare function because keys are sorted correctly (not
// reversed)
Ok(Box::new(
utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?,
)
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
}),
))
}
}

View file

@ -1,195 +0,0 @@
use ruma::{ServerName, UserId};
use crate::{
database::KeyValueDatabase,
service::{
self,
sending::{Destination, SendingEvent},
},
services, utils, Error, Result,
};
impl service::sending::Data for KeyValueDatabase {
fn active_requests<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a> {
Box::new(
self.servercurrentevent_data
.iter()
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))),
)
}
fn active_requests_for<'a>(
&'a self, destination: &Destination,
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a> {
let prefix = destination.get_prefix();
Box::new(
self.servercurrentevent_data
.scan_prefix(prefix)
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
)
}
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { self.servercurrentevent_data.remove(&key) }
fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> {
let prefix = destination.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
self.servercurrentevent_data.remove(&key)?;
}
Ok(())
}
fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> {
let prefix = destination.get_prefix();
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
self.servercurrentevent_data.remove(&key).unwrap();
}
for (key, _) in self.servernameevent_data.scan_prefix(prefix) {
self.servernameevent_data.remove(&key).unwrap();
}
Ok(())
}
fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result<Vec<Vec<u8>>> {
let mut batch = Vec::new();
let mut keys = Vec::new();
for (destination, event) in requests {
let mut key = destination.get_prefix();
if let SendingEvent::Pdu(value) = &event {
key.extend_from_slice(value);
} else {
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
}
let value = if let SendingEvent::Edu(value) = &event {
&**value
} else {
&[]
};
batch.push((key.clone(), value.to_owned()));
keys.push(key);
}
self.servernameevent_data
.insert_batch(&mut batch.into_iter())?;
Ok(keys)
}
fn queued_requests<'a>(
&'a self, destination: &Destination,
) -> Box<dyn Iterator<Item = Result<(SendingEvent, Vec<u8>)>> + 'a> {
let prefix = destination.get_prefix();
return Box::new(
self.servernameevent_data
.scan_prefix(prefix)
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
);
}
fn mark_as_active(&self, events: &[(SendingEvent, Vec<u8>)]) -> Result<()> {
for (e, key) in events {
if key.is_empty() {
continue;
}
let value = if let SendingEvent::Edu(value) = &e {
&**value
} else {
&[]
};
self.servercurrentevent_data.insert(key, value)?;
self.servernameevent_data.remove(key)?;
}
Ok(())
}
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
self.servername_educount
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
}
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
self.servername_educount
.get(server_name.as_bytes())?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
})
}
}
#[tracing::instrument(skip(key))]
fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination, SendingEvent)> {
// Appservices start with a plus
Ok::<_, Error>(if key.starts_with(b"+") {
let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element");
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server)
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
(
Destination::Appservice(server),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
SendingEvent::Edu(value)
},
)
} else if key.starts_with(b"$") {
let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
let user = parts.next().expect("splitn always returns one element");
let user_string = utils::string_from_bytes(user)
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
let user_id =
UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
let pushkey = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let pushkey_string = utils::string_from_bytes(pushkey)
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
(
Destination::Push(user_id, pushkey_string),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
// I'm pretty sure this should never be called
SendingEvent::Edu(value)
},
)
} else {
let mut parts = key.splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element");
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server)
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
(
Destination::Normal(
ServerName::parse(server)
.map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?,
),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
} else {
SendingEvent::Edu(value)
},
)
})
}

View file

@ -1,32 +0,0 @@
use ruma::{DeviceId, TransactionId, UserId};
use crate::{database::KeyValueDatabase, service, Result};
impl service::transaction_ids::Data for KeyValueDatabase {
fn add_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
self.userdevicetxnid_response.insert(&key, data)?;
Ok(())
}
fn existing_txnid(
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
) -> Result<Option<Vec<u8>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
// If there's no entry, this is a new transaction
self.userdevicetxnid_response.get(&key)
}
}

View file

@ -1,68 +0,0 @@
use ruma::{
api::client::{error::ErrorKind, uiaa::UiaaInfo},
CanonicalJsonValue, DeviceId, UserId,
};
use crate::{database::KeyValueDatabase, service, Error, Result};
impl service::uiaa::Data for KeyValueDatabase {
fn set_uiaa_request(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
) -> Result<()> {
self.userdevicesessionid_uiaarequest
.write()
.unwrap()
.insert(
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request.to_owned(),
);
Ok(())
}
fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option<CanonicalJsonValue> {
self.userdevicesessionid_uiaarequest
.read()
.unwrap()
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
.map(ToOwned::to_owned)
}
fn update_uiaa_session(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>,
) -> Result<()> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
if let Some(uiaainfo) = uiaainfo {
self.userdevicesessionid_uiaainfo.insert(
&userdevicesessionid,
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
)?;
} else {
self.userdevicesessionid_uiaainfo
.remove(&userdevicesessionid)?;
}
Ok(())
}
fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xFF);
userdevicesessionid.extend_from_slice(session.as_bytes());
serde_json::from_slice(
&self
.userdevicesessionid_uiaainfo
.get(&userdevicesessionid)?
.ok_or(Error::BadRequest(ErrorKind::forbidden(), "UIAA session does not exist."))?,
)
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
}
}

View file

@ -1,893 +0,0 @@
use std::{collections::BTreeMap, mem::size_of};
use ruma::{
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::{AnyToDeviceEvent, StateEventType},
serde::Raw,
uint, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId,
OwnedMxcUri, OwnedUserId, UInt, UserId,
};
use tracing::warn;
use crate::{
database::KeyValueDatabase,
service::{self, users::clean_signatures},
services, utils, Error, Result,
};
impl service::users::Data for KeyValueDatabase {
/// Check if a user has an account on this homeserver.
fn exists(&self, user_id: &UserId) -> Result<bool> { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) }
/// Check if account is deactivated
fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
Ok(self
.userid_password
.get(user_id.as_bytes())?
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))?
.is_empty())
}
/// Returns the number of users registered on this server.
fn count(&self) -> Result<usize> { Ok(self.userid_password.iter().count()) }
/// Find out which user an access token belongs to.
fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> {
self.token_userdeviceid
.get(token.as_bytes())?
.map_or(Ok(None), |bytes| {
let mut parts = bytes.split(|&b| b == 0xFF);
let user_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?;
let device_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?;
Ok(Some((
UserId::parse(
utils::string_from_bytes(user_bytes)
.map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid."))?,
utils::string_from_bytes(device_bytes)
.map_err(|_| Error::bad_database("Device ID in token_userdeviceid is invalid."))?,
)))
})
}
/// Returns an iterator over all users on this homeserver.
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
Box::new(self.userid_password.iter().map(|(bytes, _)| {
UserId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))
}))
}
/// Returns a list of local users as list of usernames.
///
/// A user account is considered `local` if the length of it's password is
/// greater then zero.
fn list_local_users(&self) -> Result<Vec<String>> {
let users: Vec<String> = self
.userid_password
.iter()
.filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw))
.collect();
Ok(users)
}
/// Returns the password hash for the given user.
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Password hash in db is not valid string.")
})?))
})
}
/// Hash and set the user's password to the Argon2 hash
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password {
if let Ok(hash) = utils::calculate_password_hash(password) {
self.userid_password
.insert(user_id.as_bytes(), hash.as_bytes())?;
Ok(())
} else {
Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Password does not meet the requirements.",
))
}
} else {
self.userid_password.insert(user_id.as_bytes(), b"")?;
Ok(())
}
}
/// Returns the displayname of a user on this homeserver.
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Displayname in db is invalid."))?,
))
})
}
/// Sets a new displayname or removes it if displayname is None. You still
/// need to nofify all rooms of this change.
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
if let Some(displayname) = displayname {
self.userid_displayname
.insert(user_id.as_bytes(), displayname.as_bytes())?;
} else {
self.userid_displayname.remove(user_id.as_bytes())?;
}
Ok(())
}
/// Get the `avatar_url` of a user.
fn avatar_url(&self, user_id: &UserId) -> Result<Option<OwnedMxcUri>> {
self.userid_avatarurl
.get(user_id.as_bytes())?
.map(|bytes| {
let s_bytes = utils::string_from_bytes(&bytes).map_err(|e| {
warn!("Avatar URL in db is invalid: {}", e);
Error::bad_database("Avatar URL in db is invalid.")
})?;
let mxc_uri: OwnedMxcUri = s_bytes.into();
Ok(mxc_uri)
})
.transpose()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> {
if let Some(avatar_url) = avatar_url {
self.userid_avatarurl
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
} else {
self.userid_avatarurl.remove(user_id.as_bytes())?;
}
Ok(())
}
/// Get the blurhash of a user.
fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_blurhash
.get(user_id.as_bytes())?
.map(|bytes| {
let s = utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?;
Ok(s)
})
.transpose()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
if let Some(blurhash) = blurhash {
self.userid_blurhash
.insert(user_id.as_bytes(), blurhash.as_bytes())?;
} else {
self.userid_blurhash.remove(user_id.as_bytes())?;
}
Ok(())
}
/// Adds a new device to a user.
fn create_device(
&self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>,
) -> Result<()> {
// This method should never be called for nonexistent users. We shouldn't assert
// though...
if !self.exists(user_id)? {
warn!("Called create_device for non-existent user {} in database", user_id);
return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."));
}
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(&Device {
device_id: device_id.into(),
display_name: initial_device_display_name,
last_seen_ip: None, // TODO
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
})
.expect("Device::to_string never fails."),
)?;
self.set_token(user_id, device_id, token)?;
Ok(())
}
/// Removes a device from a user.
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// Remove tokens
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
self.userdeviceid_token.remove(&userdeviceid)?;
self.token_userdeviceid.remove(&old_token)?;
}
// Remove todevice events
let mut prefix = userdeviceid.clone();
prefix.push(0xFF);
for (key, _) in self.todeviceid_events.scan_prefix(prefix) {
self.todeviceid_events.remove(&key)?;
}
// TODO: Remove onetimekeys
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.remove(&userdeviceid)?;
Ok(())
}
/// Returns an iterator over all device ids of this user.
fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
// All devices have metadata
Box::new(
self.userdeviceid_metadata
.scan_prefix(prefix)
.map(|(bytes, _)| {
Ok(utils::string_from_bytes(
bytes
.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?,
)
.map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))?
.into())
}),
)
}
/// Replaces the access token of one device.
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// should not be None, but we shouldn't assert either lol...
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
warn!(
"Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database",
user_id, device_id
);
return Err(Error::bad_database(
"User does not exist or device ID has no metadata in database.",
));
}
// Remove old token
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
self.token_userdeviceid.remove(&old_token)?;
// It will be removed from userdeviceid_token by the insert later
}
// Assign token to user device combination
self.userdeviceid_token
.insert(&userdeviceid, token.as_bytes())?;
self.token_userdeviceid
.insert(token.as_bytes(), &userdeviceid)?;
Ok(())
}
fn add_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId,
one_time_key_value: &Raw<OneTimeKey>,
) -> Result<()> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
// All devices have metadata
// Only existing devices should be able to call this, but we shouldn't assert
// either...
if self.userdeviceid_metadata.get(&key)?.is_none() {
warn!(
"Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in \
database",
user_id, device_id
);
return Err(Error::bad_database(
"User does not exist or device ID has no metadata in database.",
));
}
key.push(0xFF);
// TODO: Use DeviceKeyId::to_string when it's available (and update everything,
// because there are no wrapping quotation marks anymore)
key.extend_from_slice(
serde_json::to_string(one_time_key_key)
.expect("DeviceKeyId::to_string always works")
.as_bytes(),
);
self.onetimekeyid_onetimekeys.insert(
&key,
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
)?;
self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
Ok(())
}
fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
self.userid_lastonetimekeyupdate
.get(user_id.as_bytes())?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid."))
})
}
fn take_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm,
) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.push(b'"'); // Annoying quotation mark
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
prefix.push(b':');
self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
self.onetimekeyid_onetimekeys
.scan_prefix(prefix)
.next()
.map(|(key, value)| {
self.onetimekeyid_onetimekeys.remove(&key)?;
Ok((
serde_json::from_slice(
key.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?,
)
.map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?,
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?,
))
})
.transpose()
}
fn count_one_time_keys(
&self, user_id: &UserId, device_id: &DeviceId,
) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
let mut counts = BTreeMap::new();
for algorithm in self
.onetimekeyid_onetimekeys
.scan_prefix(userdeviceid)
.map(|(bytes, _)| {
Ok::<_, Error>(
serde_json::from_slice::<OwnedDeviceKeyId>(
bytes
.rsplit(|&b| b == 0xFF)
.next()
.ok_or_else(|| Error::bad_database("OneTimeKey ID in db is invalid."))?,
)
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))?
.algorithm(),
)
}) {
*counts.entry(algorithm?).or_default() += uint!(1);
}
Ok(counts)
}
fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.keyid_key.insert(
&userdeviceid,
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
)?;
self.mark_device_key_update(user_id)?;
Ok(())
}
fn add_cross_signing_keys(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>,
user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool,
) -> Result<()> {
// TODO: Check signatures
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let (master_key_key, _) = self.parse_master_key(user_id, master_key)?;
self.keyid_key
.insert(&master_key_key, master_key.json().get().as_bytes())?;
self.userid_masterkeyid
.insert(user_id.as_bytes(), &master_key_key)?;
// Self-signing key
if let Some(self_signing_key) = self_signing_key {
let mut self_signing_key_ids = self_signing_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))?
.keys
.into_values();
let self_signing_key_id = self_signing_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?;
if self_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Self signing key contained more than one key.",
));
}
let mut self_signing_key_key = prefix.clone();
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
self.keyid_key
.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?;
self.userid_selfsigningkeyid
.insert(user_id.as_bytes(), &self_signing_key_key)?;
}
// User-signing key
if let Some(user_signing_key) = user_signing_key {
let mut user_signing_key_ids = user_signing_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))?
.keys
.into_values();
let user_signing_key_id = user_signing_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?;
if user_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"User signing key contained more than one key.",
));
}
let mut user_signing_key_key = prefix;
user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes());
self.keyid_key
.insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?;
self.userid_usersigningkeyid
.insert(user_id.as_bytes(), &user_signing_key_key)?;
}
if notify {
self.mark_device_key_update(user_id)?;
}
Ok(())
}
fn sign_key(
&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId,
) -> Result<()> {
let mut key = target_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(key_id.as_bytes());
let mut cross_signing_key: serde_json::Value = serde_json::from_slice(
&self
.keyid_key
.get(&key)?
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?,
)
.map_err(|_| Error::bad_database("key in keyid_key is invalid."))?;
let signatures = cross_signing_key
.get_mut("signatures")
.ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))?
.as_object_mut()
.ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))?
.entry(sender_id.to_string())
.or_insert_with(|| serde_json::Map::new().into());
signatures
.as_object_mut()
.ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))?
.insert(signature.0, signature.1.into());
self.keyid_key.insert(
&key,
&serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"),
)?;
self.mark_device_key_update(target_id)?;
Ok(())
}
fn keys_changed<'a>(
&'a self, user_or_room_id: &str, from: u64, to: Option<u64>,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = user_or_room_id.as_bytes().to_vec();
prefix.push(0xFF);
let mut start = prefix.clone();
start.extend_from_slice(&(from.saturating_add(1)).to_be_bytes());
let to = to.unwrap_or(u64::MAX);
Box::new(
self.keychangeid_userid
.iter_from(&start, false)
.take_while(move |(k, _)| {
k.starts_with(&prefix)
&& if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) {
if let Ok(c) = utils::u64_from_bytes(current) {
c <= to
} else {
warn!("BadDatabase: Could not parse keychangeid_userid bytes");
false
}
} else {
warn!("BadDatabase: Could not parse keychangeid_userid");
false
}
})
.map(|(_, bytes)| {
UserId::parse(
utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid."))
}),
)
}
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
let count = services().globals.next_count()?.to_be_bytes();
for room_id in services()
.rooms
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
// Don't send key updates to unencrypted rooms
if services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?
.is_none()
{
continue;
}
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&count);
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
}
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&count);
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
Ok(())
}
fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Raw<DeviceKeys>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(
serde_json::from_slice(&bytes).map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?,
))
})
}
fn parse_master_key(
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>,
) -> Result<(Vec<u8>, CrossSigningKey)> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
let master_key = master_key
.deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
let mut master_key_ids = master_key.keys.values();
let master_key_id = master_key_ids
.next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?;
if master_key_ids.next().is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Master key contained more than one key.",
));
}
let mut master_key_key = prefix.clone();
master_key_key.extend_from_slice(master_key_id.as_bytes());
Ok((master_key_key, master_key))
}
fn get_key(
&self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes)
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?;
clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?;
Ok(Some(Raw::from_json(
serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"),
)))
})
}
fn get_master_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_masterkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
}
fn get_self_signing_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_selfsigningkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
}
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_usersigningkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| {
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(
serde_json::from_slice(&bytes)
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?,
))
})
})
}
fn add_to_device_event(
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
content: serde_json::Value,
) -> Result<()> {
let mut key = target_user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(target_device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
let mut json = serde_json::Map::new();
json.insert("type".to_owned(), event_type.to_owned().into());
json.insert("sender".to_owned(), sender.to_string().into());
json.insert("content".to_owned(), content);
let value = serde_json::to_vec(&json).expect("Map::to_vec always works");
self.todeviceid_events.insert(&key, &value)?;
Ok(())
}
fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Vec<Raw<AnyToDeviceEvent>>> {
let mut events = Vec::new();
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
for (_, value) in self.todeviceid_events.scan_prefix(prefix) {
events.push(
serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?,
);
}
Ok(events)
}
fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
let mut last = prefix.clone();
last.extend_from_slice(&until.to_be_bytes());
for (key, _) in self
.todeviceid_events
.iter_from(&last, true) // this includes last
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(key, _)| {
Ok::<_, Error>((
key.clone(),
utils::u64_from_bytes(&key[key.len() - size_of::<u64>()..key.len()])
.map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?,
))
})
.filter_map(Result::ok)
.take_while(|&(_, count)| count <= until)
{
self.todeviceid_events.remove(&key)?;
}
Ok(())
}
fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
// Only existing devices should be able to call this, but we shouldn't assert
// either...
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
warn!(
"Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \
metadata in database",
user_id, device_id
);
return Err(Error::bad_database(
"User does not exist or device ID has no metadata in database.",
));
}
self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(device).expect("Device::to_string always works"),
)?;
Ok(())
}
/// Get device metadata.
fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.userdeviceid_metadata
.get(&userdeviceid)?
.map_or(Ok(None), |bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("Metadata in userdeviceid_metadata is invalid.")
})?))
})
}
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
self.userid_devicelistversion
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid devicelistversion in db."))
.map(Some)
})
}
fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<Device>> + 'a> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
Box::new(
self.userdeviceid_metadata
.scan_prefix(key)
.map(|(_, bytes)| {
serde_json::from_slice::<Device>(&bytes)
.map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid."))
}),
)
}
/// Creates a new sync filter. Returns the filter id.
fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String> {
let filter_id = utils::random_string(4);
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes());
self.userfilterid_filter
.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?;
Ok(filter_id)
}
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes());
let raw = self.userfilterid_filter.get(&key)?;
if let Some(raw) = raw {
serde_json::from_slice(&raw).map_err(|_| Error::bad_database("Invalid filter event in db."))
} else {
Ok(None)
}
}
}
impl KeyValueDatabase {}
/// Will only return with Some(username) if the password was not empty and the
/// username could be successfully parsed.
/// If `utils::string_from_bytes`(...) returns an error that username will be
/// skipped and the error will be logged.
fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> {
// A valid password is not empty
if password.is_empty() {
None
} else {
match utils::string_from_bytes(username) {
Ok(u) => Some(u),
Err(e) => {
warn!("Failed to parse username while calling get_local_users(): {}", e.to_string());
None
},
}
}
}

320
src/database/kvdatabase.rs Normal file
View file

@ -0,0 +1,320 @@
use std::{
collections::{BTreeMap, HashMap, HashSet},
path::Path,
sync::{Arc, Mutex, RwLock},
};
use conduit::{Config, Error, PduCount, Result, Server};
use lru_cache::LruCache;
use ruma::{CanonicalJsonValue, OwnedDeviceId, OwnedRoomId, OwnedUserId};
use tracing::debug;
use crate::{KeyValueDatabaseEngine, KvTree};
pub struct KeyValueDatabase {
pub db: Arc<dyn KeyValueDatabaseEngine>,
//pub globals: globals::Globals,
pub global: Arc<dyn KvTree>,
pub server_signingkeys: Arc<dyn KvTree>,
pub roomid_inviteviaservers: Arc<dyn KvTree>,
//pub users: users::Users,
pub userid_password: Arc<dyn KvTree>,
pub userid_displayname: Arc<dyn KvTree>,
pub userid_avatarurl: Arc<dyn KvTree>,
pub userid_blurhash: Arc<dyn KvTree>,
pub userdeviceid_token: Arc<dyn KvTree>,
pub userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists
pub userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64
pub token_userdeviceid: Arc<dyn KvTree>,
pub onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId
pub userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count
pub keychangeid_userid: Arc<dyn KvTree>, // KeyChangeId = UserId/RoomId + Count
pub keyid_key: Arc<dyn KvTree>, // KeyId = UserId + KeyId (depends on key type)
pub userid_masterkeyid: Arc<dyn KvTree>,
pub userid_selfsigningkeyid: Arc<dyn KvTree>,
pub userid_usersigningkeyid: Arc<dyn KvTree>,
pub userfilterid_filter: Arc<dyn KvTree>, // UserFilterId = UserId + FilterId
pub todeviceid_events: Arc<dyn KvTree>, // ToDeviceId = UserId + DeviceId + Count
pub userid_presenceid: Arc<dyn KvTree>, // UserId => Count
pub presenceid_presence: Arc<dyn KvTree>, // Count + UserId => Presence
//pub uiaa: uiaa::Uiaa,
pub userdevicesessionid_uiaainfo: Arc<dyn KvTree>, // User-interactive authentication
pub userdevicesessionid_uiaarequest: RwLock<BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>>,
//pub edus: RoomEdus,
pub readreceiptid_readreceipt: Arc<dyn KvTree>, // ReadReceiptId = RoomId + Count + UserId
pub roomuserid_privateread: Arc<dyn KvTree>, // RoomUserId = Room + User, PrivateRead = Count
pub roomuserid_lastprivatereadupdate: Arc<dyn KvTree>, // LastPrivateReadUpdate = Count
//pub rooms: rooms::Rooms,
pub pduid_pdu: Arc<dyn KvTree>, // PduId = ShortRoomId + Count
pub eventid_pduid: Arc<dyn KvTree>,
pub roomid_pduleaves: Arc<dyn KvTree>,
pub alias_roomid: Arc<dyn KvTree>,
pub aliasid_alias: Arc<dyn KvTree>, // AliasId = RoomId + Count
pub publicroomids: Arc<dyn KvTree>,
pub threadid_userids: Arc<dyn KvTree>, // ThreadId = RoomId + Count
pub tokenids: Arc<dyn KvTree>, // TokenId = ShortRoomId + Token + PduIdCount
/// Participating servers in a room.
pub roomserverids: Arc<dyn KvTree>, // RoomServerId = RoomId + ServerName
pub serverroomids: Arc<dyn KvTree>, // ServerRoomId = ServerName + RoomId
pub userroomid_joined: Arc<dyn KvTree>,
pub roomuserid_joined: Arc<dyn KvTree>,
pub roomid_joinedcount: Arc<dyn KvTree>,
pub roomid_invitedcount: Arc<dyn KvTree>,
pub roomuseroncejoinedids: Arc<dyn KvTree>,
pub userroomid_invitestate: Arc<dyn KvTree>, // InviteState = Vec<Raw<Pdu>>
pub roomuserid_invitecount: Arc<dyn KvTree>, // InviteCount = Count
pub userroomid_leftstate: Arc<dyn KvTree>,
pub roomuserid_leftcount: Arc<dyn KvTree>,
pub disabledroomids: Arc<dyn KvTree>, // Rooms where incoming federation handling is disabled
pub bannedroomids: Arc<dyn KvTree>, // Rooms where local users are not allowed to join
pub lazyloadedids: Arc<dyn KvTree>, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId
pub userroomid_notificationcount: Arc<dyn KvTree>, // NotifyCount = u64
pub userroomid_highlightcount: Arc<dyn KvTree>, // HightlightCount = u64
pub roomuserid_lastnotificationread: Arc<dyn KvTree>, // LastNotificationRead = u64
/// Remember the current state hash of a room.
pub roomid_shortstatehash: Arc<dyn KvTree>,
pub roomsynctoken_shortstatehash: Arc<dyn KvTree>,
/// Remember the state hash at events in the past.
pub shorteventid_shortstatehash: Arc<dyn KvTree>,
pub statekey_shortstatekey: Arc<dyn KvTree>, /* StateKey = EventType + StateKey, ShortStateKey =
* Count */
pub shortstatekey_statekey: Arc<dyn KvTree>,
pub roomid_shortroomid: Arc<dyn KvTree>,
pub shorteventid_eventid: Arc<dyn KvTree>,
pub eventid_shorteventid: Arc<dyn KvTree>,
pub statehash_shortstatehash: Arc<dyn KvTree>,
pub shortstatehash_statediff: Arc<dyn KvTree>, /* StateDiff = parent (or 0) +
* (shortstatekey+shorteventid++) + 0_u64 +
* (shortstatekey+shorteventid--) */
pub shorteventid_authchain: Arc<dyn KvTree>,
/// RoomId + EventId -> outlier PDU.
/// Any pdu that has passed the steps 1-8 in the incoming event
/// /federation/send/txn.
pub eventid_outlierpdu: Arc<dyn KvTree>,
pub softfailedeventids: Arc<dyn KvTree>,
/// ShortEventId + ShortEventId -> ().
pub tofrom_relation: Arc<dyn KvTree>,
/// RoomId + EventId -> Parent PDU EventId.
pub referencedevents: Arc<dyn KvTree>,
//pub account_data: account_data::AccountData,
pub roomuserdataid_accountdata: Arc<dyn KvTree>, // RoomUserDataId = Room + User + Count + Type
pub roomusertype_roomuserdataid: Arc<dyn KvTree>, // RoomUserType = Room + User + Type
//pub media: media::Media,
pub mediaid_file: Arc<dyn KvTree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
pub url_previews: Arc<dyn KvTree>,
pub mediaid_user: Arc<dyn KvTree>,
//pub key_backups: key_backups::KeyBackups,
pub backupid_algorithm: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
pub backupid_etag: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
pub backupkeyid_backup: Arc<dyn KvTree>, // BackupKeyId = UserId + Version + RoomId + SessionId
//pub transaction_ids: transaction_ids::TransactionIds,
pub userdevicetxnid_response: Arc<dyn KvTree>, /* Response can be empty (/sendToDevice) or the event id
* (/send) */
//pub sending: sending::Sending,
pub servername_educount: Arc<dyn KvTree>, // EduCount: Count of last EDU sync
pub servernameevent_data: Arc<dyn KvTree>, /* ServernameEvent = (+ / $)SenderKey / ServerName / UserId +
* PduId / Id (for edus), Data = EDU content */
pub servercurrentevent_data: Arc<dyn KvTree>, /* ServerCurrentEvents = (+ / $)ServerName / UserId + PduId
* / Id (for edus), Data = EDU content */
//pub appservice: appservice::Appservice,
pub id_appserviceregistrations: Arc<dyn KvTree>,
//pub pusher: pusher::PushData,
pub senderkey_pusher: Arc<dyn KvTree>,
pub auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<[u64]>>>,
pub our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>,
pub appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>,
pub lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,
}
impl KeyValueDatabase {
/// Load an existing database or create a new one.
#[allow(clippy::too_many_lines)]
pub async fn load_or_create(server: &Arc<Server>) -> Result<KeyValueDatabase> {
let config = &server.config;
check_db_setup(config)?;
let builder = build(config)?;
Ok(Self {
db: builder.clone(),
userid_password: builder.open_tree("userid_password")?,
userid_displayname: builder.open_tree("userid_displayname")?,
userid_avatarurl: builder.open_tree("userid_avatarurl")?,
userid_blurhash: builder.open_tree("userid_blurhash")?,
userdeviceid_token: builder.open_tree("userdeviceid_token")?,
userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?,
userid_devicelistversion: builder.open_tree("userid_devicelistversion")?,
token_userdeviceid: builder.open_tree("token_userdeviceid")?,
onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?,
userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?,
keychangeid_userid: builder.open_tree("keychangeid_userid")?,
keyid_key: builder.open_tree("keyid_key")?,
userid_masterkeyid: builder.open_tree("userid_masterkeyid")?,
userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?,
userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
userfilterid_filter: builder.open_tree("userfilterid_filter")?,
todeviceid_events: builder.open_tree("todeviceid_events")?,
userid_presenceid: builder.open_tree("userid_presenceid")?,
presenceid_presence: builder.open_tree("presenceid_presence")?,
userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
roomuserid_lastprivatereadupdate: builder.open_tree("roomuserid_lastprivatereadupdate")?,
pduid_pdu: builder.open_tree("pduid_pdu")?,
eventid_pduid: builder.open_tree("eventid_pduid")?,
roomid_pduleaves: builder.open_tree("roomid_pduleaves")?,
alias_roomid: builder.open_tree("alias_roomid")?,
aliasid_alias: builder.open_tree("aliasid_alias")?,
publicroomids: builder.open_tree("publicroomids")?,
threadid_userids: builder.open_tree("threadid_userids")?,
tokenids: builder.open_tree("tokenids")?,
roomserverids: builder.open_tree("roomserverids")?,
serverroomids: builder.open_tree("serverroomids")?,
userroomid_joined: builder.open_tree("userroomid_joined")?,
roomuserid_joined: builder.open_tree("roomuserid_joined")?,
roomid_joinedcount: builder.open_tree("roomid_joinedcount")?,
roomid_invitedcount: builder.open_tree("roomid_invitedcount")?,
roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?,
userroomid_invitestate: builder.open_tree("userroomid_invitestate")?,
roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?,
userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,
disabledroomids: builder.open_tree("disabledroomids")?,
bannedroomids: builder.open_tree("bannedroomids")?,
lazyloadedids: builder.open_tree("lazyloadedids")?,
userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?,
userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?,
roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?,
statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?,
shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?,
shorteventid_authchain: builder.open_tree("shorteventid_authchain")?,
roomid_shortroomid: builder.open_tree("roomid_shortroomid")?,
shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?,
eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?,
roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?,
roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?,
statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?,
eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
softfailedeventids: builder.open_tree("softfailedeventids")?,
tofrom_relation: builder.open_tree("tofrom_relation")?,
referencedevents: builder.open_tree("referencedevents")?,
roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?,
mediaid_file: builder.open_tree("mediaid_file")?,
url_previews: builder.open_tree("url_previews")?,
mediaid_user: builder.open_tree("mediaid_user")?,
backupid_algorithm: builder.open_tree("backupid_algorithm")?,
backupid_etag: builder.open_tree("backupid_etag")?,
backupkeyid_backup: builder.open_tree("backupkeyid_backup")?,
userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?,
servername_educount: builder.open_tree("servername_educount")?,
servernameevent_data: builder.open_tree("servernameevent_data")?,
servercurrentevent_data: builder.open_tree("servercurrentevent_data")?,
id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?,
senderkey_pusher: builder.open_tree("senderkey_pusher")?,
global: builder.open_tree("global")?,
server_signingkeys: builder.open_tree("server_signingkeys")?,
roomid_inviteviaservers: builder.open_tree("roomid_inviteviaservers")?,
auth_chain_cache: Mutex::new(LruCache::new(
(f64::from(config.auth_chain_cache_capacity) * config.conduit_cache_capacity_modifier) as usize,
)),
our_real_users_cache: RwLock::new(HashMap::new()),
appservice_in_room_cache: RwLock::new(HashMap::new()),
lasttimelinecount_cache: Mutex::new(HashMap::new()),
})
}
}
fn build(config: &Config) -> Result<Arc<dyn KeyValueDatabaseEngine>> {
match &*config.database_backend {
"sqlite" => {
debug!("Got sqlite database backend");
#[cfg(not(feature = "sqlite"))]
return Err(Error::bad_config("Database backend not found."));
#[cfg(feature = "sqlite")]
Ok(Arc::new(Arc::<crate::sqlite::Engine>::open(config)?))
},
"rocksdb" => {
debug!("Got rocksdb database backend");
#[cfg(not(feature = "rocksdb"))]
return Err(Error::bad_config("Database backend not found."));
#[cfg(feature = "rocksdb")]
Ok(Arc::new(Arc::<crate::rocksdb::Engine>::open(config)?))
},
_ => Err(Error::bad_config(
"Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends.",
)),
}
}
fn check_db_setup(config: &Config) -> Result<()> {
let path = Path::new(&config.database_path);
let sqlite_exists = path.join("conduit.db").exists();
let rocksdb_exists = path.join("IDENTITY").exists();
if sqlite_exists && rocksdb_exists {
return Err(Error::bad_config("Multiple databases at database_path detected."));
}
if sqlite_exists && config.database_backend != "sqlite" {
return Err(Error::bad_config(
"Found sqlite at database_path, but is not specified in config.",
));
}
if rocksdb_exists && config.database_backend != "rocksdb" {
return Err(Error::bad_config(
"Found rocksdb at database_path, but is not specified in config.",
));
}
Ok(())
}

View file

@ -3,7 +3,7 @@ use std::{error::Error, sync::Arc};
use super::{Config, KvTree};
use crate::Result;
pub(crate) trait KeyValueDatabaseEngine: Send + Sync {
pub trait KeyValueDatabaseEngine: Send + Sync {
fn open(config: &Config) -> Result<Self>
where
Self: Sized;

View file

@ -2,7 +2,7 @@ use std::{future::Future, pin::Pin};
use crate::Result;
pub(crate) trait KvTree: Send + Sync {
pub trait KvTree: Send + Sync {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
#[allow(dead_code)]

View file

@ -1,638 +0,0 @@
use std::{
collections::{HashMap, HashSet},
fs::{self},
io::Write,
mem::size_of,
sync::Arc,
};
use argon2::{password_hash::SaltString, PasswordHasher, PasswordVerifier};
use itertools::Itertools;
use rand::thread_rng;
use ruma::{
events::{push_rules::PushRulesEvent, GlobalAccountDataEventType},
push::Ruleset,
EventId, OwnedRoomId, RoomId, UserId,
};
use tracing::{debug, error, info, warn};
use super::KeyValueDatabase;
use crate::{services, utils, Config, Error, Result};
pub(crate) async fn migrations(db: &KeyValueDatabase, config: &Config) -> Result<()> {
// Matrix resource ownership is based on the server name; changing it
// requires recreating the database from scratch.
if services().users.count()? > 0 {
let conduit_user =
UserId::parse_with_server_name("conduit", &config.server_name).expect("@conduit:server_name is valid");
if !services().users.exists(&conduit_user)? {
error!("The {} server user does not exist, and the database is not new.", conduit_user);
return Err(Error::bad_database(
"Cannot reuse an existing database after changing the server name, please delete the old one first.",
));
}
}
// If the database has any data, perform data migrations before starting
// do not increment the db version if the user is not using sha256_media
let latest_database_version = if cfg!(feature = "sha256_media") {
14
} else {
13
};
if services().users.count()? > 0 {
// MIGRATIONS
if services().globals.database_version()? < 1 {
for (roomserverid, _) in db.roomserverids.iter() {
let mut parts = roomserverid.split(|&b| b == 0xFF);
let room_id = parts.next().expect("split always returns one element");
let Some(servername) = parts.next() else {
error!("Migration: Invalid roomserverid in db.");
continue;
};
let mut serverroomid = servername.to_vec();
serverroomid.push(0xFF);
serverroomid.extend_from_slice(room_id);
db.serverroomids.insert(&serverroomid, &[])?;
}
services().globals.bump_database_version(1)?;
warn!("Migration: 0 -> 1 finished");
}
if services().globals.database_version()? < 2 {
// We accidentally inserted hashed versions of "" into the db instead of just ""
for (userid, password) in db.userid_password.iter() {
let salt = SaltString::generate(thread_rng());
let empty_pass = services()
.globals
.argon
.hash_password(b"", &salt)
.expect("our own password to be properly hashed");
let empty_hashed_password = services()
.globals
.argon
.verify_password(&password, &empty_pass)
.is_ok();
if empty_hashed_password {
db.userid_password.insert(&userid, b"")?;
}
}
services().globals.bump_database_version(2)?;
warn!("Migration: 1 -> 2 finished");
}
if services().globals.database_version()? < 3 {
// Move media to filesystem
for (key, content) in db.mediaid_file.iter() {
if content.is_empty() {
continue;
}
#[allow(deprecated)]
let path = services().globals.get_media_file(&key);
let mut file = fs::File::create(path)?;
file.write_all(&content)?;
db.mediaid_file.insert(&key, &[])?;
}
services().globals.bump_database_version(3)?;
warn!("Migration: 2 -> 3 finished");
}
if services().globals.database_version()? < 4 {
// Add federated users to services() as deactivated
for our_user in services().users.iter() {
let our_user = our_user?;
if services().users.is_deactivated(&our_user)? {
continue;
}
for room in services().rooms.state_cache.rooms_joined(&our_user) {
for user in services().rooms.state_cache.room_members(&room?) {
let user = user?;
if user.server_name() != config.server_name {
info!(?user, "Migration: creating user");
services().users.create(&user, None)?;
}
}
}
}
services().globals.bump_database_version(4)?;
warn!("Migration: 3 -> 4 finished");
}
if services().globals.database_version()? < 5 {
// Upgrade user data store
for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() {
let mut parts = roomuserdataid.split(|&b| b == 0xFF);
let room_id = parts.next().unwrap();
let user_id = parts.next().unwrap();
let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap();
let mut key = room_id.to_vec();
key.push(0xFF);
key.extend_from_slice(user_id);
key.push(0xFF);
key.extend_from_slice(event_type);
db.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid)?;
}
services().globals.bump_database_version(5)?;
warn!("Migration: 4 -> 5 finished");
}
if services().globals.database_version()? < 6 {
// Set room member count
for (roomid, _) in db.roomid_shortstatehash.iter() {
let string = utils::string_from_bytes(&roomid).unwrap();
let room_id = <&RoomId>::try_from(string.as_str()).unwrap();
services().rooms.state_cache.update_joined_count(room_id)?;
}
services().globals.bump_database_version(6)?;
warn!("Migration: 5 -> 6 finished");
}
if services().globals.database_version()? < 7 {
// Upgrade state store
let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new();
let mut current_sstatehash: Option<u64> = None;
let mut current_room = None;
let mut current_state = HashSet::new();
let handle_state = |current_sstatehash: u64,
current_room: &RoomId,
current_state: HashSet<_>,
last_roomstates: &mut HashMap<_, _>| {
let last_roomsstatehash = last_roomstates.get(current_room);
let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()),
|&last_roomsstatehash| {
services()
.rooms
.state_compressor
.load_shortstatehash_info(last_roomsstatehash)
},
)?;
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew = current_state
.difference(&parent_stateinfo.1)
.copied()
.collect::<HashSet<_>>();
let statediffremoved = parent_stateinfo
.1
.difference(&current_state)
.copied()
.collect::<HashSet<_>>();
(statediffnew, statediffremoved)
} else {
(current_state, HashSet::new())
};
services().rooms.state_compressor.save_state_from_diff(
current_sstatehash,
Arc::new(statediffnew),
Arc::new(statediffremoved),
2, // every state change is 2 event changes on average
states_parents,
)?;
/*
let mut tmp = services().rooms.load_shortstatehash_info(&current_sstatehash)?;
let state = tmp.pop().unwrap();
println!(
"{}\t{}{:?}: {:?} + {:?} - {:?}",
current_room,
" ".repeat(tmp.len()),
utils::u64_from_bytes(&current_sstatehash).unwrap(),
tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()),
state
.2
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>(),
state
.3
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>()
);
*/
Ok::<_, Error>(())
};
for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() {
let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()]).expect("number of bytes is correct");
let sstatekey = k[size_of::<u64>()..].to_vec();
if Some(sstatehash) != current_sstatehash {
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_deref().unwrap(),
current_state,
&mut last_roomstates,
)?;
last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash);
}
current_state = HashSet::new();
current_sstatehash = Some(sstatehash);
let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap();
let string = utils::string_from_bytes(&event_id).unwrap();
let event_id = <&EventId>::try_from(string.as_str()).unwrap();
let pdu = services()
.rooms
.timeline
.get_pdu(event_id)
.unwrap()
.unwrap();
if Some(&pdu.room_id) != current_room.as_ref() {
current_room = Some(pdu.room_id.clone());
}
}
let mut val = sstatekey;
val.extend_from_slice(&seventid);
current_state.insert(val.try_into().expect("size is correct"));
}
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_deref().unwrap(),
current_state,
&mut last_roomstates,
)?;
}
services().globals.bump_database_version(7)?;
warn!("Migration: 6 -> 7 finished");
}
if services().globals.database_version()? < 8 {
// Generate short room ids for all rooms
for (room_id, _) in db.roomid_shortstatehash.iter() {
let shortroomid = services().globals.next_count()?.to_be_bytes();
db.roomid_shortroomid.insert(&room_id, &shortroomid)?;
info!("Migration: 8");
}
// Update pduids db layout
let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| {
if !key.starts_with(b"!") {
return None;
}
let mut parts = key.splitn(2, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = db
.roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_key = short_room_id;
new_key.extend_from_slice(count);
Some((new_key, v))
});
db.pduid_pdu.insert_batch(&mut batch)?;
let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| {
if !value.starts_with(b"!") {
return None;
}
let mut parts = value.splitn(2, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = db
.roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_value = short_room_id;
new_value.extend_from_slice(count);
Some((k, new_value))
});
db.eventid_pduid.insert_batch(&mut batch2)?;
services().globals.bump_database_version(8)?;
warn!("Migration: 7 -> 8 finished");
}
if services().globals.database_version()? < 9 {
// Update tokenids db layout
let mut iter = db
.tokenids
.iter()
.filter_map(|(key, _)| {
if !key.starts_with(b"!") {
return None;
}
let mut parts = key.splitn(4, |&b| b == 0xFF);
let room_id = parts.next().unwrap();
let word = parts.next().unwrap();
let _pdu_id_room = parts.next().unwrap();
let pdu_id_count = parts.next().unwrap();
let short_room_id = db
.roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_key = short_room_id;
new_key.extend_from_slice(word);
new_key.push(0xFF);
new_key.extend_from_slice(pdu_id_count);
Some((new_key, Vec::new()))
})
.peekable();
while iter.peek().is_some() {
db.tokenids.insert_batch(&mut iter.by_ref().take(1000))?;
debug!("Inserted smaller batch");
}
info!("Deleting starts");
let batch2: Vec<_> = db
.tokenids
.iter()
.filter_map(|(key, _)| {
if key.starts_with(b"!") {
Some(key)
} else {
None
}
})
.collect();
for key in batch2 {
db.tokenids.remove(&key)?;
}
services().globals.bump_database_version(9)?;
warn!("Migration: 8 -> 9 finished");
}
if services().globals.database_version()? < 10 {
// Add other direction for shortstatekeys
for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() {
db.shortstatekey_statekey
.insert(&shortstatekey, &statekey)?;
}
// Force E2EE device list updates so we can send them over federation
for user_id in services().users.iter().filter_map(Result::ok) {
services().users.mark_device_key_update(&user_id)?;
}
services().globals.bump_database_version(10)?;
warn!("Migration: 9 -> 10 finished");
}
if services().globals.database_version()? < 11 {
db.db
.open_tree("userdevicesessionid_uiaarequest")?
.clear()?;
services().globals.bump_database_version(11)?;
warn!("Migration: 10 -> 11 finished");
}
if services().globals.database_version()? < 12 {
for username in services().users.list_local_users()? {
let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) {
Ok(u) => u,
Err(e) => {
warn!("Invalid username {username}: {e}");
continue;
},
};
let raw_rules_list = services()
.account_data
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap()
.expect("Username is invalid");
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
let rules_list = &mut account_data.content.global;
//content rule
{
let content_rule_transformation = [".m.rules.contains_user_name", ".m.rule.contains_user_name"];
let rule = rules_list.content.get(content_rule_transformation[0]);
if rule.is_some() {
let mut rule = rule.unwrap().clone();
content_rule_transformation[1].clone_into(&mut rule.rule_id);
rules_list
.content
.shift_remove(content_rule_transformation[0]);
rules_list.content.insert(rule);
}
}
//underride rules
{
let underride_rule_transformation = [
[".m.rules.call", ".m.rule.call"],
[".m.rules.room_one_to_one", ".m.rule.room_one_to_one"],
[".m.rules.encrypted_room_one_to_one", ".m.rule.encrypted_room_one_to_one"],
[".m.rules.message", ".m.rule.message"],
[".m.rules.encrypted", ".m.rule.encrypted"],
];
for transformation in underride_rule_transformation {
let rule = rules_list.underride.get(transformation[0]);
if let Some(rule) = rule {
let mut rule = rule.clone();
transformation[1].clone_into(&mut rule.rule_id);
rules_list.underride.shift_remove(transformation[0]);
rules_list.underride.insert(rule);
}
}
}
services().account_data.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)?;
}
services().globals.bump_database_version(12)?;
warn!("Migration: 11 -> 12 finished");
}
// This migration can be reused as-is anytime the server-default rules are
// updated.
if services().globals.database_version()? < 13 {
for username in services().users.list_local_users()? {
let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) {
Ok(u) => u,
Err(e) => {
warn!("Invalid username {username}: {e}");
continue;
},
};
let raw_rules_list = services()
.account_data
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap()
.expect("Username is invalid");
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
let user_default_rules = Ruleset::server_default(&user);
account_data
.content
.global
.update_with_server_default(user_default_rules);
services().account_data.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)?;
}
services().globals.bump_database_version(13)?;
warn!("Migration: 12 -> 13 finished");
}
#[cfg(feature = "sha256_media")]
{
if services().globals.database_version()? < 14 && cfg!(feature = "sha256_media") {
warn!("sha256_media feature flag is enabled, migrating legacy base64 file names to sha256 file names");
// Move old media files to new names
for (key, _) in db.mediaid_file.iter() {
let old_path = services().globals.get_media_file(&key);
debug!("Old file path: {old_path:?}");
let path = services().globals.get_media_file_new(&key);
debug!("New file path: {path:?}");
// move the file to the new location
if old_path.exists() {
tokio::fs::rename(&old_path, &path).await?;
}
}
services().globals.bump_database_version(14)?;
warn!("Migration: 13 -> 14 finished");
}
}
assert_eq!(
services().globals.database_version().unwrap(),
latest_database_version,
"Failed asserting local database version {} is equal to known latest conduwuit database version {}",
services().globals.database_version().unwrap(),
latest_database_version
);
{
let patterns = &config.forbidden_usernames;
if !patterns.is_empty() {
for user_id in services()
.users
.iter()
.filter_map(Result::ok)
.filter(|user| !services().users.is_deactivated(user).unwrap_or(true))
.filter(|user| user.server_name() == config.server_name)
{
let matches = patterns.matches(user_id.localpart());
if matches.matched_any() {
warn!(
"User {} matches the following forbidden username patterns: {}",
user_id.to_string(),
matches
.into_iter()
.map(|x| &patterns.patterns()[x])
.join(", ")
);
}
}
}
}
{
let patterns = &config.forbidden_alias_names;
if !patterns.is_empty() {
for address in services().rooms.metadata.iter_ids() {
let room_id = address?;
let room_aliases = services().rooms.alias.local_aliases_for_room(&room_id);
for room_alias_result in room_aliases {
let room_alias = room_alias_result?;
let matches = patterns.matches(room_alias.alias());
if matches.matched_any() {
warn!(
"Room with alias {} ({}) matches the following forbidden room name patterns: {}",
room_alias,
&room_id,
matches
.into_iter()
.map(|x| &patterns.patterns()[x])
.join(", ")
);
}
}
}
}
}
info!(
"Loaded {} database with schema version {}",
config.database_backend, latest_database_version
);
} else {
services()
.globals
.bump_database_version(latest_database_version)?;
// Create the admin room and server user on first run
services().admin.create_admin_room().await?;
warn!(
"Created new {} database with version {}",
config.database_backend, latest_database_version
);
}
Ok(())
}

View file

@ -1,573 +1,23 @@
mod cork;
mod key_value;
pub mod cork;
mod kvdatabase;
mod kvengine;
mod kvtree;
mod migrations;
#[cfg(feature = "rocksdb")]
mod rocksdb;
pub(crate) mod rocksdb;
#[cfg(feature = "sqlite")]
mod sqlite;
pub(crate) mod sqlite;
#[cfg(any(feature = "sqlite", feature = "rocksdb"))]
pub(crate) mod watchers;
use std::{
collections::{BTreeMap, HashMap, HashSet},
fs::{self},
path::Path,
sync::{Arc, Mutex, RwLock},
time::Duration,
};
pub(crate) use cork::Cork;
pub(crate) use kvengine::KeyValueDatabaseEngine;
pub(crate) use kvtree::KvTree;
use lru_cache::LruCache;
use ruma::{
events::{
push_rules::PushRulesEventContent, room::message::RoomMessageEventContent, GlobalAccountDataEvent,
GlobalAccountDataEventType,
},
push::Ruleset,
CanonicalJsonValue, OwnedDeviceId, OwnedRoomId, OwnedUserId, UserId,
};
use serde::Deserialize;
#[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind};
use tokio::time::{interval, Instant};
use tracing::{debug, error, warn};
use crate::{
database::migrations::migrations, service::rooms::timeline::PduCount, services, Config, Error,
LogLevelReloadHandles, Result, Services, SERVICES,
};
pub(crate) struct KeyValueDatabase {
db: Arc<dyn KeyValueDatabaseEngine>,
//pub(crate) globals: globals::Globals,
pub(crate) global: Arc<dyn KvTree>,
pub(crate) server_signingkeys: Arc<dyn KvTree>,
pub(crate) roomid_inviteviaservers: Arc<dyn KvTree>,
//pub(crate) users: users::Users,
pub(crate) userid_password: Arc<dyn KvTree>,
pub(crate) userid_displayname: Arc<dyn KvTree>,
pub(crate) userid_avatarurl: Arc<dyn KvTree>,
pub(crate) userid_blurhash: Arc<dyn KvTree>,
pub(crate) userdeviceid_token: Arc<dyn KvTree>,
pub(crate) userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists
pub(crate) userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64
pub(crate) token_userdeviceid: Arc<dyn KvTree>,
pub(crate) onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId
pub(crate) userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count
pub(crate) keychangeid_userid: Arc<dyn KvTree>, // KeyChangeId = UserId/RoomId + Count
pub(crate) keyid_key: Arc<dyn KvTree>, // KeyId = UserId + KeyId (depends on key type)
pub(crate) userid_masterkeyid: Arc<dyn KvTree>,
pub(crate) userid_selfsigningkeyid: Arc<dyn KvTree>,
pub(crate) userid_usersigningkeyid: Arc<dyn KvTree>,
pub(crate) userfilterid_filter: Arc<dyn KvTree>, // UserFilterId = UserId + FilterId
pub(crate) todeviceid_events: Arc<dyn KvTree>, // ToDeviceId = UserId + DeviceId + Count
pub(crate) userid_presenceid: Arc<dyn KvTree>, // UserId => Count
pub(crate) presenceid_presence: Arc<dyn KvTree>, // Count + UserId => Presence
//pub(crate) uiaa: uiaa::Uiaa,
pub(crate) userdevicesessionid_uiaainfo: Arc<dyn KvTree>, // User-interactive authentication
pub(crate) userdevicesessionid_uiaarequest:
RwLock<BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>>,
//pub(crate) edus: RoomEdus,
pub(crate) readreceiptid_readreceipt: Arc<dyn KvTree>, // ReadReceiptId = RoomId + Count + UserId
pub(crate) roomuserid_privateread: Arc<dyn KvTree>, // RoomUserId = Room + User, PrivateRead = Count
pub(crate) roomuserid_lastprivatereadupdate: Arc<dyn KvTree>, // LastPrivateReadUpdate = Count
//pub(crate) rooms: rooms::Rooms,
pub(crate) pduid_pdu: Arc<dyn KvTree>, // PduId = ShortRoomId + Count
pub(crate) eventid_pduid: Arc<dyn KvTree>,
pub(crate) roomid_pduleaves: Arc<dyn KvTree>,
pub(crate) alias_roomid: Arc<dyn KvTree>,
pub(crate) aliasid_alias: Arc<dyn KvTree>, // AliasId = RoomId + Count
pub(crate) publicroomids: Arc<dyn KvTree>,
pub(crate) threadid_userids: Arc<dyn KvTree>, // ThreadId = RoomId + Count
pub(crate) tokenids: Arc<dyn KvTree>, // TokenId = ShortRoomId + Token + PduIdCount
/// Participating servers in a room.
pub(crate) roomserverids: Arc<dyn KvTree>, // RoomServerId = RoomId + ServerName
pub(crate) serverroomids: Arc<dyn KvTree>, // ServerRoomId = ServerName + RoomId
pub(crate) userroomid_joined: Arc<dyn KvTree>,
pub(crate) roomuserid_joined: Arc<dyn KvTree>,
pub(crate) roomid_joinedcount: Arc<dyn KvTree>,
pub(crate) roomid_invitedcount: Arc<dyn KvTree>,
pub(crate) roomuseroncejoinedids: Arc<dyn KvTree>,
pub(crate) userroomid_invitestate: Arc<dyn KvTree>, // InviteState = Vec<Raw<Pdu>>
pub(crate) roomuserid_invitecount: Arc<dyn KvTree>, // InviteCount = Count
pub(crate) userroomid_leftstate: Arc<dyn KvTree>,
pub(crate) roomuserid_leftcount: Arc<dyn KvTree>,
pub(crate) disabledroomids: Arc<dyn KvTree>, // Rooms where incoming federation handling is disabled
pub(crate) bannedroomids: Arc<dyn KvTree>, // Rooms where local users are not allowed to join
pub(crate) lazyloadedids: Arc<dyn KvTree>, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId
pub(crate) userroomid_notificationcount: Arc<dyn KvTree>, // NotifyCount = u64
pub(crate) userroomid_highlightcount: Arc<dyn KvTree>, // HightlightCount = u64
pub(crate) roomuserid_lastnotificationread: Arc<dyn KvTree>, // LastNotificationRead = u64
/// Remember the current state hash of a room.
pub(crate) roomid_shortstatehash: Arc<dyn KvTree>,
pub(crate) roomsynctoken_shortstatehash: Arc<dyn KvTree>,
/// Remember the state hash at events in the past.
pub(crate) shorteventid_shortstatehash: Arc<dyn KvTree>,
pub(crate) statekey_shortstatekey: Arc<dyn KvTree>, /* StateKey = EventType + StateKey, ShortStateKey =
* Count */
pub(crate) shortstatekey_statekey: Arc<dyn KvTree>,
pub(crate) roomid_shortroomid: Arc<dyn KvTree>,
pub(crate) shorteventid_eventid: Arc<dyn KvTree>,
pub(crate) eventid_shorteventid: Arc<dyn KvTree>,
pub(crate) statehash_shortstatehash: Arc<dyn KvTree>,
pub(crate) shortstatehash_statediff: Arc<dyn KvTree>, /* StateDiff = parent (or 0) +
* (shortstatekey+shorteventid++) + 0_u64 +
* (shortstatekey+shorteventid--) */
pub(crate) shorteventid_authchain: Arc<dyn KvTree>,
/// RoomId + EventId -> outlier PDU.
/// Any pdu that has passed the steps 1-8 in the incoming event
/// /federation/send/txn.
pub(crate) eventid_outlierpdu: Arc<dyn KvTree>,
pub(crate) softfailedeventids: Arc<dyn KvTree>,
/// ShortEventId + ShortEventId -> ().
pub(crate) tofrom_relation: Arc<dyn KvTree>,
/// RoomId + EventId -> Parent PDU EventId.
pub(crate) referencedevents: Arc<dyn KvTree>,
//pub(crate) account_data: account_data::AccountData,
pub(crate) roomuserdataid_accountdata: Arc<dyn KvTree>, // RoomUserDataId = Room + User + Count + Type
pub(crate) roomusertype_roomuserdataid: Arc<dyn KvTree>, // RoomUserType = Room + User + Type
//pub(crate) media: media::Media,
pub(crate) mediaid_file: Arc<dyn KvTree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
pub(crate) url_previews: Arc<dyn KvTree>,
pub(crate) mediaid_user: Arc<dyn KvTree>,
//pub(crate) key_backups: key_backups::KeyBackups,
pub(crate) backupid_algorithm: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
pub(crate) backupid_etag: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
pub(crate) backupkeyid_backup: Arc<dyn KvTree>, // BackupKeyId = UserId + Version + RoomId + SessionId
//pub(crate) transaction_ids: transaction_ids::TransactionIds,
pub(crate) userdevicetxnid_response: Arc<dyn KvTree>, /* Response can be empty (/sendToDevice) or the event id
* (/send) */
//pub(crate) sending: sending::Sending,
pub(crate) servername_educount: Arc<dyn KvTree>, // EduCount: Count of last EDU sync
pub(crate) servernameevent_data: Arc<dyn KvTree>, /* ServernameEvent = (+ / $)SenderKey / ServerName / UserId +
* PduId / Id (for edus), Data = EDU content */
pub(crate) servercurrentevent_data: Arc<dyn KvTree>, /* ServerCurrentEvents = (+ / $)ServerName / UserId + PduId
* / Id (for edus), Data = EDU content */
//pub(crate) appservice: appservice::Appservice,
pub(crate) id_appserviceregistrations: Arc<dyn KvTree>,
//pub(crate) pusher: pusher::PushData,
pub(crate) senderkey_pusher: Arc<dyn KvTree>,
pub(crate) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<[u64]>>>,
pub(crate) our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>,
pub(crate) appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>,
pub(crate) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,
}
#[derive(Deserialize)]
struct CheckForUpdatesResponseEntry {
id: u64,
date: String,
message: String,
}
#[derive(Deserialize)]
struct CheckForUpdatesResponse {
updates: Vec<CheckForUpdatesResponseEntry>,
}
impl KeyValueDatabase {
/// Load an existing database or create a new one.
#[allow(clippy::too_many_lines)]
pub(crate) async fn load_or_create(config: Config, tracing_reload_handler: LogLevelReloadHandles) -> Result<()> {
Self::check_db_setup(&config)?;
if !Path::new(&config.database_path).exists() {
debug!("Database path does not exist, assuming this is a new setup and creating it");
fs::create_dir_all(&config.database_path).map_err(|e| {
error!("Failed to create database path: {e}");
Error::bad_config(
"Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please \
create the database folder yourself or allow conduwuit the permissions to create directories and \
files.",
)
})?;
}
let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config.database_backend {
"sqlite" => {
debug!("Got sqlite database backend");
#[cfg(not(feature = "sqlite"))]
return Err(Error::bad_config("Database backend not found."));
#[cfg(feature = "sqlite")]
Arc::new(Arc::<sqlite::Engine>::open(&config)?)
},
"rocksdb" => {
debug!("Got rocksdb database backend");
#[cfg(not(feature = "rocksdb"))]
return Err(Error::bad_config("Database backend not found."));
#[cfg(feature = "rocksdb")]
Arc::new(Arc::<rocksdb::Engine>::open(&config)?)
},
_ => {
return Err(Error::bad_config(
"Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends.",
));
},
};
let db_raw = Box::new(Self {
db: builder.clone(),
userid_password: builder.open_tree("userid_password")?,
userid_displayname: builder.open_tree("userid_displayname")?,
userid_avatarurl: builder.open_tree("userid_avatarurl")?,
userid_blurhash: builder.open_tree("userid_blurhash")?,
userdeviceid_token: builder.open_tree("userdeviceid_token")?,
userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?,
userid_devicelistversion: builder.open_tree("userid_devicelistversion")?,
token_userdeviceid: builder.open_tree("token_userdeviceid")?,
onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?,
userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?,
keychangeid_userid: builder.open_tree("keychangeid_userid")?,
keyid_key: builder.open_tree("keyid_key")?,
userid_masterkeyid: builder.open_tree("userid_masterkeyid")?,
userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?,
userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
userfilterid_filter: builder.open_tree("userfilterid_filter")?,
todeviceid_events: builder.open_tree("todeviceid_events")?,
userid_presenceid: builder.open_tree("userid_presenceid")?,
presenceid_presence: builder.open_tree("presenceid_presence")?,
userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
roomuserid_lastprivatereadupdate: builder.open_tree("roomuserid_lastprivatereadupdate")?,
pduid_pdu: builder.open_tree("pduid_pdu")?,
eventid_pduid: builder.open_tree("eventid_pduid")?,
roomid_pduleaves: builder.open_tree("roomid_pduleaves")?,
alias_roomid: builder.open_tree("alias_roomid")?,
aliasid_alias: builder.open_tree("aliasid_alias")?,
publicroomids: builder.open_tree("publicroomids")?,
threadid_userids: builder.open_tree("threadid_userids")?,
tokenids: builder.open_tree("tokenids")?,
roomserverids: builder.open_tree("roomserverids")?,
serverroomids: builder.open_tree("serverroomids")?,
userroomid_joined: builder.open_tree("userroomid_joined")?,
roomuserid_joined: builder.open_tree("roomuserid_joined")?,
roomid_joinedcount: builder.open_tree("roomid_joinedcount")?,
roomid_invitedcount: builder.open_tree("roomid_invitedcount")?,
roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?,
userroomid_invitestate: builder.open_tree("userroomid_invitestate")?,
roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?,
userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,
disabledroomids: builder.open_tree("disabledroomids")?,
bannedroomids: builder.open_tree("bannedroomids")?,
lazyloadedids: builder.open_tree("lazyloadedids")?,
userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?,
userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?,
roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?,
statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?,
shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?,
shorteventid_authchain: builder.open_tree("shorteventid_authchain")?,
roomid_shortroomid: builder.open_tree("roomid_shortroomid")?,
shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?,
eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?,
roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?,
roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?,
statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?,
eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
softfailedeventids: builder.open_tree("softfailedeventids")?,
tofrom_relation: builder.open_tree("tofrom_relation")?,
referencedevents: builder.open_tree("referencedevents")?,
roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?,
mediaid_file: builder.open_tree("mediaid_file")?,
url_previews: builder.open_tree("url_previews")?,
mediaid_user: builder.open_tree("mediaid_user")?,
backupid_algorithm: builder.open_tree("backupid_algorithm")?,
backupid_etag: builder.open_tree("backupid_etag")?,
backupkeyid_backup: builder.open_tree("backupkeyid_backup")?,
userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?,
servername_educount: builder.open_tree("servername_educount")?,
servernameevent_data: builder.open_tree("servernameevent_data")?,
servercurrentevent_data: builder.open_tree("servercurrentevent_data")?,
id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?,
senderkey_pusher: builder.open_tree("senderkey_pusher")?,
global: builder.open_tree("global")?,
server_signingkeys: builder.open_tree("server_signingkeys")?,
roomid_inviteviaservers: builder.open_tree("roomid_inviteviaservers")?,
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
auth_chain_cache: Mutex::new(LruCache::new(
(f64::from(config.auth_chain_cache_capacity) * config.conduit_cache_capacity_modifier) as usize,
)),
our_real_users_cache: RwLock::new(HashMap::new()),
appservice_in_room_cache: RwLock::new(HashMap::new()),
lasttimelinecount_cache: Mutex::new(HashMap::new()),
});
let db = Box::leak(db_raw);
let services_raw = Box::new(Services::build(db, &config, tracing_reload_handler)?);
// This is the first and only time we initialize the SERVICE static
*SERVICES.write().unwrap() = Some(Box::leak(services_raw));
migrations(db, &config).await?;
services().admin.start_handler();
// Set emergency access for the conduit user
match set_emergency_access() {
Ok(pwd_set) => {
if pwd_set {
warn!(
"The Conduit account emergency password is set! Please unset it as soon as you finish admin \
account recovery!"
);
services()
.admin
.send_message(RoomMessageEventContent::text_plain(
"The Conduit account emergency password is set! Please unset it as soon as you finish \
admin account recovery!",
))
.await;
}
},
Err(e) => {
error!("Could not set the configured emergency password for the conduit user: {}", e);
},
};
services().sending.start_handler();
if config.allow_local_presence {
services().presence.start_handler();
}
Self::start_cleanup_task().await;
if services().globals.allow_check_for_updates() {
Self::start_check_for_updates_task().await;
}
Ok(())
}
fn check_db_setup(config: &Config) -> Result<()> {
let path = Path::new(&config.database_path);
let sqlite_exists = path.join("conduit.db").exists();
let rocksdb_exists = path.join("IDENTITY").exists();
if sqlite_exists && rocksdb_exists {
return Err(Error::bad_config("Multiple databases at database_path detected."));
}
if sqlite_exists && config.database_backend != "sqlite" {
return Err(Error::bad_config(
"Found sqlite at database_path, but is not specified in config.",
));
}
if rocksdb_exists && config.database_backend != "rocksdb" {
return Err(Error::bad_config(
"Found rocksdb at database_path, but is not specified in config.",
));
}
Ok(())
}
#[tracing::instrument]
async fn start_check_for_updates_task() {
let timer_interval = Duration::from_secs(7200); // 2 hours
tokio::spawn(async move {
let mut i = interval(timer_interval);
loop {
tokio::select! {
_ = i.tick() => {
debug!(target: "start_check_for_updates_task", "Timer ticked");
},
}
_ = Self::try_handle_updates().await;
}
});
}
async fn try_handle_updates() -> Result<()> {
let response = services()
.globals
.client
.default
.get("https://pupbrain.dev/check-for-updates/stable")
.send()
.await?;
let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?).map_err(|e| {
error!("Bad check for updates response: {e}");
Error::BadServerResponse("Bad version check response")
})?;
let mut last_update_id = services().globals.last_check_for_updates_id()?;
for update in response.updates {
last_update_id = last_update_id.max(update.id);
if update.id > services().globals.last_check_for_updates_id()? {
error!("{}", update.message);
services()
.admin
.send_message(RoomMessageEventContent::text_plain(format!(
"@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}",
update.date, update.message
)))
.await;
}
}
services()
.globals
.update_check_for_updates_id(last_update_id)?;
Ok(())
}
#[tracing::instrument]
async fn start_cleanup_task() {
let timer_interval = Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval));
tokio::spawn(async move {
let mut i = interval(timer_interval);
#[cfg(unix)]
let mut hangup = signal(SignalKind::hangup()).expect("Failed to register SIGHUP signal receiver");
#[cfg(unix)]
let mut ctrl_c = signal(SignalKind::interrupt()).expect("Failed to register SIGINT signal receiver");
#[cfg(unix)]
let mut terminate = signal(SignalKind::terminate()).expect("Failed to register SIGTERM signal receiver");
loop {
#[cfg(unix)]
tokio::select! {
_ = i.tick() => {
debug!(target: "database-cleanup", "Timer ticked");
}
_ = hangup.recv() => {
debug!(target: "database-cleanup","Received SIGHUP");
}
_ = ctrl_c.recv() => {
debug!(target: "database-cleanup", "Received Ctrl+C");
}
_ = terminate.recv() => {
debug!(target: "database-cleanup","Received SIGTERM");
}
}
#[cfg(not(unix))]
{
i.tick().await;
debug!(target: "database-cleanup", "Timer ticked")
}
Self::perform_cleanup();
}
});
}
fn perform_cleanup() {
if !services().globals.config.rocksdb_periodic_cleanup {
return;
}
let start = Instant::now();
if let Err(e) = services().globals.cleanup() {
error!(target: "database-cleanup", "Ran into an error during cleanup: {}", e);
} else {
debug!(target: "database-cleanup", "Finished cleanup in {:#?}.", start.elapsed());
}
}
#[allow(dead_code)]
fn flush(&self) -> Result<()> {
let start = std::time::Instant::now();
let res = self.db.flush();
debug!("flush: took {:?}", start.elapsed());
res
}
}
/// Sets the emergency password and push rules for the @conduit account in case
/// emergency password is set
fn set_emergency_access() -> Result<bool> {
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
.expect("@conduit:server_name is a valid UserId");
services()
.users
.set_password(&conduit_user, services().globals.emergency_password().as_deref())?;
let (ruleset, res) = match services().globals.emergency_password() {
Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)),
None => (Ruleset::new(), Ok(false)),
};
services().account_data.update(
None,
&conduit_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)?;
res
}
extern crate conduit_core as conduit;
pub(crate) use conduit::{Config, Result};
pub use cork::Cork;
pub use kvdatabase::KeyValueDatabase;
pub use kvengine::KeyValueDatabaseEngine;
pub use kvtree::KvTree;
conduit::mod_ctor! {}
conduit::mod_dtor! {}

View file

@ -1,9 +1,8 @@
use std::{future::Future, pin::Pin, sync::Arc};
use rust_rocksdb::WriteBatchWithTransaction;
use conduit::{utils, Result};
use super::{watchers::Watchers, Engine, KeyValueDatabaseEngine, KvTree};
use crate::{utils, Result};
use super::{rust_rocksdb::WriteBatchWithTransaction, watchers::Watchers, Engine, KeyValueDatabaseEngine, KvTree};
pub(crate) struct RocksDbEngineTree<'a> {
pub(crate) db: Arc<Engine>,

View file

@ -1,3 +1,8 @@
// no_link to prevent double-inclusion of librocksdb.a here and with
// libconduit_core.so
#[no_link]
extern crate rust_rocksdb;
use std::{
collections::HashMap,
sync::{atomic::AtomicU32, Arc},
@ -6,12 +11,12 @@ use std::{
use chrono::{DateTime, Utc};
use rust_rocksdb::{
backup::{BackupEngine, BackupEngineOptions},
perf::get_memory_usage_stats,
Cache, ColumnFamilyDescriptor, DBCommon, DBWithThreadMode as Db, Env, MultiThreaded, Options,
};
use tracing::{debug, error, info, warn};
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::Result;
use crate::{watchers::Watchers, Config, KeyValueDatabaseEngine, KvTree, Result};
pub(crate) mod kvtree;
pub(crate) mod opts;
@ -22,13 +27,13 @@ use opts::{cf_options, db_options};
use super::watchers;
pub(crate) struct Engine {
rocks: Db<MultiThreaded>,
config: Config,
row_cache: Cache,
col_cache: HashMap<String, Cache>,
old_cfs: Vec<String>,
opts: Options,
env: Env,
config: Config,
old_cfs: Vec<String>,
rocks: Db<MultiThreaded>,
corks: AtomicU32,
}
@ -79,13 +84,13 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
load_time.elapsed()
);
Ok(Arc::new(Engine {
rocks: db,
config: config.clone(),
row_cache,
col_cache,
old_cfs: cfs,
opts: db_opts,
env: db_env,
config: config.clone(),
old_cfs: cfs,
rocks: db,
corks: AtomicU32::new(0),
}))
}
@ -135,7 +140,7 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
fn memory_usage(&self) -> Result<String> {
let mut res = String::new();
let stats = rust_rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.row_cache]))?;
let stats = get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.row_cache]))?;
_ = std::fmt::write(
&mut res,
format_args!(
@ -258,3 +263,20 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
#[allow(dead_code)]
fn clear_caches(&self) {}
}
impl Drop for Engine {
fn drop(&mut self) {
debug!("Waiting for background tasks to finish...");
const BLOCKING: bool = true;
self.rocks.cancel_all_background_work(BLOCKING);
debug!("Shutting down background threads");
self.env.set_high_priority_background_threads(0);
self.env.set_low_priority_background_threads(0);
self.env.set_bottom_priority_background_threads(0);
self.env.set_background_threads(0);
debug!("Joining background threads...");
self.env.join_all_threads();
}
}

View file

@ -1,14 +1,14 @@
#![allow(dead_code)]
use std::collections::HashMap;
use rust_rocksdb::{
BlockBasedOptions, Cache, DBCompactionStyle, DBCompressionType, DBRecoveryMode, Env, LogLevel, Options,
UniversalCompactOptions, UniversalCompactionStopStyle,
use super::{
rust_rocksdb::{
BlockBasedOptions, Cache, DBCompactionStyle, DBCompressionType, DBRecoveryMode, Env, LogLevel, Options,
UniversalCompactOptions, UniversalCompactionStopStyle,
},
Config,
};
use super::Config;
/// Create database-wide options suitable for opening the database. This also
/// sets our default column options in case of opening a column with the same
/// resulting value. Note that we require special per-column options on some

View file

@ -6,13 +6,13 @@ use std::{
sync::Arc,
};
use conduit::{Config, Result};
use parking_lot::{Mutex, MutexGuard};
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
use thread_local::ThreadLocal;
use tracing::debug;
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::{database::Config, Result};
thread_local! {
static READ_CONNECTION: RefCell<Option<&'static Connection>> = const { RefCell::new(None) };
@ -224,7 +224,7 @@ impl KvTree for SqliteTable {
guard.execute("BEGIN", [])?;
for key in iter {
let old = self.get_with_guard(&guard, &key)?;
let new = crate::utils::increment(old.as_deref());
let new = conduit::utils::increment(old.as_deref());
self.insert_with_guard(&guard, &key, &new)?;
}
guard.execute("COMMIT", [])?;
@ -307,7 +307,7 @@ impl KvTree for SqliteTable {
let old = self.get_with_guard(&guard, key)?;
let new = crate::utils::increment(old.as_deref());
let new = conduit::utils::increment(old.as_deref());
self.insert_with_guard(&guard, key, &new)?;