devirtualize service Data traits

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-05-27 03:17:20 +00:00
parent a6edaad6fc
commit 7ad7badd60
64 changed files with 1190 additions and 1176 deletions

View file

@ -1,6 +1,9 @@
use std::collections::{BTreeMap, HashMap};
use std::{
collections::{BTreeMap, HashMap},
sync::Arc,
};
use async_trait::async_trait;
use database::{Cork, KeyValueDatabase, KvTree};
use futures_util::{stream::FuturesUnordered, StreamExt};
use lru_cache::LruCache;
use ruma::{
@ -10,56 +13,60 @@ use ruma::{
};
use tracing::trace;
use crate::{database::Cork, services, utils, Error, KeyValueDatabase, Result};
use crate::{services, utils, Error, Result};
const COUNTER: &[u8] = b"c";
const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
#[async_trait]
pub trait Data: Send + Sync {
fn next_count(&self) -> Result<u64>;
fn current_count(&self) -> Result<u64>;
fn last_check_for_updates_id(&self) -> Result<u64>;
fn update_check_for_updates_id(&self, id: u64) -> Result<()>;
#[allow(unused_qualifications)] // async traits
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>;
fn cleanup(&self) -> Result<()>;
fn cork(&self) -> Cork;
fn cork_and_flush(&self) -> Cork;
fn memory_usage(&self) -> String;
fn clear_caches(&self, amount: u32);
fn load_keypair(&self) -> Result<Ed25519KeyPair>;
fn remove_keypair(&self) -> Result<()>;
fn add_signing_key(
&self, origin: &ServerName, new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
fn database_version(&self) -> Result<u64>;
fn bump_database_version(&self, new_version: u64) -> Result<()>;
fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { unimplemented!() }
fn backup_list(&self) -> Result<String> { Ok(String::new()) }
fn file_list(&self) -> Result<String> { Ok(String::new()) }
pub struct Data {
global: Arc<dyn KvTree>,
todeviceid_events: Arc<dyn KvTree>,
userroomid_joined: Arc<dyn KvTree>,
userroomid_invitestate: Arc<dyn KvTree>,
userroomid_leftstate: Arc<dyn KvTree>,
userroomid_notificationcount: Arc<dyn KvTree>,
userroomid_highlightcount: Arc<dyn KvTree>,
pduid_pdu: Arc<dyn KvTree>,
keychangeid_userid: Arc<dyn KvTree>,
roomusertype_roomuserdataid: Arc<dyn KvTree>,
server_signingkeys: Arc<dyn KvTree>,
readreceiptid_readreceipt: Arc<dyn KvTree>,
userid_lastonetimekeyupdate: Arc<dyn KvTree>,
pub(super) db: Arc<KeyValueDatabase>,
}
#[async_trait]
impl Data for KeyValueDatabase {
fn next_count(&self) -> Result<u64> {
impl Data {
pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
Self {
global: db.global.clone(),
todeviceid_events: db.todeviceid_events.clone(),
userroomid_joined: db.userroomid_joined.clone(),
userroomid_invitestate: db.userroomid_invitestate.clone(),
userroomid_leftstate: db.userroomid_leftstate.clone(),
userroomid_notificationcount: db.userroomid_notificationcount.clone(),
userroomid_highlightcount: db.userroomid_highlightcount.clone(),
pduid_pdu: db.pduid_pdu.clone(),
keychangeid_userid: db.keychangeid_userid.clone(),
roomusertype_roomuserdataid: db.roomusertype_roomuserdataid.clone(),
server_signingkeys: db.server_signingkeys.clone(),
readreceiptid_readreceipt: db.readreceiptid_readreceipt.clone(),
userid_lastonetimekeyupdate: db.userid_lastonetimekeyupdate.clone(),
db: db.clone(),
}
}
pub fn next_count(&self) -> Result<u64> {
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
.map_err(|_| Error::bad_database("Count has invalid bytes."))
}
fn current_count(&self) -> Result<u64> {
pub fn current_count(&self) -> Result<u64> {
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
})
}
fn last_check_for_updates_id(&self) -> Result<u64> {
pub fn last_check_for_updates_id(&self) -> Result<u64> {
self.global
.get(LAST_CHECK_FOR_UPDATES_COUNT)?
.map_or(Ok(0_u64), |bytes| {
@ -68,16 +75,15 @@ impl Data for KeyValueDatabase {
})
}
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
self.global
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
Ok(())
}
#[allow(unused_qualifications)] // async traits
#[tracing::instrument(skip(self))]
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let userid_bytes = user_id.as_bytes().to_vec();
let mut userid_prefix = userid_bytes.clone();
userid_prefix.push(0xFF);
@ -177,20 +183,20 @@ impl Data for KeyValueDatabase {
Ok(())
}
fn cleanup(&self) -> Result<()> { self.db.cleanup() }
pub fn cleanup(&self) -> Result<()> { self.db.db.cleanup() }
fn cork(&self) -> Cork { Cork::new(&self.db, false, false) }
pub fn cork(&self) -> Cork { Cork::new(&self.db.db, false, false) }
fn cork_and_flush(&self) -> Cork { Cork::new(&self.db, true, false) }
pub fn cork_and_flush(&self) -> Cork { Cork::new(&self.db.db, true, false) }
fn memory_usage(&self) -> String {
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len();
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len();
pub fn memory_usage(&self) -> String {
let auth_chain_cache = self.db.auth_chain_cache.lock().unwrap().len();
let appservice_in_room_cache = self.db.appservice_in_room_cache.read().unwrap().len();
let lasttimelinecount_cache = self.db.lasttimelinecount_cache.lock().unwrap().len();
let max_auth_chain_cache = self.auth_chain_cache.lock().unwrap().capacity();
let max_appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().capacity();
let max_lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().capacity();
let max_auth_chain_cache = self.db.auth_chain_cache.lock().unwrap().capacity();
let max_appservice_in_room_cache = self.db.appservice_in_room_cache.read().unwrap().capacity();
let max_lasttimelinecount_cache = self.db.lasttimelinecount_cache.lock().unwrap().capacity();
format!(
"\
@ -198,26 +204,26 @@ auth_chain_cache: {auth_chain_cache} / {max_auth_chain_cache}
appservice_in_room_cache: {appservice_in_room_cache} / {max_appservice_in_room_cache}
lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cache}\n\n
{}",
self.db.memory_usage().unwrap_or_default()
self.db.db.memory_usage().unwrap_or_default()
)
}
fn clear_caches(&self, amount: u32) {
pub fn clear_caches(&self, amount: u32) {
if amount > 1 {
let c = &mut *self.auth_chain_cache.lock().unwrap();
let c = &mut *self.db.auth_chain_cache.lock().unwrap();
*c = LruCache::new(c.capacity());
}
if amount > 2 {
let c = &mut *self.appservice_in_room_cache.write().unwrap();
let c = &mut *self.db.appservice_in_room_cache.write().unwrap();
*c = HashMap::new();
}
if amount > 3 {
let c = &mut *self.lasttimelinecount_cache.lock().unwrap();
let c = &mut *self.db.lasttimelinecount_cache.lock().unwrap();
*c = HashMap::new();
}
}
fn load_keypair(&self) -> Result<Ed25519KeyPair> {
pub fn load_keypair(&self) -> Result<Ed25519KeyPair> {
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|| {
let keypair = utils::generate_keypair();
@ -249,9 +255,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cach
})
}
fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
pub fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
fn add_signing_key(
pub fn add_signing_key(
&self, origin: &ServerName, new_keys: ServerSigningKeys,
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
// Not atomic, but this is not critical
@ -290,8 +296,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cach
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
/// for the server.
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
pub fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
let signingkeys = self
.db
.server_signingkeys
.get(origin.as_bytes())?
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
@ -308,20 +315,20 @@ lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cach
Ok(signingkeys)
}
fn database_version(&self) -> Result<u64> {
pub fn database_version(&self) -> Result<u64> {
self.global.get(b"version")?.map_or(Ok(0), |version| {
utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
})
}
fn bump_database_version(&self, new_version: u64) -> Result<()> {
pub fn bump_database_version(&self, new_version: u64) -> Result<()> {
self.global.insert(b"version", &new_version.to_be_bytes())?;
Ok(())
}
fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { self.db.backup() }
pub fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { self.db.db.backup() }
fn backup_list(&self) -> Result<String> { self.db.backup_list() }
pub fn backup_list(&self) -> Result<String> { self.db.db.backup_list() }
fn file_list(&self) -> Result<String> { self.db.file_list() }
pub fn file_list(&self) -> Result<String> { self.db.db.file_list() }
}