devirtualize service Data traits
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
a6edaad6fc
commit
7ad7badd60
64 changed files with 1190 additions and 1176 deletions
|
@ -1,6 +1,9 @@
|
|||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use database::{Cork, KeyValueDatabase, KvTree};
|
||||
use futures_util::{stream::FuturesUnordered, StreamExt};
|
||||
use lru_cache::LruCache;
|
||||
use ruma::{
|
||||
|
@ -10,56 +13,60 @@ use ruma::{
|
|||
};
|
||||
use tracing::trace;
|
||||
|
||||
use crate::{database::Cork, services, utils, Error, KeyValueDatabase, Result};
|
||||
use crate::{services, utils, Error, Result};
|
||||
|
||||
const COUNTER: &[u8] = b"c";
|
||||
const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
|
||||
|
||||
#[async_trait]
|
||||
pub trait Data: Send + Sync {
|
||||
fn next_count(&self) -> Result<u64>;
|
||||
fn current_count(&self) -> Result<u64>;
|
||||
fn last_check_for_updates_id(&self) -> Result<u64>;
|
||||
fn update_check_for_updates_id(&self, id: u64) -> Result<()>;
|
||||
#[allow(unused_qualifications)] // async traits
|
||||
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>;
|
||||
fn cleanup(&self) -> Result<()>;
|
||||
|
||||
fn cork(&self) -> Cork;
|
||||
fn cork_and_flush(&self) -> Cork;
|
||||
|
||||
fn memory_usage(&self) -> String;
|
||||
fn clear_caches(&self, amount: u32);
|
||||
fn load_keypair(&self) -> Result<Ed25519KeyPair>;
|
||||
fn remove_keypair(&self) -> Result<()>;
|
||||
fn add_signing_key(
|
||||
&self, origin: &ServerName, new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
||||
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
|
||||
/// for the server.
|
||||
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
||||
fn database_version(&self) -> Result<u64>;
|
||||
fn bump_database_version(&self, new_version: u64) -> Result<()>;
|
||||
fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { unimplemented!() }
|
||||
fn backup_list(&self) -> Result<String> { Ok(String::new()) }
|
||||
fn file_list(&self) -> Result<String> { Ok(String::new()) }
|
||||
pub struct Data {
|
||||
global: Arc<dyn KvTree>,
|
||||
todeviceid_events: Arc<dyn KvTree>,
|
||||
userroomid_joined: Arc<dyn KvTree>,
|
||||
userroomid_invitestate: Arc<dyn KvTree>,
|
||||
userroomid_leftstate: Arc<dyn KvTree>,
|
||||
userroomid_notificationcount: Arc<dyn KvTree>,
|
||||
userroomid_highlightcount: Arc<dyn KvTree>,
|
||||
pduid_pdu: Arc<dyn KvTree>,
|
||||
keychangeid_userid: Arc<dyn KvTree>,
|
||||
roomusertype_roomuserdataid: Arc<dyn KvTree>,
|
||||
server_signingkeys: Arc<dyn KvTree>,
|
||||
readreceiptid_readreceipt: Arc<dyn KvTree>,
|
||||
userid_lastonetimekeyupdate: Arc<dyn KvTree>,
|
||||
pub(super) db: Arc<KeyValueDatabase>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Data for KeyValueDatabase {
|
||||
fn next_count(&self) -> Result<u64> {
|
||||
impl Data {
|
||||
pub(super) fn new(db: &Arc<KeyValueDatabase>) -> Self {
|
||||
Self {
|
||||
global: db.global.clone(),
|
||||
todeviceid_events: db.todeviceid_events.clone(),
|
||||
userroomid_joined: db.userroomid_joined.clone(),
|
||||
userroomid_invitestate: db.userroomid_invitestate.clone(),
|
||||
userroomid_leftstate: db.userroomid_leftstate.clone(),
|
||||
userroomid_notificationcount: db.userroomid_notificationcount.clone(),
|
||||
userroomid_highlightcount: db.userroomid_highlightcount.clone(),
|
||||
pduid_pdu: db.pduid_pdu.clone(),
|
||||
keychangeid_userid: db.keychangeid_userid.clone(),
|
||||
roomusertype_roomuserdataid: db.roomusertype_roomuserdataid.clone(),
|
||||
server_signingkeys: db.server_signingkeys.clone(),
|
||||
readreceiptid_readreceipt: db.readreceiptid_readreceipt.clone(),
|
||||
userid_lastonetimekeyupdate: db.userid_lastonetimekeyupdate.clone(),
|
||||
db: db.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_count(&self) -> Result<u64> {
|
||||
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
}
|
||||
|
||||
fn current_count(&self) -> Result<u64> {
|
||||
pub fn current_count(&self) -> Result<u64> {
|
||||
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
})
|
||||
}
|
||||
|
||||
fn last_check_for_updates_id(&self) -> Result<u64> {
|
||||
pub fn last_check_for_updates_id(&self) -> Result<u64> {
|
||||
self.global
|
||||
.get(LAST_CHECK_FOR_UPDATES_COUNT)?
|
||||
.map_or(Ok(0_u64), |bytes| {
|
||||
|
@ -68,16 +75,15 @@ impl Data for KeyValueDatabase {
|
|||
})
|
||||
}
|
||||
|
||||
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
|
||||
pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
|
||||
self.global
|
||||
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused_qualifications)] // async traits
|
||||
#[tracing::instrument(skip(self))]
|
||||
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
let userid_bytes = user_id.as_bytes().to_vec();
|
||||
let mut userid_prefix = userid_bytes.clone();
|
||||
userid_prefix.push(0xFF);
|
||||
|
@ -177,20 +183,20 @@ impl Data for KeyValueDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> Result<()> { self.db.cleanup() }
|
||||
pub fn cleanup(&self) -> Result<()> { self.db.db.cleanup() }
|
||||
|
||||
fn cork(&self) -> Cork { Cork::new(&self.db, false, false) }
|
||||
pub fn cork(&self) -> Cork { Cork::new(&self.db.db, false, false) }
|
||||
|
||||
fn cork_and_flush(&self) -> Cork { Cork::new(&self.db, true, false) }
|
||||
pub fn cork_and_flush(&self) -> Cork { Cork::new(&self.db.db, true, false) }
|
||||
|
||||
fn memory_usage(&self) -> String {
|
||||
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
|
||||
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len();
|
||||
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len();
|
||||
pub fn memory_usage(&self) -> String {
|
||||
let auth_chain_cache = self.db.auth_chain_cache.lock().unwrap().len();
|
||||
let appservice_in_room_cache = self.db.appservice_in_room_cache.read().unwrap().len();
|
||||
let lasttimelinecount_cache = self.db.lasttimelinecount_cache.lock().unwrap().len();
|
||||
|
||||
let max_auth_chain_cache = self.auth_chain_cache.lock().unwrap().capacity();
|
||||
let max_appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().capacity();
|
||||
let max_lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().capacity();
|
||||
let max_auth_chain_cache = self.db.auth_chain_cache.lock().unwrap().capacity();
|
||||
let max_appservice_in_room_cache = self.db.appservice_in_room_cache.read().unwrap().capacity();
|
||||
let max_lasttimelinecount_cache = self.db.lasttimelinecount_cache.lock().unwrap().capacity();
|
||||
|
||||
format!(
|
||||
"\
|
||||
|
@ -198,26 +204,26 @@ auth_chain_cache: {auth_chain_cache} / {max_auth_chain_cache}
|
|||
appservice_in_room_cache: {appservice_in_room_cache} / {max_appservice_in_room_cache}
|
||||
lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cache}\n\n
|
||||
{}",
|
||||
self.db.memory_usage().unwrap_or_default()
|
||||
self.db.db.memory_usage().unwrap_or_default()
|
||||
)
|
||||
}
|
||||
|
||||
fn clear_caches(&self, amount: u32) {
|
||||
pub fn clear_caches(&self, amount: u32) {
|
||||
if amount > 1 {
|
||||
let c = &mut *self.auth_chain_cache.lock().unwrap();
|
||||
let c = &mut *self.db.auth_chain_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 2 {
|
||||
let c = &mut *self.appservice_in_room_cache.write().unwrap();
|
||||
let c = &mut *self.db.appservice_in_room_cache.write().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
if amount > 3 {
|
||||
let c = &mut *self.lasttimelinecount_cache.lock().unwrap();
|
||||
let c = &mut *self.db.lasttimelinecount_cache.lock().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
}
|
||||
|
||||
fn load_keypair(&self) -> Result<Ed25519KeyPair> {
|
||||
pub fn load_keypair(&self) -> Result<Ed25519KeyPair> {
|
||||
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|
||||
|| {
|
||||
let keypair = utils::generate_keypair();
|
||||
|
@ -249,9 +255,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cach
|
|||
})
|
||||
}
|
||||
|
||||
fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
|
||||
pub fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
|
||||
|
||||
fn add_signing_key(
|
||||
pub fn add_signing_key(
|
||||
&self, origin: &ServerName, new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
// Not atomic, but this is not critical
|
||||
|
@ -290,8 +296,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cach
|
|||
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
|
||||
/// for the server.
|
||||
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
pub fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
let signingkeys = self
|
||||
.db
|
||||
.server_signingkeys
|
||||
.get(origin.as_bytes())?
|
||||
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
|
||||
|
@ -308,20 +315,20 @@ lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cach
|
|||
Ok(signingkeys)
|
||||
}
|
||||
|
||||
fn database_version(&self) -> Result<u64> {
|
||||
pub fn database_version(&self) -> Result<u64> {
|
||||
self.global.get(b"version")?.map_or(Ok(0), |version| {
|
||||
utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
fn bump_database_version(&self, new_version: u64) -> Result<()> {
|
||||
pub fn bump_database_version(&self, new_version: u64) -> Result<()> {
|
||||
self.global.insert(b"version", &new_version.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { self.db.backup() }
|
||||
pub fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { self.db.db.backup() }
|
||||
|
||||
fn backup_list(&self) -> Result<String> { self.db.backup_list() }
|
||||
pub fn backup_list(&self) -> Result<String> { self.db.db.backup_list() }
|
||||
|
||||
fn file_list(&self) -> Result<String> { self.db.file_list() }
|
||||
pub fn file_list(&self) -> Result<String> { self.db.db.file_list() }
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue