add rustfmt.toml, format entire codebase
Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
parent
9fd521f041
commit
f419c64aca
144 changed files with 25573 additions and 31053 deletions
|
@ -1,8 +1,8 @@
|
|||
use std::{future::Future, pin::Pin, sync::Arc};
|
||||
|
||||
use super::Config;
|
||||
use crate::Result;
|
||||
|
||||
use std::{future::Future, pin::Pin, sync::Arc};
|
||||
|
||||
#[cfg(feature = "sqlite")]
|
||||
pub mod sqlite;
|
||||
|
||||
|
@ -13,53 +13,44 @@ pub(crate) mod rocksdb;
|
|||
pub(crate) mod watchers;
|
||||
|
||||
pub(crate) trait KeyValueDatabaseEngine: Send + Sync {
|
||||
fn open(config: &Config) -> Result<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>>;
|
||||
fn flush(&self) -> Result<()>;
|
||||
fn cleanup(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
fn memory_usage(&self) -> Result<String> {
|
||||
Ok("Current database engine does not support memory usage reporting.".to_owned())
|
||||
}
|
||||
fn open(config: &Config) -> Result<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>>;
|
||||
fn flush(&self) -> Result<()>;
|
||||
fn cleanup(&self) -> Result<()> { Ok(()) }
|
||||
fn memory_usage(&self) -> Result<String> {
|
||||
Ok("Current database engine does not support memory usage reporting.".to_owned())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn clear_caches(&self) {}
|
||||
#[allow(dead_code)]
|
||||
fn clear_caches(&self) {}
|
||||
}
|
||||
|
||||
pub(crate) trait KvTree: Send + Sync {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;
|
||||
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()>;
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;
|
||||
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()>;
|
||||
|
||||
fn remove(&self, key: &[u8]) -> Result<()>;
|
||||
fn remove(&self, key: &[u8]) -> Result<()>;
|
||||
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
|
||||
fn iter_from<'a>(
|
||||
&'a self,
|
||||
from: &[u8],
|
||||
backwards: bool,
|
||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
|
||||
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>;
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
|
||||
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>;
|
||||
|
||||
fn scan_prefix<'a>(
|
||||
&'a self,
|
||||
prefix: Vec<u8>,
|
||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
|
||||
|
||||
fn clear(&self) -> Result<()> {
|
||||
for (key, _) in self.iter() {
|
||||
self.remove(&key)?;
|
||||
}
|
||||
fn clear(&self) -> Result<()> {
|
||||
for (key, _) in self.iter() {
|
||||
self.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,293 +1,265 @@
|
|||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::{Arc, RwLock},
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use rocksdb::LogLevel::{Debug, Error, Fatal, Info, Warn};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
use crate::{utils, Result};
|
||||
|
||||
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
|
||||
pub(crate) struct Engine {
|
||||
rocks: rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>,
|
||||
cache: rocksdb::Cache,
|
||||
old_cfs: Vec<String>,
|
||||
config: Config,
|
||||
rocks: rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>,
|
||||
cache: rocksdb::Cache,
|
||||
old_cfs: Vec<String>,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
struct RocksDbEngineTree<'a> {
|
||||
db: Arc<Engine>,
|
||||
name: &'a str,
|
||||
watchers: Watchers,
|
||||
write_lock: RwLock<()>,
|
||||
db: Arc<Engine>,
|
||||
name: &'a str,
|
||||
watchers: Watchers,
|
||||
write_lock: RwLock<()>,
|
||||
}
|
||||
|
||||
fn db_options(rocksdb_cache: &rocksdb::Cache, config: &Config) -> rocksdb::Options {
|
||||
// block-based options: https://docs.rs/rocksdb/latest/rocksdb/struct.BlockBasedOptions.html#
|
||||
let mut block_based_options = rocksdb::BlockBasedOptions::default();
|
||||
// block-based options: https://docs.rs/rocksdb/latest/rocksdb/struct.BlockBasedOptions.html#
|
||||
let mut block_based_options = rocksdb::BlockBasedOptions::default();
|
||||
|
||||
block_based_options.set_block_cache(rocksdb_cache);
|
||||
block_based_options.set_block_cache(rocksdb_cache);
|
||||
|
||||
// "Difference of spinning disk"
|
||||
// https://zhangyuchi.gitbooks.io/rocksdbbook/content/RocksDB-Tuning-Guide.html
|
||||
block_based_options.set_block_size(64 * 1024);
|
||||
block_based_options.set_cache_index_and_filter_blocks(true);
|
||||
// "Difference of spinning disk"
|
||||
// https://zhangyuchi.gitbooks.io/rocksdbbook/content/RocksDB-Tuning-Guide.html
|
||||
block_based_options.set_block_size(64 * 1024);
|
||||
block_based_options.set_cache_index_and_filter_blocks(true);
|
||||
|
||||
// database options: https://docs.rs/rocksdb/latest/rocksdb/struct.Options.html#
|
||||
let mut db_opts = rocksdb::Options::default();
|
||||
// database options: https://docs.rs/rocksdb/latest/rocksdb/struct.Options.html#
|
||||
let mut db_opts = rocksdb::Options::default();
|
||||
|
||||
let rocksdb_log_level = match config.rocksdb_log_level.as_ref() {
|
||||
"debug" => Debug,
|
||||
"info" => Info,
|
||||
"error" => Error,
|
||||
"fatal" => Fatal,
|
||||
_ => Warn,
|
||||
};
|
||||
let rocksdb_log_level = match config.rocksdb_log_level.as_ref() {
|
||||
"debug" => Debug,
|
||||
"info" => Info,
|
||||
"error" => Error,
|
||||
"fatal" => Fatal,
|
||||
_ => Warn,
|
||||
};
|
||||
|
||||
let threads = if config.rocksdb_parallelism_threads == 0 {
|
||||
num_cpus::get_physical() // max cores if user specified 0
|
||||
} else {
|
||||
config.rocksdb_parallelism_threads
|
||||
};
|
||||
let threads = if config.rocksdb_parallelism_threads == 0 {
|
||||
num_cpus::get_physical() // max cores if user specified 0
|
||||
} else {
|
||||
config.rocksdb_parallelism_threads
|
||||
};
|
||||
|
||||
db_opts.set_log_level(rocksdb_log_level);
|
||||
db_opts.set_max_log_file_size(config.rocksdb_max_log_file_size);
|
||||
db_opts.set_log_file_time_to_roll(config.rocksdb_log_time_to_roll);
|
||||
db_opts.set_log_level(rocksdb_log_level);
|
||||
db_opts.set_max_log_file_size(config.rocksdb_max_log_file_size);
|
||||
db_opts.set_log_file_time_to_roll(config.rocksdb_log_time_to_roll);
|
||||
|
||||
if config.rocksdb_optimize_for_spinning_disks {
|
||||
db_opts.set_skip_stats_update_on_db_open(true);
|
||||
db_opts.set_compaction_readahead_size(2 * 1024 * 1024); // default compaction_readahead_size is 0 which is good for SSDs
|
||||
db_opts.set_target_file_size_base(256 * 1024 * 1024); // default target_file_size is 64MB which is good for SSDs
|
||||
db_opts.set_optimize_filters_for_hits(true); // doesn't really seem useful for fast storage
|
||||
db_opts.set_keep_log_file_num(3); // keep as few LOG files as possible for spinning hard drives. these are not really important
|
||||
} else {
|
||||
db_opts.set_skip_stats_update_on_db_open(false);
|
||||
db_opts.set_max_bytes_for_level_base(512 * 1024 * 1024);
|
||||
db_opts.set_use_direct_reads(true);
|
||||
db_opts.set_use_direct_io_for_flush_and_compaction(true);
|
||||
db_opts.set_keep_log_file_num(20);
|
||||
}
|
||||
if config.rocksdb_optimize_for_spinning_disks {
|
||||
db_opts.set_skip_stats_update_on_db_open(true);
|
||||
db_opts.set_compaction_readahead_size(2 * 1024 * 1024); // default compaction_readahead_size is 0 which is good for SSDs
|
||||
db_opts.set_target_file_size_base(256 * 1024 * 1024); // default target_file_size is 64MB which is good for SSDs
|
||||
db_opts.set_optimize_filters_for_hits(true); // doesn't really seem useful for fast storage
|
||||
db_opts.set_keep_log_file_num(3); // keep as few LOG files as possible for
|
||||
// spinning hard drives. these are not really
|
||||
// important
|
||||
} else {
|
||||
db_opts.set_skip_stats_update_on_db_open(false);
|
||||
db_opts.set_max_bytes_for_level_base(512 * 1024 * 1024);
|
||||
db_opts.set_use_direct_reads(true);
|
||||
db_opts.set_use_direct_io_for_flush_and_compaction(true);
|
||||
db_opts.set_keep_log_file_num(20);
|
||||
}
|
||||
|
||||
db_opts.set_block_based_table_factory(&block_based_options);
|
||||
db_opts.set_level_compaction_dynamic_level_bytes(true);
|
||||
db_opts.create_if_missing(true);
|
||||
db_opts.increase_parallelism(
|
||||
threads
|
||||
.try_into()
|
||||
.expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"),
|
||||
);
|
||||
//db_opts.set_max_open_files(config.rocksdb_max_open_files);
|
||||
db_opts.set_compression_type(rocksdb::DBCompressionType::Zstd);
|
||||
db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level);
|
||||
db_opts.optimize_level_style_compaction(10 * 1024 * 1024);
|
||||
db_opts.set_block_based_table_factory(&block_based_options);
|
||||
db_opts.set_level_compaction_dynamic_level_bytes(true);
|
||||
db_opts.create_if_missing(true);
|
||||
db_opts.increase_parallelism(
|
||||
threads.try_into().expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"),
|
||||
);
|
||||
//db_opts.set_max_open_files(config.rocksdb_max_open_files);
|
||||
db_opts.set_compression_type(rocksdb::DBCompressionType::Zstd);
|
||||
db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level);
|
||||
db_opts.optimize_level_style_compaction(10 * 1024 * 1024);
|
||||
|
||||
// https://github.com/facebook/rocksdb/wiki/Setup-Options-and-Basic-Tuning
|
||||
db_opts.set_max_background_jobs(6);
|
||||
db_opts.set_bytes_per_sync(1_048_576);
|
||||
// https://github.com/facebook/rocksdb/wiki/Setup-Options-and-Basic-Tuning
|
||||
db_opts.set_max_background_jobs(6);
|
||||
db_opts.set_bytes_per_sync(1_048_576);
|
||||
|
||||
// https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes#ktoleratecorruptedtailrecords
|
||||
//
|
||||
// Unclean shutdowns of a Matrix homeserver are likely to be fine when
|
||||
// recovered in this manner as it's likely any lost information will be
|
||||
// restored via federation.
|
||||
db_opts.set_wal_recovery_mode(rocksdb::DBRecoveryMode::TolerateCorruptedTailRecords);
|
||||
// https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes#ktoleratecorruptedtailrecords
|
||||
//
|
||||
// Unclean shutdowns of a Matrix homeserver are likely to be fine when
|
||||
// recovered in this manner as it's likely any lost information will be
|
||||
// restored via federation.
|
||||
db_opts.set_wal_recovery_mode(rocksdb::DBRecoveryMode::TolerateCorruptedTailRecords);
|
||||
|
||||
let prefix_extractor = rocksdb::SliceTransform::create_fixed_prefix(1);
|
||||
db_opts.set_prefix_extractor(prefix_extractor);
|
||||
let prefix_extractor = rocksdb::SliceTransform::create_fixed_prefix(1);
|
||||
db_opts.set_prefix_extractor(prefix_extractor);
|
||||
|
||||
db_opts
|
||||
db_opts
|
||||
}
|
||||
|
||||
impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||
fn open(config: &Config) -> Result<Self> {
|
||||
let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize;
|
||||
let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes);
|
||||
fn open(config: &Config) -> Result<Self> {
|
||||
let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize;
|
||||
let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes);
|
||||
|
||||
let db_opts = db_options(&rocksdb_cache, config);
|
||||
let db_opts = db_options(&rocksdb_cache, config);
|
||||
|
||||
debug!("Listing column families in database");
|
||||
let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf(
|
||||
&db_opts,
|
||||
&config.database_path,
|
||||
)
|
||||
.unwrap_or_default();
|
||||
debug!("Listing column families in database");
|
||||
let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf(&db_opts, &config.database_path)
|
||||
.unwrap_or_default();
|
||||
|
||||
debug!("Opening column family descriptors in database");
|
||||
info!("RocksDB database compaction will take place now, a delay in startup is expected");
|
||||
let db = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::open_cf_descriptors(
|
||||
&db_opts,
|
||||
&config.database_path,
|
||||
cfs.iter().map(|name| {
|
||||
rocksdb::ColumnFamilyDescriptor::new(name, db_options(&rocksdb_cache, config))
|
||||
}),
|
||||
)?;
|
||||
debug!("Opening column family descriptors in database");
|
||||
info!("RocksDB database compaction will take place now, a delay in startup is expected");
|
||||
let db = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::open_cf_descriptors(
|
||||
&db_opts,
|
||||
&config.database_path,
|
||||
cfs.iter().map(|name| rocksdb::ColumnFamilyDescriptor::new(name, db_options(&rocksdb_cache, config))),
|
||||
)?;
|
||||
|
||||
Ok(Arc::new(Engine {
|
||||
rocks: db,
|
||||
cache: rocksdb_cache,
|
||||
old_cfs: cfs,
|
||||
config: config.clone(),
|
||||
}))
|
||||
}
|
||||
Ok(Arc::new(Engine {
|
||||
rocks: db,
|
||||
cache: rocksdb_cache,
|
||||
old_cfs: cfs,
|
||||
config: config.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>> {
|
||||
if !self.old_cfs.contains(&name.to_owned()) {
|
||||
// Create if it didn't exist
|
||||
debug!("Creating new column family in database: {}", name);
|
||||
let _ = self
|
||||
.rocks
|
||||
.create_cf(name, &db_options(&self.cache, &self.config));
|
||||
}
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>> {
|
||||
if !self.old_cfs.contains(&name.to_owned()) {
|
||||
// Create if it didn't exist
|
||||
debug!("Creating new column family in database: {}", name);
|
||||
let _ = self.rocks.create_cf(name, &db_options(&self.cache, &self.config));
|
||||
}
|
||||
|
||||
Ok(Arc::new(RocksDbEngineTree {
|
||||
name,
|
||||
db: Arc::clone(self),
|
||||
watchers: Watchers::default(),
|
||||
write_lock: RwLock::new(()),
|
||||
}))
|
||||
}
|
||||
Ok(Arc::new(RocksDbEngineTree {
|
||||
name,
|
||||
db: Arc::clone(self),
|
||||
watchers: Watchers::default(),
|
||||
write_lock: RwLock::new(()),
|
||||
}))
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<()> {
|
||||
// TODO?
|
||||
Ok(())
|
||||
}
|
||||
fn flush(&self) -> Result<()> {
|
||||
// TODO?
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn memory_usage(&self) -> Result<String> {
|
||||
let stats =
|
||||
rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?;
|
||||
Ok(format!(
|
||||
"Approximate memory usage of all the mem-tables: {:.3} MB\n\
|
||||
Approximate memory usage of un-flushed mem-tables: {:.3} MB\n\
|
||||
Approximate memory usage of all the table readers: {:.3} MB\n\
|
||||
Approximate memory usage by cache: {:.3} MB\n\
|
||||
Approximate memory usage by cache pinned: {:.3} MB\n\
|
||||
",
|
||||
stats.mem_table_total as f64 / 1024.0 / 1024.0,
|
||||
stats.mem_table_unflushed as f64 / 1024.0 / 1024.0,
|
||||
stats.mem_table_readers_total as f64 / 1024.0 / 1024.0,
|
||||
stats.cache_total as f64 / 1024.0 / 1024.0,
|
||||
self.cache.get_pinned_usage() as f64 / 1024.0 / 1024.0,
|
||||
))
|
||||
}
|
||||
fn memory_usage(&self) -> Result<String> {
|
||||
let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?;
|
||||
Ok(format!(
|
||||
"Approximate memory usage of all the mem-tables: {:.3} MB\nApproximate memory usage of un-flushed \
|
||||
mem-tables: {:.3} MB\nApproximate memory usage of all the table readers: {:.3} MB\nApproximate memory \
|
||||
usage by cache: {:.3} MB\nApproximate memory usage by cache pinned: {:.3} MB\n",
|
||||
stats.mem_table_total as f64 / 1024.0 / 1024.0,
|
||||
stats.mem_table_unflushed as f64 / 1024.0 / 1024.0,
|
||||
stats.mem_table_readers_total as f64 / 1024.0 / 1024.0,
|
||||
stats.cache_total as f64 / 1024.0 / 1024.0,
|
||||
self.cache.get_pinned_usage() as f64 / 1024.0 / 1024.0,
|
||||
))
|
||||
}
|
||||
|
||||
// TODO: figure out if this is needed for rocksdb
|
||||
#[allow(dead_code)]
|
||||
fn clear_caches(&self) {}
|
||||
// TODO: figure out if this is needed for rocksdb
|
||||
#[allow(dead_code)]
|
||||
fn clear_caches(&self) {}
|
||||
}
|
||||
|
||||
impl RocksDbEngineTree<'_> {
|
||||
fn cf(&self) -> Arc<rocksdb::BoundColumnFamily<'_>> {
|
||||
self.db.rocks.cf_handle(self.name).unwrap()
|
||||
}
|
||||
fn cf(&self) -> Arc<rocksdb::BoundColumnFamily<'_>> { self.db.rocks.cf_handle(self.name).unwrap() }
|
||||
}
|
||||
|
||||
impl KvTree for RocksDbEngineTree<'_> {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
Ok(self.db.rocks.get_cf(&self.cf(), key)?)
|
||||
}
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { Ok(self.db.rocks.get_cf(&self.cf(), key)?) }
|
||||
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
let lock = self.write_lock.read().unwrap();
|
||||
self.db.rocks.put_cf(&self.cf(), key, value)?;
|
||||
drop(lock);
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
let lock = self.write_lock.read().unwrap();
|
||||
self.db.rocks.put_cf(&self.cf(), key, value)?;
|
||||
drop(lock);
|
||||
|
||||
self.watchers.wake(key);
|
||||
self.watchers.wake(key);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
|
||||
for (key, value) in iter {
|
||||
self.db.rocks.put_cf(&self.cf(), key, value)?;
|
||||
}
|
||||
fn insert_batch(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
|
||||
for (key, value) in iter {
|
||||
self.db.rocks.put_cf(&self.cf(), key, value)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove(&self, key: &[u8]) -> Result<()> {
|
||||
Ok(self.db.rocks.delete_cf(&self.cf(), key)?)
|
||||
}
|
||||
fn remove(&self, key: &[u8]) -> Result<()> { Ok(self.db.rocks.delete_cf(&self.cf(), key)?) }
|
||||
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(&self.cf(), rocksdb::IteratorMode::Start)
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v))),
|
||||
)
|
||||
}
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(&self.cf(), rocksdb::IteratorMode::Start)
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v))),
|
||||
)
|
||||
}
|
||||
|
||||
fn iter_from<'a>(
|
||||
&'a self,
|
||||
from: &[u8],
|
||||
backwards: bool,
|
||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(
|
||||
&self.cf(),
|
||||
rocksdb::IteratorMode::From(
|
||||
from,
|
||||
if backwards {
|
||||
rocksdb::Direction::Reverse
|
||||
} else {
|
||||
rocksdb::Direction::Forward
|
||||
},
|
||||
),
|
||||
)
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v))),
|
||||
)
|
||||
}
|
||||
fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(
|
||||
&self.cf(),
|
||||
rocksdb::IteratorMode::From(
|
||||
from,
|
||||
if backwards {
|
||||
rocksdb::Direction::Reverse
|
||||
} else {
|
||||
rocksdb::Direction::Forward
|
||||
},
|
||||
),
|
||||
)
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v))),
|
||||
)
|
||||
}
|
||||
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
|
||||
let lock = self.write_lock.write().unwrap();
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
|
||||
let lock = self.write_lock.write().unwrap();
|
||||
|
||||
let old = self.db.rocks.get_cf(&self.cf(), key)?;
|
||||
let new = utils::increment(old.as_deref()).unwrap();
|
||||
self.db.rocks.put_cf(&self.cf(), key, &new)?;
|
||||
let old = self.db.rocks.get_cf(&self.cf(), key)?;
|
||||
let new = utils::increment(old.as_deref()).unwrap();
|
||||
self.db.rocks.put_cf(&self.cf(), key, &new)?;
|
||||
|
||||
drop(lock);
|
||||
Ok(new)
|
||||
}
|
||||
drop(lock);
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
|
||||
let lock = self.write_lock.write().unwrap();
|
||||
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
|
||||
let lock = self.write_lock.write().unwrap();
|
||||
|
||||
for key in iter {
|
||||
let old = self.db.rocks.get_cf(&self.cf(), &key)?;
|
||||
let new = utils::increment(old.as_deref()).unwrap();
|
||||
self.db.rocks.put_cf(&self.cf(), key, new)?;
|
||||
}
|
||||
for key in iter {
|
||||
let old = self.db.rocks.get_cf(&self.cf(), &key)?;
|
||||
let new = utils::increment(old.as_deref()).unwrap();
|
||||
self.db.rocks.put_cf(&self.cf(), key, new)?;
|
||||
}
|
||||
|
||||
drop(lock);
|
||||
drop(lock);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scan_prefix<'a>(
|
||||
&'a self,
|
||||
prefix: Vec<u8>,
|
||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(
|
||||
&self.cf(),
|
||||
rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward),
|
||||
)
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v)))
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix)),
|
||||
)
|
||||
}
|
||||
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(&self.cf(), rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward))
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v)))
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix)),
|
||||
)
|
||||
}
|
||||
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
self.watchers.watch(prefix)
|
||||
}
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
self.watchers.watch(prefix)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,340 +1,305 @@
|
|||
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
use crate::{database::Config, Result};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
future::Future,
|
||||
path::{Path, PathBuf},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use parking_lot::{Mutex, MutexGuard};
|
||||
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
future::Future,
|
||||
path::{Path, PathBuf},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
use thread_local::ThreadLocal;
|
||||
use tracing::debug;
|
||||
|
||||
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
use crate::{database::Config, Result};
|
||||
|
||||
thread_local! {
|
||||
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
||||
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
||||
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
||||
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
||||
}
|
||||
|
||||
struct PreparedStatementIterator<'a> {
|
||||
pub iterator: Box<dyn Iterator<Item = TupleOfBytes> + 'a>,
|
||||
pub _statement_ref: NonAliasingBox<rusqlite::Statement<'a>>,
|
||||
pub iterator: Box<dyn Iterator<Item = TupleOfBytes> + 'a>,
|
||||
pub _statement_ref: NonAliasingBox<rusqlite::Statement<'a>>,
|
||||
}
|
||||
|
||||
impl Iterator for PreparedStatementIterator<'_> {
|
||||
type Item = TupleOfBytes;
|
||||
type Item = TupleOfBytes;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.iterator.next()
|
||||
}
|
||||
fn next(&mut self) -> Option<Self::Item> { self.iterator.next() }
|
||||
}
|
||||
|
||||
struct NonAliasingBox<T>(*mut T);
|
||||
impl<T> Drop for NonAliasingBox<T> {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
let _ = Box::from_raw(self.0);
|
||||
};
|
||||
}
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
let _ = Box::from_raw(self.0);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Engine {
|
||||
writer: Mutex<Connection>,
|
||||
read_conn_tls: ThreadLocal<Connection>,
|
||||
read_iterator_conn_tls: ThreadLocal<Connection>,
|
||||
writer: Mutex<Connection>,
|
||||
read_conn_tls: ThreadLocal<Connection>,
|
||||
read_iterator_conn_tls: ThreadLocal<Connection>,
|
||||
|
||||
path: PathBuf,
|
||||
cache_size_per_thread: u32,
|
||||
path: PathBuf,
|
||||
cache_size_per_thread: u32,
|
||||
}
|
||||
|
||||
impl Engine {
|
||||
fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result<Connection> {
|
||||
let conn = Connection::open(path)?;
|
||||
fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result<Connection> {
|
||||
let conn = Connection::open(path)?;
|
||||
|
||||
conn.pragma_update(Some(Main), "page_size", 2048)?;
|
||||
conn.pragma_update(Some(Main), "journal_mode", "WAL")?;
|
||||
conn.pragma_update(Some(Main), "synchronous", "NORMAL")?;
|
||||
conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?;
|
||||
conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?;
|
||||
conn.pragma_update(Some(Main), "page_size", 2048)?;
|
||||
conn.pragma_update(Some(Main), "journal_mode", "WAL")?;
|
||||
conn.pragma_update(Some(Main), "synchronous", "NORMAL")?;
|
||||
conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?;
|
||||
conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?;
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
fn write_lock(&self) -> MutexGuard<'_, Connection> {
|
||||
self.writer.lock()
|
||||
}
|
||||
fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() }
|
||||
|
||||
fn read_lock(&self) -> &Connection {
|
||||
self.read_conn_tls
|
||||
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||
}
|
||||
fn read_lock(&self) -> &Connection {
|
||||
self.read_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||
}
|
||||
|
||||
fn read_lock_iterator(&self) -> &Connection {
|
||||
self.read_iterator_conn_tls
|
||||
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||
}
|
||||
fn read_lock_iterator(&self) -> &Connection {
|
||||
self.read_iterator_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||
}
|
||||
|
||||
pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
|
||||
self.write_lock()
|
||||
.pragma_update(Some(Main), "wal_checkpoint", "RESTART")?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
|
||||
self.write_lock().pragma_update(Some(Main), "wal_checkpoint", "RESTART")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||
fn open(config: &Config) -> Result<Self> {
|
||||
let path = Path::new(&config.database_path).join("conduit.db");
|
||||
fn open(config: &Config) -> Result<Self> {
|
||||
let path = Path::new(&config.database_path).join("conduit.db");
|
||||
|
||||
// calculates cache-size per permanent connection
|
||||
// 1. convert MB to KiB
|
||||
// 2. divide by permanent connections + permanent iter connections + write connection
|
||||
// 3. round down to nearest integer
|
||||
let cache_size_per_thread: u32 = ((config.db_cache_capacity_mb * 1024.0)
|
||||
/ ((num_cpus::get().max(1) * 2) + 1) as f64)
|
||||
as u32;
|
||||
// calculates cache-size per permanent connection
|
||||
// 1. convert MB to KiB
|
||||
// 2. divide by permanent connections + permanent iter connections + write
|
||||
// connection
|
||||
// 3. round down to nearest integer
|
||||
let cache_size_per_thread: u32 =
|
||||
((config.db_cache_capacity_mb * 1024.0) / ((num_cpus::get().max(1) * 2) + 1) as f64) as u32;
|
||||
|
||||
let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?);
|
||||
let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?);
|
||||
|
||||
let arc = Arc::new(Engine {
|
||||
writer,
|
||||
read_conn_tls: ThreadLocal::new(),
|
||||
read_iterator_conn_tls: ThreadLocal::new(),
|
||||
path,
|
||||
cache_size_per_thread,
|
||||
});
|
||||
let arc = Arc::new(Engine {
|
||||
writer,
|
||||
read_conn_tls: ThreadLocal::new(),
|
||||
read_iterator_conn_tls: ThreadLocal::new(),
|
||||
path,
|
||||
cache_size_per_thread,
|
||||
});
|
||||
|
||||
Ok(arc)
|
||||
}
|
||||
Ok(arc)
|
||||
}
|
||||
|
||||
fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> {
|
||||
self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [])?;
|
||||
fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> {
|
||||
self.write_lock().execute(
|
||||
&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"),
|
||||
[],
|
||||
)?;
|
||||
|
||||
Ok(Arc::new(SqliteTable {
|
||||
engine: Arc::clone(self),
|
||||
name: name.to_owned(),
|
||||
watchers: Watchers::default(),
|
||||
}))
|
||||
}
|
||||
Ok(Arc::new(SqliteTable {
|
||||
engine: Arc::clone(self),
|
||||
name: name.to_owned(),
|
||||
watchers: Watchers::default(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<()> {
|
||||
// we enabled PRAGMA synchronous=normal, so this should not be necessary
|
||||
Ok(())
|
||||
}
|
||||
fn flush(&self) -> Result<()> {
|
||||
// we enabled PRAGMA synchronous=normal, so this should not be necessary
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> Result<()> {
|
||||
self.flush_wal()
|
||||
}
|
||||
fn cleanup(&self) -> Result<()> { self.flush_wal() }
|
||||
}
|
||||
|
||||
pub struct SqliteTable {
|
||||
engine: Arc<Engine>,
|
||||
name: String,
|
||||
watchers: Watchers,
|
||||
engine: Arc<Engine>,
|
||||
name: String,
|
||||
watchers: Watchers,
|
||||
}
|
||||
|
||||
type TupleOfBytes = (Vec<u8>, Vec<u8>);
|
||||
|
||||
impl SqliteTable {
|
||||
fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
Ok(guard
|
||||
.prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())?
|
||||
.query_row([key], |row| row.get(0))
|
||||
.optional()?)
|
||||
}
|
||||
fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
Ok(guard
|
||||
.prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())?
|
||||
.query_row([key], |row| row.get(0))
|
||||
.optional()?)
|
||||
}
|
||||
|
||||
fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
guard.execute(
|
||||
format!(
|
||||
"INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)",
|
||||
self.name
|
||||
)
|
||||
.as_str(),
|
||||
[key, value],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
guard.execute(
|
||||
format!("INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", self.name).as_str(),
|
||||
[key, value],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn iter_with_guard<'a>(
|
||||
&'a self,
|
||||
guard: &'a Connection,
|
||||
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard
|
||||
.prepare(&format!(
|
||||
"SELECT key, value FROM {} ORDER BY key ASC",
|
||||
&self.name
|
||||
))
|
||||
.unwrap(),
|
||||
));
|
||||
pub fn iter_with_guard<'a>(&'a self, guard: &'a Connection) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard.prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name)).unwrap(),
|
||||
));
|
||||
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
|
||||
//let name = self.name.clone();
|
||||
//let name = self.name.clone();
|
||||
|
||||
let iterator = Box::new(
|
||||
statement
|
||||
.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
|
||||
.unwrap()
|
||||
.map(move |r| r.unwrap()),
|
||||
);
|
||||
let iterator = Box::new(
|
||||
statement.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))).unwrap().map(move |r| r.unwrap()),
|
||||
);
|
||||
|
||||
Box::new(PreparedStatementIterator {
|
||||
iterator,
|
||||
_statement_ref: statement_ref,
|
||||
})
|
||||
}
|
||||
Box::new(PreparedStatementIterator {
|
||||
iterator,
|
||||
_statement_ref: statement_ref,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl KvTree for SqliteTable {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
self.get_with_guard(self.engine.read_lock(), key)
|
||||
}
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { self.get_with_guard(self.engine.read_lock(), key) }
|
||||
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
self.insert_with_guard(&guard, key, value)?;
|
||||
drop(guard);
|
||||
self.watchers.wake(key);
|
||||
Ok(())
|
||||
}
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
self.insert_with_guard(&guard, key, value)?;
|
||||
drop(guard);
|
||||
self.watchers.wake(key);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn insert_batch<'a>(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
fn insert_batch<'a>(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
|
||||
guard.execute("BEGIN", [])?;
|
||||
for (key, value) in iter {
|
||||
self.insert_with_guard(&guard, &key, &value)?;
|
||||
}
|
||||
guard.execute("COMMIT", [])?;
|
||||
guard.execute("BEGIN", [])?;
|
||||
for (key, value) in iter {
|
||||
self.insert_with_guard(&guard, &key, &value)?;
|
||||
}
|
||||
guard.execute("COMMIT", [])?;
|
||||
|
||||
drop(guard);
|
||||
drop(guard);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn increment_batch<'a>(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
fn increment_batch<'a>(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
|
||||
guard.execute("BEGIN", [])?;
|
||||
for key in iter {
|
||||
let old = self.get_with_guard(&guard, &key)?;
|
||||
let new = crate::utils::increment(old.as_deref())
|
||||
.expect("utils::increment always returns Some");
|
||||
self.insert_with_guard(&guard, &key, &new)?;
|
||||
}
|
||||
guard.execute("COMMIT", [])?;
|
||||
guard.execute("BEGIN", [])?;
|
||||
for key in iter {
|
||||
let old = self.get_with_guard(&guard, &key)?;
|
||||
let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
|
||||
self.insert_with_guard(&guard, &key, &new)?;
|
||||
}
|
||||
guard.execute("COMMIT", [])?;
|
||||
|
||||
drop(guard);
|
||||
drop(guard);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove(&self, key: &[u8]) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
fn remove(&self, key: &[u8]) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
|
||||
guard.execute(
|
||||
format!("DELETE FROM {} WHERE key = ?", self.name).as_str(),
|
||||
[key],
|
||||
)?;
|
||||
guard.execute(format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), [key])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let guard = self.engine.read_lock_iterator();
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let guard = self.engine.read_lock_iterator();
|
||||
|
||||
self.iter_with_guard(guard)
|
||||
}
|
||||
self.iter_with_guard(guard)
|
||||
}
|
||||
|
||||
fn iter_from<'a>(
|
||||
&'a self,
|
||||
from: &[u8],
|
||||
backwards: bool,
|
||||
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let guard = self.engine.read_lock_iterator();
|
||||
let from = from.to_vec(); // TODO change interface?
|
||||
fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let guard = self.engine.read_lock_iterator();
|
||||
let from = from.to_vec(); // TODO change interface?
|
||||
|
||||
//let name = self.name.clone();
|
||||
//let name = self.name.clone();
|
||||
|
||||
if backwards {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard
|
||||
.prepare(&format!(
|
||||
"SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC",
|
||||
&self.name
|
||||
))
|
||||
.unwrap(),
|
||||
));
|
||||
if backwards {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard
|
||||
.prepare(&format!(
|
||||
"SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC",
|
||||
&self.name
|
||||
))
|
||||
.unwrap(),
|
||||
));
|
||||
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
|
||||
let iterator = Box::new(
|
||||
statement
|
||||
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
|
||||
.unwrap()
|
||||
.map(move |r| r.unwrap()),
|
||||
);
|
||||
Box::new(PreparedStatementIterator {
|
||||
iterator,
|
||||
_statement_ref: statement_ref,
|
||||
})
|
||||
} else {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard
|
||||
.prepare(&format!(
|
||||
"SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC",
|
||||
&self.name
|
||||
))
|
||||
.unwrap(),
|
||||
));
|
||||
let iterator = Box::new(
|
||||
statement
|
||||
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
|
||||
.unwrap()
|
||||
.map(move |r| r.unwrap()),
|
||||
);
|
||||
Box::new(PreparedStatementIterator {
|
||||
iterator,
|
||||
_statement_ref: statement_ref,
|
||||
})
|
||||
} else {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard
|
||||
.prepare(&format!(
|
||||
"SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC",
|
||||
&self.name
|
||||
))
|
||||
.unwrap(),
|
||||
));
|
||||
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
|
||||
let iterator = Box::new(
|
||||
statement
|
||||
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
|
||||
.unwrap()
|
||||
.map(move |r| r.unwrap()),
|
||||
);
|
||||
let iterator = Box::new(
|
||||
statement
|
||||
.query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
|
||||
.unwrap()
|
||||
.map(move |r| r.unwrap()),
|
||||
);
|
||||
|
||||
Box::new(PreparedStatementIterator {
|
||||
iterator,
|
||||
_statement_ref: statement_ref,
|
||||
})
|
||||
}
|
||||
}
|
||||
Box::new(PreparedStatementIterator {
|
||||
iterator,
|
||||
_statement_ref: statement_ref,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
|
||||
let guard = self.engine.write_lock();
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
|
||||
let guard = self.engine.write_lock();
|
||||
|
||||
let old = self.get_with_guard(&guard, key)?;
|
||||
let old = self.get_with_guard(&guard, key)?;
|
||||
|
||||
let new =
|
||||
crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
|
||||
let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
|
||||
|
||||
self.insert_with_guard(&guard, key, &new)?;
|
||||
self.insert_with_guard(&guard, key, &new)?;
|
||||
|
||||
Ok(new)
|
||||
}
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
Box::new(
|
||||
self.iter_from(&prefix, false)
|
||||
.take_while(move |(key, _)| key.starts_with(&prefix)),
|
||||
)
|
||||
}
|
||||
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
Box::new(self.iter_from(&prefix, false).take_while(move |(key, _)| key.starts_with(&prefix)))
|
||||
}
|
||||
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
self.watchers.watch(prefix)
|
||||
}
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
self.watchers.watch(prefix)
|
||||
}
|
||||
|
||||
fn clear(&self) -> Result<()> {
|
||||
debug!("clear: running");
|
||||
self.engine
|
||||
.write_lock()
|
||||
.execute(format!("DELETE FROM {}", self.name).as_str(), [])?;
|
||||
debug!("clear: ran");
|
||||
Ok(())
|
||||
}
|
||||
fn clear(&self) -> Result<()> {
|
||||
debug!("clear: running");
|
||||
self.engine.write_lock().execute(format!("DELETE FROM {}", self.name).as_str(), [])?;
|
||||
debug!("clear: ran");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,56 +1,55 @@
|
|||
use std::{
|
||||
collections::{hash_map, HashMap},
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::RwLock,
|
||||
collections::{hash_map, HashMap},
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::RwLock,
|
||||
};
|
||||
|
||||
use tokio::sync::watch;
|
||||
|
||||
type Watcher = RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>;
|
||||
|
||||
#[derive(Default)]
|
||||
pub(super) struct Watchers {
|
||||
watchers: Watcher,
|
||||
watchers: Watcher,
|
||||
}
|
||||
|
||||
impl Watchers {
|
||||
pub(super) fn watch<'a>(
|
||||
&'a self,
|
||||
prefix: &[u8],
|
||||
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) {
|
||||
hash_map::Entry::Occupied(o) => o.get().1.clone(),
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
let (tx, rx) = tokio::sync::watch::channel(());
|
||||
v.insert((tx, rx.clone()));
|
||||
rx
|
||||
}
|
||||
};
|
||||
pub(super) fn watch<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) {
|
||||
hash_map::Entry::Occupied(o) => o.get().1.clone(),
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
let (tx, rx) = tokio::sync::watch::channel(());
|
||||
v.insert((tx, rx.clone()));
|
||||
rx
|
||||
},
|
||||
};
|
||||
|
||||
Box::pin(async move {
|
||||
// Tx is never destroyed
|
||||
rx.changed().await.unwrap();
|
||||
})
|
||||
}
|
||||
pub(super) fn wake(&self, key: &[u8]) {
|
||||
let watchers = self.watchers.read().unwrap();
|
||||
let mut triggered = Vec::new();
|
||||
Box::pin(async move {
|
||||
// Tx is never destroyed
|
||||
rx.changed().await.unwrap();
|
||||
})
|
||||
}
|
||||
|
||||
for length in 0..=key.len() {
|
||||
if watchers.contains_key(&key[..length]) {
|
||||
triggered.push(&key[..length]);
|
||||
}
|
||||
}
|
||||
pub(super) fn wake(&self, key: &[u8]) {
|
||||
let watchers = self.watchers.read().unwrap();
|
||||
let mut triggered = Vec::new();
|
||||
|
||||
drop(watchers);
|
||||
for length in 0..=key.len() {
|
||||
if watchers.contains_key(&key[..length]) {
|
||||
triggered.push(&key[..length]);
|
||||
}
|
||||
}
|
||||
|
||||
if !triggered.is_empty() {
|
||||
let mut watchers = self.watchers.write().unwrap();
|
||||
for prefix in triggered {
|
||||
if let Some(tx) = watchers.remove(prefix) {
|
||||
let _ = tx.0.send(());
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
drop(watchers);
|
||||
|
||||
if !triggered.is_empty() {
|
||||
let mut watchers = self.watchers.write().unwrap();
|
||||
for prefix in triggered {
|
||||
if let Some(tx) = watchers.remove(prefix) {
|
||||
let _ = tx.0.send(());
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,148 +1,120 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use ruma::{
|
||||
api::client::error::ErrorKind,
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
api::client::error::ErrorKind,
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::account_data::Data for KeyValueDatabase {
|
||||
/// Places one event in the account data of the user and removes the previous entry.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
||||
fn update(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id
|
||||
.map(std::string::ToString::to_string)
|
||||
.unwrap_or_default()
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(user_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
/// Places one event in the account data of the user and removes the
|
||||
/// previous entry.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
||||
fn update(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(user_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut roomuserdataid = prefix.clone();
|
||||
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
roomuserdataid.push(0xff);
|
||||
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
|
||||
let mut roomuserdataid = prefix.clone();
|
||||
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
roomuserdataid.push(0xFF);
|
||||
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
|
||||
|
||||
let mut key = prefix;
|
||||
key.extend_from_slice(event_type.to_string().as_bytes());
|
||||
let mut key = prefix;
|
||||
key.extend_from_slice(event_type.to_string().as_bytes());
|
||||
|
||||
if data.get("type").is_none() || data.get("content").is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Account data doesn't have all required fields.",
|
||||
));
|
||||
}
|
||||
if data.get("type").is_none() || data.get("content").is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Account data doesn't have all required fields.",
|
||||
));
|
||||
}
|
||||
|
||||
self.roomuserdataid_accountdata.insert(
|
||||
&roomuserdataid,
|
||||
&serde_json::to_vec(&data).expect("to_vec always works on json values"),
|
||||
)?;
|
||||
self.roomuserdataid_accountdata.insert(
|
||||
&roomuserdataid,
|
||||
&serde_json::to_vec(&data).expect("to_vec always works on json values"),
|
||||
)?;
|
||||
|
||||
let prev = self.roomusertype_roomuserdataid.get(&key)?;
|
||||
let prev = self.roomusertype_roomuserdataid.get(&key)?;
|
||||
|
||||
self.roomusertype_roomuserdataid
|
||||
.insert(&key, &roomuserdataid)?;
|
||||
self.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
|
||||
|
||||
// Remove old entry
|
||||
if let Some(prev) = prev {
|
||||
self.roomuserdataid_accountdata.remove(&prev)?;
|
||||
}
|
||||
// Remove old entry
|
||||
if let Some(prev) = prev {
|
||||
self.roomuserdataid_accountdata.remove(&prev)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Searches the account data for a specific kind.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, kind))]
|
||||
fn get(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
kind: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
||||
let mut key = room_id
|
||||
.map(std::string::ToString::to_string)
|
||||
.unwrap_or_default()
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(kind.to_string().as_bytes());
|
||||
/// Searches the account data for a specific kind.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, kind))]
|
||||
fn get(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
||||
let mut key = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(kind.to_string().as_bytes());
|
||||
|
||||
self.roomusertype_roomuserdataid
|
||||
.get(&key)?
|
||||
.and_then(|roomuserdataid| {
|
||||
self.roomuserdataid_accountdata
|
||||
.get(&roomuserdataid)
|
||||
.transpose()
|
||||
})
|
||||
.transpose()?
|
||||
.map(|data| {
|
||||
serde_json::from_slice(&data)
|
||||
.map_err(|_| Error::bad_database("could not deserialize"))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
self.roomusertype_roomuserdataid
|
||||
.get(&key)?
|
||||
.and_then(|roomuserdataid| self.roomuserdataid_accountdata.get(&roomuserdataid).transpose())
|
||||
.transpose()?
|
||||
.map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||
fn changes_since(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
||||
let mut userdata = HashMap::new();
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||
fn changes_since(
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
||||
let mut userdata = HashMap::new();
|
||||
|
||||
let mut prefix = room_id
|
||||
.map(std::string::ToString::to_string)
|
||||
.unwrap_or_default()
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(user_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
let mut prefix = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(user_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
// Skip the data that's exactly at since, because we sent that last time
|
||||
let mut first_possible = prefix.clone();
|
||||
first_possible.extend_from_slice(&(since + 1).to_be_bytes());
|
||||
// Skip the data that's exactly at since, because we sent that last time
|
||||
let mut first_possible = prefix.clone();
|
||||
first_possible.extend_from_slice(&(since + 1).to_be_bytes());
|
||||
|
||||
for r in self
|
||||
.roomuserdataid_accountdata
|
||||
.iter_from(&first_possible, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(|(k, v)| {
|
||||
Ok::<_, Error>((
|
||||
RoomAccountDataEventType::from(
|
||||
utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else(
|
||||
|| Error::bad_database("RoomUserData ID in db is invalid."),
|
||||
)?)
|
||||
.map_err(|e| {
|
||||
warn!("RoomUserData ID in database is invalid: {}", e);
|
||||
Error::bad_database("RoomUserData ID in db is invalid.")
|
||||
})?,
|
||||
),
|
||||
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v).map_err(|_| {
|
||||
Error::bad_database("Database contains invalid account data.")
|
||||
})?,
|
||||
))
|
||||
})
|
||||
{
|
||||
let (kind, data) = r?;
|
||||
userdata.insert(kind, data);
|
||||
}
|
||||
for r in self
|
||||
.roomuserdataid_accountdata
|
||||
.iter_from(&first_possible, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(|(k, v)| {
|
||||
Ok::<_, Error>((
|
||||
RoomAccountDataEventType::from(
|
||||
utils::string_from_bytes(
|
||||
k.rsplit(|&b| b == 0xFF)
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("RoomUserData ID in db is invalid."))?,
|
||||
)
|
||||
.map_err(|e| {
|
||||
warn!("RoomUserData ID in database is invalid: {}", e);
|
||||
Error::bad_database("RoomUserData ID in db is invalid.")
|
||||
})?,
|
||||
),
|
||||
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v)
|
||||
.map_err(|_| Error::bad_database("Database contains invalid account data."))?,
|
||||
))
|
||||
}) {
|
||||
let (kind, data) = r?;
|
||||
userdata.insert(kind, data);
|
||||
}
|
||||
|
||||
Ok(userdata)
|
||||
}
|
||||
Ok(userdata)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,78 +3,58 @@ use ruma::api::appservice::Registration;
|
|||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::appservice::Data for KeyValueDatabase {
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
||||
let id = yaml.id.as_str();
|
||||
self.id_appserviceregistrations.insert(
|
||||
id.as_bytes(),
|
||||
serde_yaml::to_string(&yaml).unwrap().as_bytes(),
|
||||
)?;
|
||||
self.cached_registrations
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(id.to_owned(), yaml.clone());
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
||||
let id = yaml.id.as_str();
|
||||
self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
|
||||
self.cached_registrations.write().unwrap().insert(id.to_owned(), yaml.clone());
|
||||
|
||||
Ok(id.to_owned())
|
||||
}
|
||||
Ok(id.to_owned())
|
||||
}
|
||||
|
||||
/// Remove an appservice registration
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
self.id_appserviceregistrations
|
||||
.remove(service_name.as_bytes())?;
|
||||
self.cached_registrations
|
||||
.write()
|
||||
.unwrap()
|
||||
.remove(service_name);
|
||||
Ok(())
|
||||
}
|
||||
/// Remove an appservice registration
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
self.id_appserviceregistrations.remove(service_name.as_bytes())?;
|
||||
self.cached_registrations.write().unwrap().remove(service_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
|
||||
self.cached_registrations
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(id)
|
||||
.map_or_else(
|
||||
|| {
|
||||
self.id_appserviceregistrations
|
||||
.get(id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
serde_yaml::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Invalid registration bytes in id_appserviceregistrations.",
|
||||
)
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
},
|
||||
|r| Ok(Some(r.clone())),
|
||||
)
|
||||
}
|
||||
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
|
||||
self.cached_registrations.read().unwrap().get(id).map_or_else(
|
||||
|| {
|
||||
self.id_appserviceregistrations
|
||||
.get(id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
serde_yaml::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
},
|
||||
|r| Ok(Some(r.clone())),
|
||||
)
|
||||
}
|
||||
|
||||
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
|
||||
Ok(Box::new(self.id_appserviceregistrations.iter().map(
|
||||
|(id, _)| {
|
||||
utils::string_from_bytes(&id).map_err(|_| {
|
||||
Error::bad_database("Invalid id bytes in id_appserviceregistrations.")
|
||||
})
|
||||
},
|
||||
)))
|
||||
}
|
||||
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
|
||||
Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
|
||||
utils::string_from_bytes(&id)
|
||||
.map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
|
||||
})))
|
||||
}
|
||||
|
||||
fn all(&self) -> Result<Vec<(String, Registration)>> {
|
||||
self.iter_ids()?
|
||||
.filter_map(std::result::Result::ok)
|
||||
.map(move |id| {
|
||||
Ok((
|
||||
id.clone(),
|
||||
self.get_registration(&id)?
|
||||
.expect("iter_ids only returns appservices that exist"),
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
fn all(&self) -> Result<Vec<(String, Registration)>> {
|
||||
self.iter_ids()?
|
||||
.filter_map(std::result::Result::ok)
|
||||
.map(move |id| {
|
||||
Ok((
|
||||
id.clone(),
|
||||
self.get_registration(&id)?.expect("iter_ids only returns appservices that exist"),
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,9 +4,9 @@ use async_trait::async_trait;
|
|||
use futures_util::{stream::FuturesUnordered, StreamExt};
|
||||
use lru_cache::LruCache;
|
||||
use ruma::{
|
||||
api::federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||
signatures::Ed25519KeyPair,
|
||||
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
|
||||
api::federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||
signatures::Ed25519KeyPair,
|
||||
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
@ -16,139 +16,118 @@ const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
|
|||
|
||||
#[async_trait]
|
||||
impl service::globals::Data for KeyValueDatabase {
|
||||
fn next_count(&self) -> Result<u64> {
|
||||
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
}
|
||||
fn next_count(&self) -> Result<u64> {
|
||||
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
}
|
||||
|
||||
fn current_count(&self) -> Result<u64> {
|
||||
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
})
|
||||
}
|
||||
fn current_count(&self) -> Result<u64> {
|
||||
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
})
|
||||
}
|
||||
|
||||
fn last_check_for_updates_id(&self) -> Result<u64> {
|
||||
self.global
|
||||
.get(LAST_CHECK_FOR_UPDATES_COUNT)?
|
||||
.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("last check for updates count has invalid bytes.")
|
||||
})
|
||||
})
|
||||
}
|
||||
fn last_check_for_updates_id(&self) -> Result<u64> {
|
||||
self.global.get(LAST_CHECK_FOR_UPDATES_COUNT)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("last check for updates count has invalid bytes."))
|
||||
})
|
||||
}
|
||||
|
||||
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
|
||||
self.global
|
||||
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
|
||||
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
|
||||
self.global.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
let userid_bytes = user_id.as_bytes().to_vec();
|
||||
let mut userid_prefix = userid_bytes.clone();
|
||||
userid_prefix.push(0xff);
|
||||
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
let userid_bytes = user_id.as_bytes().to_vec();
|
||||
let mut userid_prefix = userid_bytes.clone();
|
||||
userid_prefix.push(0xFF);
|
||||
|
||||
let mut userdeviceid_prefix = userid_prefix.clone();
|
||||
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
|
||||
userdeviceid_prefix.push(0xff);
|
||||
let mut userdeviceid_prefix = userid_prefix.clone();
|
||||
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
|
||||
userdeviceid_prefix.push(0xFF);
|
||||
|
||||
let mut futures = FuturesUnordered::new();
|
||||
let mut futures = FuturesUnordered::new();
|
||||
|
||||
// Return when *any* user changed his key
|
||||
// TODO: only send for user they share a room with
|
||||
futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix));
|
||||
// Return when *any* user changed his key
|
||||
// TODO: only send for user they share a room with
|
||||
futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix));
|
||||
|
||||
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
|
||||
futures.push(
|
||||
self.userroomid_notificationcount
|
||||
.watch_prefix(&userid_prefix),
|
||||
);
|
||||
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_notificationcount.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
|
||||
|
||||
// Events for rooms we are in
|
||||
for room_id in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(user_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
let short_roomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(&room_id)
|
||||
.ok()
|
||||
.flatten()
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
// Events for rooms we are in
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(std::result::Result::ok) {
|
||||
let short_roomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(&room_id)
|
||||
.ok()
|
||||
.flatten()
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
|
||||
let roomid_bytes = room_id.as_bytes().to_vec();
|
||||
let mut roomid_prefix = roomid_bytes.clone();
|
||||
roomid_prefix.push(0xff);
|
||||
let roomid_bytes = room_id.as_bytes().to_vec();
|
||||
let mut roomid_prefix = roomid_bytes.clone();
|
||||
roomid_prefix.push(0xFF);
|
||||
|
||||
// PDUs
|
||||
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
|
||||
// PDUs
|
||||
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
|
||||
|
||||
// EDUs
|
||||
futures.push(self.roomid_lasttypingupdate.watch_prefix(&roomid_bytes));
|
||||
// EDUs
|
||||
futures.push(self.roomid_lasttypingupdate.watch_prefix(&roomid_bytes));
|
||||
|
||||
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
|
||||
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
|
||||
|
||||
// Key changes
|
||||
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
|
||||
// Key changes
|
||||
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
|
||||
|
||||
// Room account data
|
||||
let mut roomuser_prefix = roomid_prefix.clone();
|
||||
roomuser_prefix.extend_from_slice(&userid_prefix);
|
||||
// Room account data
|
||||
let mut roomuser_prefix = roomid_prefix.clone();
|
||||
roomuser_prefix.extend_from_slice(&userid_prefix);
|
||||
|
||||
futures.push(
|
||||
self.roomusertype_roomuserdataid
|
||||
.watch_prefix(&roomuser_prefix),
|
||||
);
|
||||
}
|
||||
futures.push(self.roomusertype_roomuserdataid.watch_prefix(&roomuser_prefix));
|
||||
}
|
||||
|
||||
let mut globaluserdata_prefix = vec![0xff];
|
||||
globaluserdata_prefix.extend_from_slice(&userid_prefix);
|
||||
let mut globaluserdata_prefix = vec![0xFF];
|
||||
globaluserdata_prefix.extend_from_slice(&userid_prefix);
|
||||
|
||||
futures.push(
|
||||
self.roomusertype_roomuserdataid
|
||||
.watch_prefix(&globaluserdata_prefix),
|
||||
);
|
||||
futures.push(self.roomusertype_roomuserdataid.watch_prefix(&globaluserdata_prefix));
|
||||
|
||||
// More key changes (used when user is not joined to any rooms)
|
||||
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
|
||||
// More key changes (used when user is not joined to any rooms)
|
||||
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
|
||||
|
||||
// One time keys
|
||||
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
|
||||
// One time keys
|
||||
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
|
||||
|
||||
futures.push(Box::pin(services().globals.rotate.watch()));
|
||||
futures.push(Box::pin(services().globals.rotate.watch()));
|
||||
|
||||
// Wait until one of them finds something
|
||||
futures.next().await;
|
||||
// Wait until one of them finds something
|
||||
futures.next().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> Result<()> {
|
||||
self.db.cleanup()
|
||||
}
|
||||
fn cleanup(&self) -> Result<()> { self.db.cleanup() }
|
||||
|
||||
fn memory_usage(&self) -> String {
|
||||
let pdu_cache = self.pdu_cache.lock().unwrap().len();
|
||||
let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len();
|
||||
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
|
||||
let eventidshort_cache = self.eventidshort_cache.lock().unwrap().len();
|
||||
let statekeyshort_cache = self.statekeyshort_cache.lock().unwrap().len();
|
||||
let our_real_users_cache = self.our_real_users_cache.read().unwrap().len();
|
||||
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len();
|
||||
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len();
|
||||
fn memory_usage(&self) -> String {
|
||||
let pdu_cache = self.pdu_cache.lock().unwrap().len();
|
||||
let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len();
|
||||
let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len();
|
||||
let eventidshort_cache = self.eventidshort_cache.lock().unwrap().len();
|
||||
let statekeyshort_cache = self.statekeyshort_cache.lock().unwrap().len();
|
||||
let our_real_users_cache = self.our_real_users_cache.read().unwrap().len();
|
||||
let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len();
|
||||
let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len();
|
||||
|
||||
let mut response = format!(
|
||||
"\
|
||||
let mut response = format!(
|
||||
"\
|
||||
pdu_cache: {pdu_cache}
|
||||
shorteventid_cache: {shorteventid_cache}
|
||||
auth_chain_cache: {auth_chain_cache}
|
||||
|
@ -157,155 +136,137 @@ statekeyshort_cache: {statekeyshort_cache}
|
|||
our_real_users_cache: {our_real_users_cache}
|
||||
appservice_in_room_cache: {appservice_in_room_cache}
|
||||
lasttimelinecount_cache: {lasttimelinecount_cache}\n"
|
||||
);
|
||||
if let Ok(db_stats) = self.db.memory_usage() {
|
||||
response += &db_stats;
|
||||
}
|
||||
);
|
||||
if let Ok(db_stats) = self.db.memory_usage() {
|
||||
response += &db_stats;
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
response
|
||||
}
|
||||
|
||||
fn clear_caches(&self, amount: u32) {
|
||||
if amount > 0 {
|
||||
let c = &mut *self.pdu_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 1 {
|
||||
let c = &mut *self.shorteventid_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 2 {
|
||||
let c = &mut *self.auth_chain_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 3 {
|
||||
let c = &mut *self.eventidshort_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 4 {
|
||||
let c = &mut *self.statekeyshort_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 5 {
|
||||
let c = &mut *self.our_real_users_cache.write().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
if amount > 6 {
|
||||
let c = &mut *self.appservice_in_room_cache.write().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
if amount > 7 {
|
||||
let c = &mut *self.lasttimelinecount_cache.lock().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
}
|
||||
fn clear_caches(&self, amount: u32) {
|
||||
if amount > 0 {
|
||||
let c = &mut *self.pdu_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 1 {
|
||||
let c = &mut *self.shorteventid_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 2 {
|
||||
let c = &mut *self.auth_chain_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 3 {
|
||||
let c = &mut *self.eventidshort_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 4 {
|
||||
let c = &mut *self.statekeyshort_cache.lock().unwrap();
|
||||
*c = LruCache::new(c.capacity());
|
||||
}
|
||||
if amount > 5 {
|
||||
let c = &mut *self.our_real_users_cache.write().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
if amount > 6 {
|
||||
let c = &mut *self.appservice_in_room_cache.write().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
if amount > 7 {
|
||||
let c = &mut *self.lasttimelinecount_cache.lock().unwrap();
|
||||
*c = HashMap::new();
|
||||
}
|
||||
}
|
||||
|
||||
fn load_keypair(&self) -> Result<Ed25519KeyPair> {
|
||||
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|
||||
|| {
|
||||
let keypair = utils::generate_keypair();
|
||||
self.global.insert(b"keypair", &keypair)?;
|
||||
Ok::<_, Error>(keypair)
|
||||
},
|
||||
Ok,
|
||||
)?;
|
||||
fn load_keypair(&self) -> Result<Ed25519KeyPair> {
|
||||
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|
||||
|| {
|
||||
let keypair = utils::generate_keypair();
|
||||
self.global.insert(b"keypair", &keypair)?;
|
||||
Ok::<_, Error>(keypair)
|
||||
},
|
||||
Ok,
|
||||
)?;
|
||||
|
||||
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff);
|
||||
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF);
|
||||
|
||||
utils::string_from_bytes(
|
||||
// 1. version
|
||||
parts
|
||||
.next()
|
||||
.expect("splitn always returns at least one element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
|
||||
.and_then(|version| {
|
||||
// 2. key
|
||||
parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
|
||||
.map(|key| (version, key))
|
||||
})
|
||||
.and_then(|(version, key)| {
|
||||
Ed25519KeyPair::from_der(key, version)
|
||||
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
|
||||
})
|
||||
}
|
||||
fn remove_keypair(&self) -> Result<()> {
|
||||
self.global.remove(b"keypair")
|
||||
}
|
||||
utils::string_from_bytes(
|
||||
// 1. version
|
||||
parts.next().expect("splitn always returns at least one element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
|
||||
.and_then(|version| {
|
||||
// 2. key
|
||||
parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
|
||||
.map(|key| (version, key))
|
||||
})
|
||||
.and_then(|(version, key)| {
|
||||
Ed25519KeyPair::from_der(key, version)
|
||||
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
fn add_signing_key(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
// Not atomic, but this is not critical
|
||||
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
|
||||
fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
|
||||
|
||||
let mut keys = signingkeys
|
||||
.and_then(|keys| serde_json::from_slice(&keys).ok())
|
||||
.unwrap_or_else(|| {
|
||||
// Just insert "now", it doesn't matter
|
||||
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
|
||||
});
|
||||
fn add_signing_key(
|
||||
&self, origin: &ServerName, new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
// Not atomic, but this is not critical
|
||||
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
|
||||
|
||||
let ServerSigningKeys {
|
||||
verify_keys,
|
||||
old_verify_keys,
|
||||
..
|
||||
} = new_keys;
|
||||
let mut keys = signingkeys.and_then(|keys| serde_json::from_slice(&keys).ok()).unwrap_or_else(|| {
|
||||
// Just insert "now", it doesn't matter
|
||||
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
|
||||
});
|
||||
|
||||
keys.verify_keys.extend(verify_keys);
|
||||
keys.old_verify_keys.extend(old_verify_keys);
|
||||
let ServerSigningKeys {
|
||||
verify_keys,
|
||||
old_verify_keys,
|
||||
..
|
||||
} = new_keys;
|
||||
|
||||
self.server_signingkeys.insert(
|
||||
origin.as_bytes(),
|
||||
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
|
||||
)?;
|
||||
keys.verify_keys.extend(verify_keys);
|
||||
keys.old_verify_keys.extend(old_verify_keys);
|
||||
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(
|
||||
keys.old_verify_keys
|
||||
.into_iter()
|
||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
||||
);
|
||||
self.server_signingkeys.insert(
|
||||
origin.as_bytes(),
|
||||
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
|
||||
)?;
|
||||
|
||||
Ok(tree)
|
||||
}
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key))));
|
||||
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
||||
fn signing_keys_for(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
let signingkeys = self
|
||||
.server_signingkeys
|
||||
.get(origin.as_bytes())?
|
||||
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
|
||||
.map(|keys: ServerSigningKeys| {
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(
|
||||
keys.old_verify_keys
|
||||
.into_iter()
|
||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
||||
);
|
||||
tree
|
||||
})
|
||||
.unwrap_or_else(BTreeMap::new);
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
Ok(signingkeys)
|
||||
}
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
|
||||
/// for the server.
|
||||
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
let signingkeys = self
|
||||
.server_signingkeys
|
||||
.get(origin.as_bytes())?
|
||||
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
|
||||
.map(|keys: ServerSigningKeys| {
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key))));
|
||||
tree
|
||||
})
|
||||
.unwrap_or_else(BTreeMap::new);
|
||||
|
||||
fn database_version(&self) -> Result<u64> {
|
||||
self.global.get(b"version")?.map_or(Ok(0), |version| {
|
||||
utils::u64_from_bytes(&version)
|
||||
.map_err(|_| Error::bad_database("Database version id is invalid."))
|
||||
})
|
||||
}
|
||||
Ok(signingkeys)
|
||||
}
|
||||
|
||||
fn bump_database_version(&self, new_version: u64) -> Result<()> {
|
||||
self.global.insert(b"version", &new_version.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
fn database_version(&self) -> Result<u64> {
|
||||
self.global.get(b"version")?.map_or(Ok(0), |version| {
|
||||
utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
fn bump_database_version(&self, new_version: u64) -> Result<()> {
|
||||
self.global.insert(b"version", &new_version.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,364 +1,292 @@
|
|||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::{
|
||||
api::client::{
|
||||
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
error::ErrorKind,
|
||||
},
|
||||
serde::Raw,
|
||||
OwnedRoomId, RoomId, UserId,
|
||||
api::client::{
|
||||
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
error::ErrorKind,
|
||||
},
|
||||
serde::Raw,
|
||||
OwnedRoomId, RoomId, UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::key_backups::Data for KeyValueDatabase {
|
||||
fn create_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
let version = services().globals.next_count()?.to_string();
|
||||
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
|
||||
let version = services().globals.next_count()?.to_string();
|
||||
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
|
||||
)?;
|
||||
self.backupid_etag
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version)
|
||||
}
|
||||
self.backupid_algorithm.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
|
||||
)?;
|
||||
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm.remove(&key)?;
|
||||
self.backupid_etag.remove(&key)?;
|
||||
self.backupid_algorithm.remove(&key)?;
|
||||
self.backupid_etag.remove(&key)?;
|
||||
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn update_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Tried to update nonexistent backup.",
|
||||
));
|
||||
}
|
||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
|
||||
}
|
||||
|
||||
self.backupid_algorithm
|
||||
.insert(&key, backup_metadata.json().get().as_bytes())?;
|
||||
self.backupid_etag
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version.to_owned())
|
||||
}
|
||||
self.backupid_algorithm.insert(&key, backup_metadata.json().get().as_bytes())?;
|
||||
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version.to_owned())
|
||||
}
|
||||
|
||||
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
self.backupid_algorithm
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.next()
|
||||
.map(|(key, _)| {
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
self.backupid_algorithm
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.next()
|
||||
.map(|(key, _)| {
|
||||
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_latest_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
self.backupid_algorithm
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.next()
|
||||
.map(|(key, value)| {
|
||||
let version = utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
|
||||
self.backupid_algorithm
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.next()
|
||||
.map(|(key, value)| {
|
||||
let version = utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
|
||||
|
||||
Ok((
|
||||
version,
|
||||
serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("Algorithm in backupid_algorithm is invalid.")
|
||||
})?,
|
||||
))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
Ok((
|
||||
version,
|
||||
serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?,
|
||||
))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
serde_json::from_slice(&bytes)
|
||||
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
|
||||
})
|
||||
}
|
||||
self.backupid_algorithm.get(&key)?.map_or(Ok(None), |bytes| {
|
||||
serde_json::from_slice(&bytes)
|
||||
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
fn add_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
fn add_key(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Tried to update nonexistent backup.",
|
||||
));
|
||||
}
|
||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
|
||||
}
|
||||
|
||||
self.backupid_etag
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
|
||||
self.backupkeyid_backup
|
||||
.insert(&key, key_data.json().get().as_bytes())?;
|
||||
self.backupkeyid_backup.insert(&key, key_data.json().get().as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
|
||||
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
|
||||
}
|
||||
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
|
||||
}
|
||||
|
||||
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
Ok(utils::u64_from_bytes(
|
||||
&self
|
||||
.backupid_etag
|
||||
.get(&key)?
|
||||
.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
|
||||
.to_string())
|
||||
}
|
||||
Ok(utils::u64_from_bytes(
|
||||
&self.backupid_etag.get(&key)?.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
fn get_all(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
prefix.push(0xff);
|
||||
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
|
||||
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
|
||||
|
||||
for result in self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
for result in self.backupkeyid_backup.scan_prefix(prefix).map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xFF);
|
||||
|
||||
let session_id =
|
||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup session_id is invalid.")
|
||||
})?;
|
||||
let session_id = utils::string_from_bytes(
|
||||
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
|
||||
|
||||
let room_id = RoomId::parse(
|
||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup room_id is invalid room id.")
|
||||
})?;
|
||||
let room_id = RoomId::parse(
|
||||
utils::string_from_bytes(
|
||||
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?;
|
||||
|
||||
let key_data = serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})?;
|
||||
let key_data = serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
|
||||
|
||||
Ok::<_, Error>((room_id, session_id, key_data))
|
||||
})
|
||||
{
|
||||
let (room_id, session_id, key_data) = result?;
|
||||
rooms
|
||||
.entry(room_id)
|
||||
.or_insert_with(|| RoomKeyBackup {
|
||||
sessions: BTreeMap::new(),
|
||||
})
|
||||
.sessions
|
||||
.insert(session_id, key_data);
|
||||
}
|
||||
Ok::<_, Error>((room_id, session_id, key_data))
|
||||
}) {
|
||||
let (room_id, session_id, key_data) = result?;
|
||||
rooms
|
||||
.entry(room_id)
|
||||
.or_insert_with(|| RoomKeyBackup {
|
||||
sessions: BTreeMap::new(),
|
||||
})
|
||||
.sessions
|
||||
.insert(session_id, key_data);
|
||||
}
|
||||
|
||||
Ok(rooms)
|
||||
}
|
||||
Ok(rooms)
|
||||
}
|
||||
|
||||
fn get_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
fn get_room(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
Ok(self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
Ok(self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xFF);
|
||||
|
||||
let session_id =
|
||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup session_id is invalid.")
|
||||
})?;
|
||||
let session_id = utils::string_from_bytes(
|
||||
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
|
||||
|
||||
let key_data = serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})?;
|
||||
let key_data = serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
|
||||
|
||||
Ok::<_, Error>((session_id, key_data))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect())
|
||||
}
|
||||
Ok::<_, Error>((session_id, key_data))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn get_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
fn get_session(
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
|
||||
self.backupkeyid_backup
|
||||
.get(&key)?
|
||||
.map(|value| {
|
||||
serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
self.backupkeyid_backup
|
||||
.get(&key)?
|
||||
.map(|value| {
|
||||
serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xFF);
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn delete_room_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,245 +2,182 @@ use ruma::api::client::error::ErrorKind;
|
|||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{self, media::UrlPreviewData},
|
||||
utils, Error, Result,
|
||||
database::KeyValueDatabase,
|
||||
service::{self, media::UrlPreviewData},
|
||||
utils, Error, Result,
|
||||
};
|
||||
|
||||
impl service::media::Data for KeyValueDatabase {
|
||||
fn create_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>> {
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(&width.to_be_bytes());
|
||||
key.extend_from_slice(&height.to_be_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_disposition
|
||||
.as_ref()
|
||||
.map(|f| f.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_type
|
||||
.as_ref()
|
||||
.map(|c| c.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
fn create_file_metadata(
|
||||
&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>> {
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(&width.to_be_bytes());
|
||||
key.extend_from_slice(&height.to_be_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default());
|
||||
|
||||
self.mediaid_file.insert(&key, &[])?;
|
||||
self.mediaid_file.insert(&key, &[])?;
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
fn delete_file_mxc(&self, mxc: String) -> Result<()> {
|
||||
debug!("MXC URI: {:?}", mxc);
|
||||
fn delete_file_mxc(&self, mxc: String) -> Result<()> {
|
||||
debug!("MXC URI: {:?}", mxc);
|
||||
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
debug!("MXC db prefix: {:?}", prefix);
|
||||
debug!("MXC db prefix: {:?}", prefix);
|
||||
|
||||
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
|
||||
debug!("Deleting key: {:?}", key);
|
||||
self.mediaid_file.remove(&key)?;
|
||||
}
|
||||
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
|
||||
debug!("Deleting key: {:?}", key);
|
||||
self.mediaid_file.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Searches for all files with the given MXC
|
||||
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>> {
|
||||
debug!("MXC URI: {:?}", mxc);
|
||||
/// Searches for all files with the given MXC
|
||||
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>> {
|
||||
debug!("MXC URI: {:?}", mxc);
|
||||
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut keys: Vec<Vec<u8>> = vec![];
|
||||
let mut keys: Vec<Vec<u8>> = vec![];
|
||||
|
||||
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
|
||||
keys.push(key);
|
||||
}
|
||||
for (key, _) in self.mediaid_file.scan_prefix(prefix) {
|
||||
keys.push(key);
|
||||
}
|
||||
|
||||
if keys.is_empty() {
|
||||
return Err(Error::bad_database(
|
||||
"Failed to find any keys in database with the provided MXC.",
|
||||
));
|
||||
}
|
||||
if keys.is_empty() {
|
||||
return Err(Error::bad_database(
|
||||
"Failed to find any keys in database with the provided MXC.",
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Got the following keys: {:?}", keys);
|
||||
debug!("Got the following keys: {:?}", keys);
|
||||
|
||||
Ok(keys)
|
||||
}
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
fn search_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(&width.to_be_bytes());
|
||||
prefix.extend_from_slice(&height.to_be_bytes());
|
||||
prefix.push(0xff);
|
||||
fn search_file_metadata(
|
||||
&self, mxc: String, width: u32, height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(&width.to_be_bytes());
|
||||
prefix.extend_from_slice(&height.to_be_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
let (key, _) = self
|
||||
.mediaid_file
|
||||
.scan_prefix(prefix)
|
||||
.next()
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
|
||||
let (key, _) = self
|
||||
.mediaid_file
|
||||
.scan_prefix(prefix)
|
||||
.next()
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
|
||||
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
let mut parts = key.rsplit(|&b| b == 0xFF);
|
||||
|
||||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Content type in mediaid_file is invalid unicode.")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes)
|
||||
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let content_disposition_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
let content_disposition_bytes =
|
||||
parts.next().ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
|
||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
utils::string_from_bytes(content_disposition_bytes).map_err(|_| {
|
||||
Error::bad_database("Content Disposition in mediaid_file is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
};
|
||||
Ok((content_disposition, content_type, key))
|
||||
}
|
||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
utils::string_from_bytes(content_disposition_bytes)
|
||||
.map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?,
|
||||
)
|
||||
};
|
||||
Ok((content_disposition, content_type, key))
|
||||
}
|
||||
|
||||
/// Gets all the media keys in our database (this includes all the metadata associated with it such as width, height, content-type, etc)
|
||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
|
||||
let mut keys: Vec<Vec<u8>> = vec![];
|
||||
/// Gets all the media keys in our database (this includes all the metadata
|
||||
/// associated with it such as width, height, content-type, etc)
|
||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
|
||||
let mut keys: Vec<Vec<u8>> = vec![];
|
||||
|
||||
for (key, _) in self.mediaid_file.iter() {
|
||||
keys.push(key);
|
||||
}
|
||||
for (key, _) in self.mediaid_file.iter() {
|
||||
keys.push(key);
|
||||
}
|
||||
|
||||
Ok(keys)
|
||||
}
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
fn remove_url_preview(&self, url: &str) -> Result<()> {
|
||||
self.url_previews.remove(url.as_bytes())
|
||||
}
|
||||
fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) }
|
||||
|
||||
fn set_url_preview(
|
||||
&self,
|
||||
url: &str,
|
||||
data: &UrlPreviewData,
|
||||
timestamp: std::time::Duration,
|
||||
) -> Result<()> {
|
||||
let mut value = Vec::<u8>::new();
|
||||
value.extend_from_slice(×tamp.as_secs().to_be_bytes());
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(
|
||||
data.title
|
||||
.as_ref()
|
||||
.map(std::string::String::as_bytes)
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(
|
||||
data.description
|
||||
.as_ref()
|
||||
.map(std::string::String::as_bytes)
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(
|
||||
data.image
|
||||
.as_ref()
|
||||
.map(std::string::String::as_bytes)
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
|
||||
fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> {
|
||||
let mut value = Vec::<u8>::new();
|
||||
value.extend_from_slice(×tamp.as_secs().to_be_bytes());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(data.title.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(data.description.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(data.image.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
|
||||
|
||||
self.url_previews.insert(url.as_bytes(), &value)
|
||||
}
|
||||
self.url_previews.insert(url.as_bytes(), &value)
|
||||
}
|
||||
|
||||
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
|
||||
let values = self.url_previews.get(url.as_bytes()).ok()??;
|
||||
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
|
||||
let values = self.url_previews.get(url.as_bytes()).ok()??;
|
||||
|
||||
let mut values = values.split(|&b| b == 0xff);
|
||||
let mut values = values.split(|&b| b == 0xFF);
|
||||
|
||||
let _ts = match values
|
||||
.next()
|
||||
.map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array")))
|
||||
{
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let title = match values
|
||||
.next()
|
||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
||||
{
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let description = match values
|
||||
.next()
|
||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
||||
{
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let image = match values
|
||||
.next()
|
||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
||||
{
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let image_size = match values
|
||||
.next()
|
||||
.map(|b| usize::from_be_bytes(b.try_into().expect("valid BE array")))
|
||||
{
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let image_width = match values
|
||||
.next()
|
||||
.map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array")))
|
||||
{
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let image_height = match values
|
||||
.next()
|
||||
.map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array")))
|
||||
{
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let _ts = match values.next().map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let title = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let description = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let image = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let image_size = match values.next().map(|b| usize::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let image_width = match values.next().map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let image_height = match values.next().map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
|
||||
Some(UrlPreviewData {
|
||||
title,
|
||||
description,
|
||||
image,
|
||||
image_size,
|
||||
image_width,
|
||||
image_height,
|
||||
})
|
||||
}
|
||||
Some(UrlPreviewData {
|
||||
title,
|
||||
description,
|
||||
image,
|
||||
image_size,
|
||||
image_width,
|
||||
image_height,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,79 +1,63 @@
|
|||
use ruma::{
|
||||
api::client::push::{set_pusher, Pusher},
|
||||
UserId,
|
||||
api::client::push::{set_pusher, Pusher},
|
||||
UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::pusher::Data for KeyValueDatabase {
|
||||
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
|
||||
match &pusher {
|
||||
set_pusher::v3::PusherAction::Post(data) => {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
|
||||
self.senderkey_pusher.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
set_pusher::v3::PusherAction::Delete(ids) => {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(ids.pushkey.as_bytes());
|
||||
self.senderkey_pusher
|
||||
.remove(&key)
|
||||
.map(|_| ())
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
}
|
||||
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> {
|
||||
match &pusher {
|
||||
set_pusher::v3::PusherAction::Post(data) => {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
|
||||
self.senderkey_pusher
|
||||
.insert(&key, &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"))?;
|
||||
Ok(())
|
||||
},
|
||||
set_pusher::v3::PusherAction::Delete(ids) => {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(ids.pushkey.as_bytes());
|
||||
self.senderkey_pusher.remove(&key).map(|_| ()).map_err(Into::into)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
|
||||
let mut senderkey = sender.as_bytes().to_vec();
|
||||
senderkey.push(0xff);
|
||||
senderkey.extend_from_slice(pushkey.as_bytes());
|
||||
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
|
||||
let mut senderkey = sender.as_bytes().to_vec();
|
||||
senderkey.push(0xFF);
|
||||
senderkey.extend_from_slice(pushkey.as_bytes());
|
||||
|
||||
self.senderkey_pusher
|
||||
.get(&senderkey)?
|
||||
.map(|push| {
|
||||
serde_json::from_slice(&push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
self.senderkey_pusher
|
||||
.get(&senderkey)?
|
||||
.map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
self.senderkey_pusher
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, push)| {
|
||||
serde_json::from_slice(&push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
self.senderkey_pusher
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn get_pushkeys<'a>(
|
||||
&'a self,
|
||||
sender: &UserId,
|
||||
) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
|
||||
let mut parts = k.splitn(2, |&b| b == 0xff);
|
||||
let _senderkey = parts.next();
|
||||
let push_key = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
|
||||
let push_key_string = utils::string_from_bytes(push_key)
|
||||
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
|
||||
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
|
||||
let mut parts = k.splitn(2, |&b| b == 0xFF);
|
||||
let _senderkey = parts.next();
|
||||
let push_key = parts.next().ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
|
||||
let push_key_string = utils::string_from_bytes(push_key)
|
||||
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
|
||||
|
||||
Ok(push_key_string)
|
||||
}))
|
||||
}
|
||||
Ok(push_key_string)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,82 +3,68 @@ use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAli
|
|||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::alias::Data for KeyValueDatabase {
|
||||
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
|
||||
self.alias_roomid
|
||||
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
|
||||
let mut aliasid = room_id.as_bytes().to_vec();
|
||||
aliasid.push(0xff);
|
||||
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
|
||||
self.alias_roomid.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
|
||||
let mut aliasid = room_id.as_bytes().to_vec();
|
||||
aliasid.push(0xFF);
|
||||
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
|
||||
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
|
||||
let mut prefix = room_id;
|
||||
prefix.push(0xff);
|
||||
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
|
||||
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
|
||||
let mut prefix = room_id;
|
||||
prefix.push(0xFF);
|
||||
|
||||
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
|
||||
self.aliasid_alias.remove(&key)?;
|
||||
}
|
||||
self.alias_roomid.remove(alias.alias().as_bytes())?;
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Alias does not exist.",
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
|
||||
self.aliasid_alias.remove(&key)?;
|
||||
}
|
||||
self.alias_roomid.remove(alias.alias().as_bytes())?;
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
|
||||
self.alias_roomid
|
||||
.get(alias.alias().as_bytes())?
|
||||
.map(|bytes| {
|
||||
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Room ID in alias_roomid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
|
||||
self.alias_roomid
|
||||
.get(alias.alias().as_bytes())?
|
||||
.map(|bytes| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn local_aliases_for_room<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn local_aliases_for_room<'a>(
|
||||
&'a self, room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
|
||||
}))
|
||||
}
|
||||
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
|
||||
}))
|
||||
}
|
||||
|
||||
fn all_local_aliases<'a>(
|
||||
&'a self,
|
||||
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
||||
Box::new(
|
||||
self.alias_roomid
|
||||
.iter()
|
||||
.map(|(room_alias_bytes, room_id_bytes)| {
|
||||
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid alias bytes in aliasid_alias.")
|
||||
})?;
|
||||
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
||||
Box::new(self.alias_roomid.iter().map(|(room_alias_bytes, room_id_bytes)| {
|
||||
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?;
|
||||
|
||||
let room_id = utils::string_from_bytes(&room_id_bytes)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid room_id bytes in aliasid_alias.")
|
||||
})?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
|
||||
let room_id = utils::string_from_bytes(&room_id_bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
|
||||
|
||||
Ok((room_id, room_alias_localpart))
|
||||
}),
|
||||
)
|
||||
}
|
||||
Ok((room_id, room_alias_localpart))
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,59 +3,47 @@ use std::{collections::HashSet, mem::size_of, sync::Arc};
|
|||
use crate::{database::KeyValueDatabase, service, utils, Result};
|
||||
|
||||
impl service::rooms::auth_chain::Data for KeyValueDatabase {
|
||||
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
|
||||
// Check RAM cache
|
||||
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
|
||||
return Ok(Some(Arc::clone(result)));
|
||||
}
|
||||
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
|
||||
// Check RAM cache
|
||||
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
|
||||
return Ok(Some(Arc::clone(result)));
|
||||
}
|
||||
|
||||
// We only save auth chains for single events in the db
|
||||
if key.len() == 1 {
|
||||
// Check DB cache
|
||||
let chain = self
|
||||
.shorteventid_authchain
|
||||
.get(&key[0].to_be_bytes())?
|
||||
.map(|chain| {
|
||||
chain
|
||||
.chunks_exact(size_of::<u64>())
|
||||
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
|
||||
.collect()
|
||||
});
|
||||
// We only save auth chains for single events in the db
|
||||
if key.len() == 1 {
|
||||
// Check DB cache
|
||||
let chain = self.shorteventid_authchain.get(&key[0].to_be_bytes())?.map(|chain| {
|
||||
chain
|
||||
.chunks_exact(size_of::<u64>())
|
||||
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
|
||||
.collect()
|
||||
});
|
||||
|
||||
if let Some(chain) = chain {
|
||||
let chain = Arc::new(chain);
|
||||
if let Some(chain) = chain {
|
||||
let chain = Arc::new(chain);
|
||||
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(vec![key[0]], Arc::clone(&chain));
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache.lock().unwrap().insert(vec![key[0]], Arc::clone(&chain));
|
||||
|
||||
return Ok(Some(chain));
|
||||
}
|
||||
}
|
||||
return Ok(Some(chain));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
|
||||
// Only persist single events in db
|
||||
if key.len() == 1 {
|
||||
self.shorteventid_authchain.insert(
|
||||
&key[0].to_be_bytes(),
|
||||
&auth_chain
|
||||
.iter()
|
||||
.flat_map(|s| s.to_be_bytes().to_vec())
|
||||
.collect::<Vec<u8>>(),
|
||||
)?;
|
||||
}
|
||||
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
|
||||
// Only persist single events in db
|
||||
if key.len() == 1 {
|
||||
self.shorteventid_authchain.insert(
|
||||
&key[0].to_be_bytes(),
|
||||
&auth_chain.iter().flat_map(|s| s.to_be_bytes().to_vec()).collect::<Vec<u8>>(),
|
||||
)?;
|
||||
}
|
||||
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(key, auth_chain);
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache.lock().unwrap().insert(key, auth_chain);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,26 +3,21 @@ use ruma::{OwnedRoomId, RoomId};
|
|||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::rooms::directory::Data for KeyValueDatabase {
|
||||
fn set_public(&self, room_id: &RoomId) -> Result<()> {
|
||||
self.publicroomids.insert(room_id.as_bytes(), &[])
|
||||
}
|
||||
fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) }
|
||||
|
||||
fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
|
||||
self.publicroomids.remove(room_id.as_bytes())
|
||||
}
|
||||
fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) }
|
||||
|
||||
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
|
||||
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Room ID in publicroomids is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
|
||||
}))
|
||||
}
|
||||
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,178 +1,155 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use ruma::{
|
||||
events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId,
|
||||
};
|
||||
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::edus::presence::Presence},
|
||||
services,
|
||||
utils::{self, user_id_from_bytes},
|
||||
Error, Result,
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::edus::presence::Presence},
|
||||
services,
|
||||
utils::{self, user_id_from_bytes},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
||||
fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>> {
|
||||
let key = presence_key(room_id, user_id);
|
||||
fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>> {
|
||||
let key = presence_key(room_id, user_id);
|
||||
|
||||
self.roomuserid_presence
|
||||
.get(&key)?
|
||||
.map(|presence_bytes| -> Result<PresenceEvent> {
|
||||
Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
self.roomuserid_presence
|
||||
.get(&key)?
|
||||
.map(|presence_bytes| -> Result<PresenceEvent> {
|
||||
Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> {
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let mut state_changed = false;
|
||||
fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> {
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let mut state_changed = false;
|
||||
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id) {
|
||||
let key = presence_key(&room_id?, user_id);
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id) {
|
||||
let key = presence_key(&room_id?, user_id);
|
||||
|
||||
let presence_bytes = self.roomuserid_presence.get(&key)?;
|
||||
let presence_bytes = self.roomuserid_presence.get(&key)?;
|
||||
|
||||
if let Some(presence_bytes) = presence_bytes {
|
||||
let presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
if presence.state != new_state {
|
||||
state_changed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(presence_bytes) = presence_bytes {
|
||||
let presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
if presence.state != new_state {
|
||||
state_changed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let count = if state_changed {
|
||||
services().globals.next_count()?
|
||||
} else {
|
||||
services().globals.current_count()?
|
||||
};
|
||||
let count = if state_changed {
|
||||
services().globals.next_count()?
|
||||
} else {
|
||||
services().globals.current_count()?
|
||||
};
|
||||
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id) {
|
||||
let key = presence_key(&room_id?, user_id);
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id) {
|
||||
let key = presence_key(&room_id?, user_id);
|
||||
|
||||
let presence_bytes = self.roomuserid_presence.get(&key)?;
|
||||
let presence_bytes = self.roomuserid_presence.get(&key)?;
|
||||
|
||||
let new_presence = match presence_bytes {
|
||||
Some(presence_bytes) => {
|
||||
let mut presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
presence.state = new_state.clone();
|
||||
presence.currently_active = presence.state == PresenceState::Online;
|
||||
presence.last_active_ts = now;
|
||||
presence.last_count = count;
|
||||
let new_presence = match presence_bytes {
|
||||
Some(presence_bytes) => {
|
||||
let mut presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
presence.state = new_state.clone();
|
||||
presence.currently_active = presence.state == PresenceState::Online;
|
||||
presence.last_active_ts = now;
|
||||
presence.last_count = count;
|
||||
|
||||
presence
|
||||
}
|
||||
None => Presence::new(
|
||||
new_state.clone(),
|
||||
new_state == PresenceState::Online,
|
||||
now,
|
||||
count,
|
||||
None,
|
||||
),
|
||||
};
|
||||
presence
|
||||
},
|
||||
None => Presence::new(new_state.clone(), new_state == PresenceState::Online, now, count, None),
|
||||
};
|
||||
|
||||
self.roomuserid_presence
|
||||
.insert(&key, &new_presence.to_json_bytes()?)?;
|
||||
}
|
||||
self.roomuserid_presence.insert(&key, &new_presence.to_json_bytes()?)?;
|
||||
}
|
||||
|
||||
let timeout = match new_state {
|
||||
PresenceState::Online => services().globals.config.presence_idle_timeout_s,
|
||||
_ => services().globals.config.presence_offline_timeout_s,
|
||||
};
|
||||
let timeout = match new_state {
|
||||
PresenceState::Online => services().globals.config.presence_idle_timeout_s,
|
||||
_ => services().globals.config.presence_offline_timeout_s,
|
||||
};
|
||||
|
||||
self.presence_timer_sender
|
||||
.send((user_id.to_owned(), Duration::from_secs(timeout)))
|
||||
.map_err(|e| {
|
||||
error!("Failed to add presence timer: {}", e);
|
||||
Error::bad_database("Failed to add presence timer")
|
||||
})
|
||||
}
|
||||
self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| {
|
||||
error!("Failed to add presence timer: {}", e);
|
||||
Error::bad_database("Failed to add presence timer")
|
||||
})
|
||||
}
|
||||
|
||||
fn set_presence(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
presence_state: PresenceState,
|
||||
currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>,
|
||||
status_msg: Option<String>,
|
||||
) -> Result<()> {
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let last_active_ts = match last_active_ago {
|
||||
Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
|
||||
None => now,
|
||||
};
|
||||
fn set_presence(
|
||||
&self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>, status_msg: Option<String>,
|
||||
) -> Result<()> {
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let last_active_ts = match last_active_ago {
|
||||
Some(last_active_ago) => now.saturating_sub(last_active_ago.into()),
|
||||
None => now,
|
||||
};
|
||||
|
||||
let key = presence_key(room_id, user_id);
|
||||
let key = presence_key(room_id, user_id);
|
||||
|
||||
let presence = Presence::new(
|
||||
presence_state,
|
||||
currently_active.unwrap_or(false),
|
||||
last_active_ts,
|
||||
services().globals.next_count()?,
|
||||
status_msg,
|
||||
);
|
||||
let presence = Presence::new(
|
||||
presence_state,
|
||||
currently_active.unwrap_or(false),
|
||||
last_active_ts,
|
||||
services().globals.next_count()?,
|
||||
status_msg,
|
||||
);
|
||||
|
||||
let timeout = match presence.state {
|
||||
PresenceState::Online => services().globals.config.presence_idle_timeout_s,
|
||||
_ => services().globals.config.presence_offline_timeout_s,
|
||||
};
|
||||
let timeout = match presence.state {
|
||||
PresenceState::Online => services().globals.config.presence_idle_timeout_s,
|
||||
_ => services().globals.config.presence_offline_timeout_s,
|
||||
};
|
||||
|
||||
self.presence_timer_sender
|
||||
.send((user_id.to_owned(), Duration::from_secs(timeout)))
|
||||
.map_err(|e| {
|
||||
error!("Failed to add presence timer: {}", e);
|
||||
Error::bad_database("Failed to add presence timer")
|
||||
})?;
|
||||
self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| {
|
||||
error!("Failed to add presence timer: {}", e);
|
||||
Error::bad_database("Failed to add presence timer")
|
||||
})?;
|
||||
|
||||
self.roomuserid_presence
|
||||
.insert(&key, &presence.to_json_bytes()?)?;
|
||||
self.roomuserid_presence.insert(&key, &presence.to_json_bytes()?)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_presence(&self, user_id: &UserId) -> Result<()> {
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id) {
|
||||
let key = presence_key(&room_id?, user_id);
|
||||
fn remove_presence(&self, user_id: &UserId) -> Result<()> {
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id) {
|
||||
let key = presence_key(&room_id?, user_id);
|
||||
|
||||
self.roomuserid_presence.remove(&key)?;
|
||||
}
|
||||
self.roomuserid_presence.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn presence_since<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a> {
|
||||
let prefix = [room_id.as_bytes(), &[0xff]].concat();
|
||||
fn presence_since<'a>(
|
||||
&'a self, room_id: &RoomId, since: u64,
|
||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a> {
|
||||
let prefix = [room_id.as_bytes(), &[0xFF]].concat();
|
||||
|
||||
Box::new(
|
||||
self.roomuserid_presence
|
||||
.scan_prefix(prefix)
|
||||
.flat_map(
|
||||
|(key, presence_bytes)| -> Result<(OwnedUserId, u64, PresenceEvent)> {
|
||||
let user_id = user_id_from_bytes(
|
||||
key.rsplit(|byte| *byte == 0xff).next().ok_or_else(|| {
|
||||
Error::bad_database("No UserID bytes in presence key")
|
||||
})?,
|
||||
)?;
|
||||
Box::new(
|
||||
self.roomuserid_presence
|
||||
.scan_prefix(prefix)
|
||||
.flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, PresenceEvent)> {
|
||||
let user_id = user_id_from_bytes(
|
||||
key.rsplit(|byte| *byte == 0xFF)
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("No UserID bytes in presence key"))?,
|
||||
)?;
|
||||
|
||||
let presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
let presence_event = presence.to_presence_event(&user_id)?;
|
||||
let presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
let presence_event = presence.to_presence_event(&user_id)?;
|
||||
|
||||
Ok((user_id, presence.last_count, presence_event))
|
||||
},
|
||||
)
|
||||
.filter(move |(_, count, _)| *count > since),
|
||||
)
|
||||
}
|
||||
Ok((user_id, presence.last_count, presence_event))
|
||||
})
|
||||
.filter(move |(_, count, _)| *count > since),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn presence_key(room_id: &RoomId, user_id: &UserId) -> Vec<u8> {
|
||||
[room_id.as_bytes(), &[0xff], user_id.as_bytes()].concat()
|
||||
[room_id.as_bytes(), &[0xFF], user_id.as_bytes()].concat()
|
||||
}
|
||||
|
|
|
@ -1,150 +1,113 @@
|
|||
use std::mem;
|
||||
|
||||
use ruma::{
|
||||
events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId,
|
||||
};
|
||||
use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
||||
fn readreceipt_update(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
event: ReceiptEvent,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
// Remove old entry
|
||||
if let Some((old, _)) = self
|
||||
.readreceiptid_readreceipt
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
.find(|(key, _)| {
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element")
|
||||
== user_id.as_bytes()
|
||||
})
|
||||
{
|
||||
// This is the old room_latest
|
||||
self.readreceiptid_readreceipt.remove(&old)?;
|
||||
}
|
||||
// Remove old entry
|
||||
if let Some((old, _)) = self
|
||||
.readreceiptid_readreceipt
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
.find(|(key, _)| {
|
||||
key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element") == user_id.as_bytes()
|
||||
}) {
|
||||
// This is the old room_latest
|
||||
self.readreceiptid_readreceipt.remove(&old)?;
|
||||
}
|
||||
|
||||
let mut room_latest_id = prefix;
|
||||
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
room_latest_id.push(0xff);
|
||||
room_latest_id.extend_from_slice(user_id.as_bytes());
|
||||
let mut room_latest_id = prefix;
|
||||
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
room_latest_id.push(0xFF);
|
||||
room_latest_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.readreceiptid_readreceipt.insert(
|
||||
&room_latest_id,
|
||||
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
|
||||
)?;
|
||||
self.readreceiptid_readreceipt.insert(
|
||||
&room_latest_id,
|
||||
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn readreceipts_since<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> Box<
|
||||
dyn Iterator<
|
||||
Item = Result<(
|
||||
OwnedUserId,
|
||||
u64,
|
||||
Raw<ruma::events::AnySyncEphemeralRoomEvent>,
|
||||
)>,
|
||||
> + 'a,
|
||||
> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let prefix2 = prefix.clone();
|
||||
fn readreceipts_since<'a>(
|
||||
&'a self, room_id: &RoomId, since: u64,
|
||||
) -> Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
let prefix2 = prefix.clone();
|
||||
|
||||
let mut first_possible_edu = prefix.clone();
|
||||
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
|
||||
let mut first_possible_edu = prefix.clone();
|
||||
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
|
||||
|
||||
Box::new(
|
||||
self.readreceiptid_readreceipt
|
||||
.iter_from(&first_possible_edu, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(k, v)| {
|
||||
let count = utils::u64_from_bytes(
|
||||
&k[prefix.len()..prefix.len() + mem::size_of::<u64>()],
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid readreceiptid userid bytes in db.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
|
||||
Box::new(
|
||||
self.readreceiptid_readreceipt
|
||||
.iter_from(&first_possible_edu, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(k, v)| {
|
||||
let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::<u64>()])
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
|
||||
|
||||
let mut json =
|
||||
serde_json::from_slice::<CanonicalJsonObject>(&v).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Read receipt in roomlatestid_roomlatest is invalid json.",
|
||||
)
|
||||
})?;
|
||||
json.remove("room_id");
|
||||
let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v)
|
||||
.map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
|
||||
json.remove("room_id");
|
||||
|
||||
Ok((
|
||||
user_id,
|
||||
count,
|
||||
Raw::from_json(
|
||||
serde_json::value::to_raw_value(&json)
|
||||
.expect("json is valid raw value"),
|
||||
),
|
||||
))
|
||||
}),
|
||||
)
|
||||
}
|
||||
Ok((
|
||||
user_id,
|
||||
count,
|
||||
Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")),
|
||||
))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread
|
||||
.insert(&key, &count.to_be_bytes())?;
|
||||
self.roomuserid_privateread.insert(&key, &count.to_be_bytes())?;
|
||||
|
||||
self.roomuserid_lastprivatereadupdate
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())
|
||||
}
|
||||
self.roomuserid_lastprivatereadupdate.insert(&key, &services().globals.next_count()?.to_be_bytes())
|
||||
}
|
||||
|
||||
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |v| {
|
||||
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
|
||||
Error::bad_database("Invalid private read marker bytes")
|
||||
})?))
|
||||
})
|
||||
}
|
||||
self.roomuserid_privateread.get(&key)?.map_or(Ok(None), |v| {
|
||||
Ok(Some(
|
||||
utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
Ok(self
|
||||
.roomuserid_lastprivatereadupdate
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
Ok(self
|
||||
.roomuserid_lastprivatereadupdate
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,123 +5,111 @@ use ruma::{OwnedUserId, RoomId, UserId};
|
|||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
||||
fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let count = services().globals.next_count()?.to_be_bytes();
|
||||
let count = services().globals.next_count()?.to_be_bytes();
|
||||
|
||||
let mut room_typing_id = prefix;
|
||||
room_typing_id.extend_from_slice(&timeout.to_be_bytes());
|
||||
room_typing_id.push(0xff);
|
||||
room_typing_id.extend_from_slice(&count);
|
||||
let mut room_typing_id = prefix;
|
||||
room_typing_id.extend_from_slice(&timeout.to_be_bytes());
|
||||
room_typing_id.push(0xFF);
|
||||
room_typing_id.extend_from_slice(&count);
|
||||
|
||||
self.typingid_userid
|
||||
.insert(&room_typing_id, user_id.as_bytes())?;
|
||||
self.typingid_userid.insert(&room_typing_id, user_id.as_bytes())?;
|
||||
|
||||
self.roomid_lasttypingupdate
|
||||
.insert(room_id.as_bytes(), &count)?;
|
||||
self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &count)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let user_id = user_id.to_string();
|
||||
let user_id = user_id.to_string();
|
||||
|
||||
let mut found_outdated = false;
|
||||
let mut found_outdated = false;
|
||||
|
||||
// Maybe there are multiple ones from calling roomtyping_add multiple times
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.filter(|(_, v)| &**v == user_id.as_bytes())
|
||||
{
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
// Maybe there are multiple ones from calling roomtyping_add multiple times
|
||||
for outdated_edu in self.typingid_userid.scan_prefix(prefix).filter(|(_, v)| &**v == user_id.as_bytes()) {
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate.insert(
|
||||
room_id.as_bytes(),
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
}
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn typings_maintain(&self, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn typings_maintain(&self, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let current_timestamp = utils::millis_since_unix_epoch();
|
||||
let current_timestamp = utils::millis_since_unix_epoch();
|
||||
|
||||
let mut found_outdated = false;
|
||||
let mut found_outdated = false;
|
||||
|
||||
// Find all outdated edus before inserting a new one
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
Ok::<_, Error>((
|
||||
key.clone(),
|
||||
utils::u64_from_bytes(
|
||||
&key.splitn(2, |&b| b == 0xff).nth(1).ok_or_else(|| {
|
||||
Error::bad_database("RoomTyping has invalid timestamp or delimiters.")
|
||||
})?[0..mem::size_of::<u64>()],
|
||||
)
|
||||
.map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?,
|
||||
))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.take_while(|&(_, timestamp)| timestamp < current_timestamp)
|
||||
{
|
||||
// This is an outdated edu (time > timestamp)
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
// Find all outdated edus before inserting a new one
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
Ok::<_, Error>((
|
||||
key.clone(),
|
||||
utils::u64_from_bytes(
|
||||
&key.splitn(2, |&b| b == 0xFF)
|
||||
.nth(1)
|
||||
.ok_or_else(|| Error::bad_database("RoomTyping has invalid timestamp or delimiters."))?[0..mem::size_of::<u64>()],
|
||||
)
|
||||
.map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?,
|
||||
))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.take_while(|&(_, timestamp)| timestamp < current_timestamp)
|
||||
{
|
||||
// This is an outdated edu (time > timestamp)
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate.insert(
|
||||
room_id.as_bytes(),
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
}
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> {
|
||||
Ok(self
|
||||
.roomid_lasttypingupdate
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> {
|
||||
Ok(self
|
||||
.roomid_lasttypingupdate
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid."))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
fn typings_all(&self, room_id: &RoomId) -> Result<HashSet<OwnedUserId>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn typings_all(&self, room_id: &RoomId) -> Result<HashSet<OwnedUserId>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut user_ids = HashSet::new();
|
||||
let mut user_ids = HashSet::new();
|
||||
|
||||
for (_, user_id) in self.typingid_userid.scan_prefix(prefix) {
|
||||
let user_id = UserId::parse(utils::string_from_bytes(&user_id).map_err(|_| {
|
||||
Error::bad_database("User ID in typingid_userid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?;
|
||||
for (_, user_id) in self.typingid_userid.scan_prefix(prefix) {
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(&user_id)
|
||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?;
|
||||
|
||||
user_ids.insert(user_id);
|
||||
}
|
||||
user_ids.insert(user_id);
|
||||
}
|
||||
|
||||
Ok(user_ids)
|
||||
}
|
||||
Ok(user_ids)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,63 +3,51 @@ use ruma::{DeviceId, RoomId, UserId};
|
|||
use crate::{database::KeyValueDatabase, service, Result};
|
||||
|
||||
impl service::rooms::lazy_loading::Data for KeyValueDatabase {
|
||||
fn lazy_load_was_sent_before(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
ll_user: &UserId,
|
||||
) -> Result<bool> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(device_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(ll_user.as_bytes());
|
||||
Ok(self.lazyloadedids.get(&key)?.is_some())
|
||||
}
|
||||
fn lazy_load_was_sent_before(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
|
||||
) -> Result<bool> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(ll_user.as_bytes());
|
||||
Ok(self.lazyloadedids.get(&key)?.is_some())
|
||||
}
|
||||
|
||||
fn lazy_load_confirm_delivery(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
|
||||
) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
fn lazy_load_confirm_delivery(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
|
||||
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
|
||||
) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
for ll_id in confirmed_user_ids {
|
||||
let mut key = prefix.clone();
|
||||
key.extend_from_slice(ll_id.as_bytes());
|
||||
self.lazyloadedids.insert(&key, &[])?;
|
||||
}
|
||||
for ll_id in confirmed_user_ids {
|
||||
let mut key = prefix.clone();
|
||||
key.extend_from_slice(ll_id.as_bytes());
|
||||
self.lazyloadedids.insert(&key, &[])?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn lazy_load_reset(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
|
||||
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
|
||||
self.lazyloadedids.remove(&key)?;
|
||||
}
|
||||
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
|
||||
self.lazyloadedids.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,76 +4,68 @@ use tracing::error;
|
|||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::metadata::Data for KeyValueDatabase {
|
||||
fn exists(&self, room_id: &RoomId) -> Result<bool> {
|
||||
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
|
||||
Some(b) => b.to_be_bytes().to_vec(),
|
||||
None => return Ok(false),
|
||||
};
|
||||
fn exists(&self, room_id: &RoomId) -> Result<bool> {
|
||||
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
|
||||
Some(b) => b.to_be_bytes().to_vec(),
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
// Look for PDUs in that room.
|
||||
Ok(self
|
||||
.pduid_pdu
|
||||
.iter_from(&prefix, false)
|
||||
.next()
|
||||
.filter(|(k, _)| k.starts_with(&prefix))
|
||||
.is_some())
|
||||
}
|
||||
// Look for PDUs in that room.
|
||||
Ok(self.pduid_pdu.iter_from(&prefix, false).next().filter(|(k, _)| k.starts_with(&prefix)).is_some())
|
||||
}
|
||||
|
||||
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Room ID in publicroomids is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
|
||||
}))
|
||||
}
|
||||
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
|
||||
}))
|
||||
}
|
||||
|
||||
fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
|
||||
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
|
||||
if disabled {
|
||||
self.disabledroomids.insert(room_id.as_bytes(), &[])?;
|
||||
} else {
|
||||
self.disabledroomids.remove(room_id.as_bytes())?;
|
||||
}
|
||||
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
|
||||
if disabled {
|
||||
self.disabledroomids.insert(room_id.as_bytes(), &[])?;
|
||||
} else {
|
||||
self.disabledroomids.remove(room_id.as_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_banned(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
fn is_banned(&self, room_id: &RoomId) -> Result<bool> { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) }
|
||||
|
||||
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
|
||||
if banned {
|
||||
self.bannedroomids.insert(room_id.as_bytes(), &[])?;
|
||||
} else {
|
||||
self.bannedroomids.remove(room_id.as_bytes())?;
|
||||
}
|
||||
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
|
||||
if banned {
|
||||
self.bannedroomids.insert(room_id.as_bytes(), &[])?;
|
||||
} else {
|
||||
self.bannedroomids.remove(room_id.as_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.bannedroomids.iter().map(
|
||||
|(room_id_bytes, _ /* non-banned rooms should not be in this table */)| {
|
||||
let room_id = utils::string_from_bytes(&room_id_bytes)
|
||||
.map_err(|e| {
|
||||
error!("Invalid room_id bytes in bannedroomids: {e}");
|
||||
Error::bad_database("Invalid room_id in bannedroomids.")
|
||||
})?
|
||||
.try_into()
|
||||
.map_err(|e| {
|
||||
error!("Invalid room_id in bannedroomids: {e}");
|
||||
Error::bad_database("Invalid room_id in bannedroomids")
|
||||
})?;
|
||||
fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.bannedroomids.iter().map(
|
||||
|(room_id_bytes, _ /* non-banned rooms should not be in this table */)| {
|
||||
let room_id = utils::string_from_bytes(&room_id_bytes)
|
||||
.map_err(|e| {
|
||||
error!("Invalid room_id bytes in bannedroomids: {e}");
|
||||
Error::bad_database("Invalid room_id in bannedroomids.")
|
||||
})?
|
||||
.try_into()
|
||||
.map_err(|e| {
|
||||
error!("Invalid room_id in bannedroomids: {e}");
|
||||
Error::bad_database("Invalid room_id in bannedroomids")
|
||||
})?;
|
||||
|
||||
Ok(room_id)
|
||||
},
|
||||
))
|
||||
}
|
||||
Ok(room_id)
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,26 +3,22 @@ use ruma::{CanonicalJsonObject, EventId};
|
|||
use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result};
|
||||
|
||||
impl service::rooms::outlier::Data for KeyValueDatabase {
|
||||
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
}
|
||||
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
}
|
||||
|
||||
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
}
|
||||
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
}
|
||||
|
||||
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
|
||||
self.eventid_outlierpdu.insert(
|
||||
event_id.as_bytes(),
|
||||
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
|
||||
)
|
||||
}
|
||||
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
|
||||
self.eventid_outlierpdu.insert(
|
||||
event_id.as_bytes(),
|
||||
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,85 +3,78 @@ use std::{mem, sync::Arc};
|
|||
use ruma::{EventId, RoomId, UserId};
|
||||
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::timeline::PduCount},
|
||||
services, utils, Error, PduEvent, Result,
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::timeline::PduCount},
|
||||
services, utils, Error, PduEvent, Result,
|
||||
};
|
||||
|
||||
impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
|
||||
fn add_relation(&self, from: u64, to: u64) -> Result<()> {
|
||||
let mut key = to.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&from.to_be_bytes());
|
||||
self.tofrom_relation.insert(&key, &[])?;
|
||||
Ok(())
|
||||
}
|
||||
fn add_relation(&self, from: u64, to: u64) -> Result<()> {
|
||||
let mut key = to.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&from.to_be_bytes());
|
||||
self.tofrom_relation.insert(&key, &[])?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn relations_until<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
shortroomid: u64,
|
||||
target: u64,
|
||||
until: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let prefix = target.to_be_bytes().to_vec();
|
||||
let mut current = prefix.clone();
|
||||
fn relations_until<'a>(
|
||||
&'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let prefix = target.to_be_bytes().to_vec();
|
||||
let mut current = prefix.clone();
|
||||
|
||||
let count_raw = match until {
|
||||
PduCount::Normal(x) => x - 1,
|
||||
PduCount::Backfilled(x) => {
|
||||
current.extend_from_slice(&0_u64.to_be_bytes());
|
||||
u64::MAX - x - 1
|
||||
}
|
||||
};
|
||||
current.extend_from_slice(&count_raw.to_be_bytes());
|
||||
let count_raw = match until {
|
||||
PduCount::Normal(x) => x - 1,
|
||||
PduCount::Backfilled(x) => {
|
||||
current.extend_from_slice(&0_u64.to_be_bytes());
|
||||
u64::MAX - x - 1
|
||||
},
|
||||
};
|
||||
current.extend_from_slice(&count_raw.to_be_bytes());
|
||||
|
||||
Ok(Box::new(
|
||||
self.tofrom_relation
|
||||
.iter_from(¤t, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(tofrom, _data)| {
|
||||
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
|
||||
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
|
||||
Ok(Box::new(
|
||||
self.tofrom_relation.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||
move |(tofrom, _data)| {
|
||||
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
|
||||
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
|
||||
|
||||
let mut pduid = shortroomid.to_be_bytes().to_vec();
|
||||
pduid.extend_from_slice(&from.to_be_bytes());
|
||||
let mut pduid = shortroomid.to_be_bytes().to_vec();
|
||||
pduid.extend_from_slice(&from.to_be_bytes());
|
||||
|
||||
let mut pdu = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((PduCount::Normal(from), pdu))
|
||||
}),
|
||||
))
|
||||
}
|
||||
let mut pdu = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((PduCount::Normal(from), pdu))
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
|
||||
for prev in event_ids {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.extend_from_slice(prev.as_bytes());
|
||||
self.referencedevents.insert(&key, &[])?;
|
||||
}
|
||||
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
|
||||
for prev in event_ids {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.extend_from_slice(prev.as_bytes());
|
||||
self.referencedevents.insert(&key, &[])?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.extend_from_slice(event_id.as_bytes());
|
||||
Ok(self.referencedevents.get(&key)?.is_some())
|
||||
}
|
||||
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.extend_from_slice(event_id.as_bytes());
|
||||
Ok(self.referencedevents.get(&key)?.is_some())
|
||||
}
|
||||
|
||||
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
|
||||
self.softfailedeventids.insert(event_id.as_bytes(), &[])
|
||||
}
|
||||
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
|
||||
self.softfailedeventids.insert(event_id.as_bytes(), &[])
|
||||
}
|
||||
|
||||
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
|
||||
self.softfailedeventids
|
||||
.get(event_id.as_bytes())
|
||||
.map(|o| o.is_some())
|
||||
}
|
||||
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
|
||||
self.softfailedeventids.get(event_id.as_bytes()).map(|o| o.is_some())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,61 +5,55 @@ use crate::{database::KeyValueDatabase, service, services, utils, Result};
|
|||
type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
|
||||
|
||||
impl service::rooms::search::Data for KeyValueDatabase {
|
||||
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
|
||||
let mut batch = message_body
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty())
|
||||
.filter(|word| word.len() <= 50)
|
||||
.map(str::to_lowercase)
|
||||
.map(|word| {
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(word.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
|
||||
(key, Vec::new())
|
||||
});
|
||||
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
|
||||
let mut batch = message_body
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty())
|
||||
.filter(|word| word.len() <= 50)
|
||||
.map(str::to_lowercase)
|
||||
.map(|word| {
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(word.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
|
||||
(key, Vec::new())
|
||||
});
|
||||
|
||||
self.tokenids.insert_batch(&mut batch)
|
||||
}
|
||||
self.tokenids.insert_batch(&mut batch)
|
||||
}
|
||||
|
||||
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
|
||||
let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec();
|
||||
|
||||
let words: Vec<_> = search_string
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(str::to_lowercase)
|
||||
.collect();
|
||||
let words: Vec<_> = search_string
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(str::to_lowercase)
|
||||
.collect();
|
||||
|
||||
let iterators = words.clone().into_iter().map(move |word| {
|
||||
let mut prefix2 = prefix.clone();
|
||||
prefix2.extend_from_slice(word.as_bytes());
|
||||
prefix2.push(0xff);
|
||||
let prefix3 = prefix2.clone();
|
||||
let iterators = words.clone().into_iter().map(move |word| {
|
||||
let mut prefix2 = prefix.clone();
|
||||
prefix2.extend_from_slice(word.as_bytes());
|
||||
prefix2.push(0xFF);
|
||||
let prefix3 = prefix2.clone();
|
||||
|
||||
let mut last_possible_id = prefix2.clone();
|
||||
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
let mut last_possible_id = prefix2.clone();
|
||||
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
self.tokenids
|
||||
.iter_from(&last_possible_id, true) // Newest pdus first
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(key, _)| key[prefix3.len()..].to_vec())
|
||||
});
|
||||
self.tokenids
|
||||
.iter_from(&last_possible_id, true) // Newest pdus first
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(key, _)| key[prefix3.len()..].to_vec())
|
||||
});
|
||||
|
||||
let common_elements = match utils::common_elements(iterators, |a, b| {
|
||||
// We compare b with a because we reversed the iterator earlier
|
||||
b.cmp(a)
|
||||
}) {
|
||||
Some(it) => it,
|
||||
None => return Ok(None),
|
||||
};
|
||||
let common_elements = match utils::common_elements(iterators, |a, b| {
|
||||
// We compare b with a because we reversed the iterator earlier
|
||||
b.cmp(a)
|
||||
}) {
|
||||
Some(it) => it,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
Ok(Some((Box::new(common_elements), words)))
|
||||
}
|
||||
Ok(Some((Box::new(common_elements), words)))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,214 +6,165 @@ use tracing::warn;
|
|||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::short::Data for KeyValueDatabase {
|
||||
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
|
||||
if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
|
||||
return Ok(*short);
|
||||
}
|
||||
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
|
||||
if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
|
||||
return Ok(*short);
|
||||
}
|
||||
|
||||
let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
|
||||
Some(shorteventid) => utils::u64_from_bytes(&shorteventid)
|
||||
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
|
||||
None => {
|
||||
let shorteventid = services().globals.next_count()?;
|
||||
self.eventid_shorteventid
|
||||
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
|
||||
self.shorteventid_eventid
|
||||
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
|
||||
shorteventid
|
||||
}
|
||||
};
|
||||
let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
|
||||
Some(shorteventid) => {
|
||||
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
|
||||
},
|
||||
None => {
|
||||
let shorteventid = services().globals.next_count()?;
|
||||
self.eventid_shorteventid.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
|
||||
self.shorteventid_eventid.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
|
||||
shorteventid
|
||||
},
|
||||
};
|
||||
|
||||
self.eventidshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(event_id.to_owned(), short);
|
||||
self.eventidshort_cache.lock().unwrap().insert(event_id.to_owned(), short);
|
||||
|
||||
Ok(short)
|
||||
}
|
||||
Ok(short)
|
||||
}
|
||||
|
||||
fn get_shortstatekey(
|
||||
&self,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<u64>> {
|
||||
if let Some(short) = self
|
||||
.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
{
|
||||
return Ok(Some(*short));
|
||||
}
|
||||
fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
|
||||
if let Some(short) =
|
||||
self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
{
|
||||
return Ok(Some(*short));
|
||||
}
|
||||
|
||||
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
||||
statekey_vec.push(0xff);
|
||||
statekey_vec.extend_from_slice(state_key.as_bytes());
|
||||
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
||||
statekey_vec.push(0xFF);
|
||||
statekey_vec.extend_from_slice(state_key.as_bytes());
|
||||
|
||||
let short = self
|
||||
.statekey_shortstatekey
|
||||
.get(&statekey_vec)?
|
||||
.map(|shortstatekey| {
|
||||
utils::u64_from_bytes(&shortstatekey)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
|
||||
})
|
||||
.transpose()?;
|
||||
let short = self
|
||||
.statekey_shortstatekey
|
||||
.get(&statekey_vec)?
|
||||
.map(|shortstatekey| {
|
||||
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
if let Some(s) = short {
|
||||
self.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert((event_type.clone(), state_key.to_owned()), s);
|
||||
}
|
||||
if let Some(s) = short {
|
||||
self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), s);
|
||||
}
|
||||
|
||||
Ok(short)
|
||||
}
|
||||
Ok(short)
|
||||
}
|
||||
|
||||
fn get_or_create_shortstatekey(
|
||||
&self,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<u64> {
|
||||
if let Some(short) = self
|
||||
.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
{
|
||||
return Ok(*short);
|
||||
}
|
||||
fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
|
||||
if let Some(short) =
|
||||
self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
{
|
||||
return Ok(*short);
|
||||
}
|
||||
|
||||
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
||||
statekey_vec.push(0xff);
|
||||
statekey_vec.extend_from_slice(state_key.as_bytes());
|
||||
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
||||
statekey_vec.push(0xFF);
|
||||
statekey_vec.extend_from_slice(state_key.as_bytes());
|
||||
|
||||
let short = match self.statekey_shortstatekey.get(&statekey_vec)? {
|
||||
Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
|
||||
None => {
|
||||
let shortstatekey = services().globals.next_count()?;
|
||||
self.statekey_shortstatekey
|
||||
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
|
||||
self.shortstatekey_statekey
|
||||
.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
|
||||
shortstatekey
|
||||
}
|
||||
};
|
||||
let short = match self.statekey_shortstatekey.get(&statekey_vec)? {
|
||||
Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
|
||||
None => {
|
||||
let shortstatekey = services().globals.next_count()?;
|
||||
self.statekey_shortstatekey.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
|
||||
self.shortstatekey_statekey.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
|
||||
shortstatekey
|
||||
},
|
||||
};
|
||||
|
||||
self.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert((event_type.clone(), state_key.to_owned()), short);
|
||||
self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), short);
|
||||
|
||||
Ok(short)
|
||||
}
|
||||
Ok(short)
|
||||
}
|
||||
|
||||
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
|
||||
if let Some(id) = self
|
||||
.shorteventid_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&shorteventid)
|
||||
{
|
||||
return Ok(Arc::clone(id));
|
||||
}
|
||||
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
|
||||
if let Some(id) = self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid) {
|
||||
return Ok(Arc::clone(id));
|
||||
}
|
||||
|
||||
let bytes = self
|
||||
.shorteventid_eventid
|
||||
.get(&shorteventid.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
|
||||
let bytes = self
|
||||
.shorteventid_eventid
|
||||
.get(&shorteventid.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
|
||||
|
||||
let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
|
||||
let event_id = EventId::parse_arc(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
|
||||
|
||||
self.shorteventid_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(shorteventid, Arc::clone(&event_id));
|
||||
self.shorteventid_cache.lock().unwrap().insert(shorteventid, Arc::clone(&event_id));
|
||||
|
||||
Ok(event_id)
|
||||
}
|
||||
Ok(event_id)
|
||||
}
|
||||
|
||||
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
|
||||
if let Some(id) = self
|
||||
.shortstatekey_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&shortstatekey)
|
||||
{
|
||||
return Ok(id.clone());
|
||||
}
|
||||
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
|
||||
if let Some(id) = self.shortstatekey_cache.lock().unwrap().get_mut(&shortstatekey) {
|
||||
return Ok(id.clone());
|
||||
}
|
||||
|
||||
let bytes = self
|
||||
.shortstatekey_statekey
|
||||
.get(&shortstatekey.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
|
||||
let bytes = self
|
||||
.shortstatekey_statekey
|
||||
.get(&shortstatekey.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
|
||||
|
||||
let mut parts = bytes.splitn(2, |&b| b == 0xff);
|
||||
let eventtype_bytes = parts.next().expect("split always returns one entry");
|
||||
let statekey_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
|
||||
let mut parts = bytes.splitn(2, |&b| b == 0xFF);
|
||||
let eventtype_bytes = parts.next().expect("split always returns one entry");
|
||||
let statekey_bytes =
|
||||
parts.next().ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
|
||||
|
||||
let event_type =
|
||||
StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
|
||||
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
|
||||
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
|
||||
})?);
|
||||
let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
|
||||
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
|
||||
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
|
||||
})?);
|
||||
|
||||
let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| {
|
||||
Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
|
||||
})?;
|
||||
let state_key = utils::string_from_bytes(statekey_bytes)
|
||||
.map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?;
|
||||
|
||||
let result = (event_type, state_key);
|
||||
let result = (event_type, state_key);
|
||||
|
||||
self.shortstatekey_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(shortstatekey, result.clone());
|
||||
self.shortstatekey_cache.lock().unwrap().insert(shortstatekey, result.clone());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Returns (shortstatehash, already_existed)
|
||||
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
|
||||
Ok(match self.statehash_shortstatehash.get(state_hash)? {
|
||||
Some(shortstatehash) => (
|
||||
utils::u64_from_bytes(&shortstatehash)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
|
||||
true,
|
||||
),
|
||||
None => {
|
||||
let shortstatehash = services().globals.next_count()?;
|
||||
self.statehash_shortstatehash
|
||||
.insert(state_hash, &shortstatehash.to_be_bytes())?;
|
||||
(shortstatehash, false)
|
||||
}
|
||||
})
|
||||
}
|
||||
/// Returns (shortstatehash, already_existed)
|
||||
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
|
||||
Ok(match self.statehash_shortstatehash.get(state_hash)? {
|
||||
Some(shortstatehash) => (
|
||||
utils::u64_from_bytes(&shortstatehash)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
|
||||
true,
|
||||
),
|
||||
None => {
|
||||
let shortstatehash = services().globals.next_count()?;
|
||||
self.statehash_shortstatehash.insert(state_hash, &shortstatehash.to_be_bytes())?;
|
||||
(shortstatehash, false)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortroomid
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortroomid
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
|
||||
Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
|
||||
Some(short) => utils::u64_from_bytes(&short)
|
||||
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
|
||||
None => {
|
||||
let short = services().globals.next_count()?;
|
||||
self.roomid_shortroomid
|
||||
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
|
||||
short
|
||||
}
|
||||
})
|
||||
}
|
||||
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
|
||||
Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
|
||||
Some(short) => {
|
||||
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
|
||||
},
|
||||
None => {
|
||||
let short = services().globals.next_count()?;
|
||||
self.roomid_shortroomid.insert(room_id.as_bytes(), &short.to_be_bytes())?;
|
||||
short
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,73 +1,69 @@
|
|||
use ruma::{EventId, OwnedEventId, RoomId};
|
||||
use std::collections::HashSet;
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use std::sync::Arc;
|
||||
use ruma::{EventId, OwnedEventId, RoomId};
|
||||
use tokio::sync::MutexGuard;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::rooms::state::Data for KeyValueDatabase {
|
||||
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortstatehash
|
||||
.get(room_id.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
|
||||
})?))
|
||||
})
|
||||
}
|
||||
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortstatehash.get(room_id.as_bytes())?.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
|
||||
})?))
|
||||
})
|
||||
}
|
||||
|
||||
fn set_room_state(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
new_shortstatehash: u64,
|
||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||
) -> Result<()> {
|
||||
self.roomid_shortstatehash
|
||||
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
fn set_room_state(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
new_shortstatehash: u64,
|
||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||
) -> Result<()> {
|
||||
self.roomid_shortstatehash.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
|
||||
self.shorteventid_shortstatehash
|
||||
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
|
||||
self.shorteventid_shortstatehash.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
self.roomid_pduleaves
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, bytes)| {
|
||||
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
self.roomid_pduleaves
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, bytes)| {
|
||||
EventId::parse_arc(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn set_forward_extremities(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_ids: Vec<OwnedEventId>,
|
||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn set_forward_extremities(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_ids: Vec<OwnedEventId>,
|
||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
|
||||
self.roomid_pduleaves.remove(&key)?;
|
||||
}
|
||||
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
|
||||
self.roomid_pduleaves.remove(&key)?;
|
||||
}
|
||||
|
||||
for event_id in event_ids {
|
||||
let mut key = prefix.clone();
|
||||
key.extend_from_slice(event_id.as_bytes());
|
||||
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
|
||||
}
|
||||
for event_id in event_ids {
|
||||
let mut key = prefix.clone();
|
||||
key.extend_from_slice(event_id.as_bytes());
|
||||
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,186 +1,144 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||
use async_trait::async_trait;
|
||||
use ruma::{events::StateEventType, EventId, RoomId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||
|
||||
#[async_trait]
|
||||
impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
||||
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
let mut result = HashMap::new();
|
||||
let mut i = 0;
|
||||
for compressed in full_state.iter() {
|
||||
let parsed = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.parse_compressed_state_event(compressed)?;
|
||||
result.insert(parsed.0, parsed.1);
|
||||
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
let mut result = HashMap::new();
|
||||
let mut i = 0;
|
||||
for compressed in full_state.iter() {
|
||||
let parsed = services().rooms.state_compressor.parse_compressed_state_event(compressed)?;
|
||||
result.insert(parsed.0, parsed.1);
|
||||
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn state_full(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
|
||||
let mut result = HashMap::new();
|
||||
let mut i = 0;
|
||||
for compressed in full_state.iter() {
|
||||
let (_, eventid) = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.parse_compressed_state_event(compressed)?;
|
||||
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
|
||||
result.insert(
|
||||
(
|
||||
pdu.kind.to_string().into(),
|
||||
pdu.state_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::bad_database("State event has no state key."))?
|
||||
.clone(),
|
||||
),
|
||||
pdu,
|
||||
);
|
||||
}
|
||||
let mut result = HashMap::new();
|
||||
let mut i = 0;
|
||||
for compressed in full_state.iter() {
|
||||
let (_, eventid) = services().rooms.state_compressor.parse_compressed_state_event(compressed)?;
|
||||
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
|
||||
result.insert(
|
||||
(
|
||||
pdu.kind.to_string().into(),
|
||||
pdu.state_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::bad_database("State event has no state key."))?
|
||||
.clone(),
|
||||
),
|
||||
pdu,
|
||||
);
|
||||
}
|
||||
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
fn state_get_id(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<Arc<EventId>>> {
|
||||
let shortstatekey = match services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortstatekey(event_type, state_key)?
|
||||
{
|
||||
Some(s) => s,
|
||||
None => return Ok(None),
|
||||
};
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
Ok(full_state
|
||||
.iter()
|
||||
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
|
||||
.and_then(|compressed| {
|
||||
services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.parse_compressed_state_event(compressed)
|
||||
.ok()
|
||||
.map(|(_, id)| id)
|
||||
}))
|
||||
}
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||
/// `state_key`).
|
||||
fn state_get_id(
|
||||
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
|
||||
) -> Result<Option<Arc<EventId>>> {
|
||||
let shortstatekey = match services().rooms.short.get_shortstatekey(event_type, state_key)? {
|
||||
Some(s) => s,
|
||||
None => return Ok(None),
|
||||
};
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
Ok(
|
||||
full_state.iter().find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())).and_then(|compressed| {
|
||||
services().rooms.state_compressor.parse_compressed_state_event(compressed).ok().map(|(_, id)| id)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
fn state_get(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<Arc<PduEvent>>> {
|
||||
self.state_get_id(shortstatehash, event_type, state_key)?
|
||||
.map_or(Ok(None), |event_id| {
|
||||
services().rooms.timeline.get_pdu(&event_id)
|
||||
})
|
||||
}
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||
/// `state_key`).
|
||||
fn state_get(
|
||||
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
|
||||
) -> Result<Option<Arc<PduEvent>>> {
|
||||
self.state_get_id(shortstatehash, event_type, state_key)?
|
||||
.map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id))
|
||||
}
|
||||
|
||||
/// Returns the state hash for this pdu.
|
||||
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
|
||||
self.eventid_shorteventid
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |shorteventid| {
|
||||
self.shorteventid_shortstatehash
|
||||
.get(&shorteventid)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Invalid shortstatehash bytes in shorteventid_shortstatehash",
|
||||
)
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
}
|
||||
/// Returns the state hash for this pdu.
|
||||
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
|
||||
self.eventid_shorteventid.get(event_id.as_bytes())?.map_or(Ok(None), |shorteventid| {
|
||||
self.shorteventid_shortstatehash
|
||||
.get(&shorteventid)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash"))
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the full room state.
|
||||
async fn room_state_full(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) =
|
||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
||||
{
|
||||
self.state_full(current_shortstatehash).await
|
||||
} else {
|
||||
Ok(HashMap::new())
|
||||
}
|
||||
}
|
||||
/// Returns the full room state.
|
||||
async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
|
||||
self.state_full(current_shortstatehash).await
|
||||
} else {
|
||||
Ok(HashMap::new())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
fn room_state_get_id(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<Arc<EventId>>> {
|
||||
if let Some(current_shortstatehash) =
|
||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
||||
{
|
||||
self.state_get_id(current_shortstatehash, event_type, state_key)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||
/// `state_key`).
|
||||
fn room_state_get_id(
|
||||
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||
) -> Result<Option<Arc<EventId>>> {
|
||||
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
|
||||
self.state_get_id(current_shortstatehash, event_type, state_key)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
fn room_state_get(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) =
|
||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
||||
{
|
||||
self.state_get(current_shortstatehash, event_type, state_key)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||
/// `state_key`).
|
||||
fn room_state_get(
|
||||
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||
) -> Result<Option<Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
|
||||
self.state_get(current_shortstatehash, event_type, state_key)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,61 +1,63 @@
|
|||
use std::{collections::HashSet, mem::size_of, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::state_compressor::data::StateDiff},
|
||||
utils, Error, Result,
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::state_compressor::data::StateDiff},
|
||||
utils, Error, Result,
|
||||
};
|
||||
|
||||
impl service::rooms::state_compressor::Data for KeyValueDatabase {
|
||||
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
|
||||
let value = self
|
||||
.shortstatehash_statediff
|
||||
.get(&shortstatehash.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
|
||||
let parent =
|
||||
utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
|
||||
let parent = if parent != 0 { Some(parent) } else { None };
|
||||
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
|
||||
let value = self
|
||||
.shortstatehash_statediff
|
||||
.get(&shortstatehash.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
|
||||
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
|
||||
let parent = if parent != 0 {
|
||||
Some(parent)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut add_mode = true;
|
||||
let mut added = HashSet::new();
|
||||
let mut removed = HashSet::new();
|
||||
let mut add_mode = true;
|
||||
let mut added = HashSet::new();
|
||||
let mut removed = HashSet::new();
|
||||
|
||||
let mut i = size_of::<u64>();
|
||||
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
|
||||
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
|
||||
add_mode = false;
|
||||
i += size_of::<u64>();
|
||||
continue;
|
||||
}
|
||||
if add_mode {
|
||||
added.insert(v.try_into().expect("we checked the size above"));
|
||||
} else {
|
||||
removed.insert(v.try_into().expect("we checked the size above"));
|
||||
}
|
||||
i += 2 * size_of::<u64>();
|
||||
}
|
||||
let mut i = size_of::<u64>();
|
||||
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
|
||||
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
|
||||
add_mode = false;
|
||||
i += size_of::<u64>();
|
||||
continue;
|
||||
}
|
||||
if add_mode {
|
||||
added.insert(v.try_into().expect("we checked the size above"));
|
||||
} else {
|
||||
removed.insert(v.try_into().expect("we checked the size above"));
|
||||
}
|
||||
i += 2 * size_of::<u64>();
|
||||
}
|
||||
|
||||
Ok(StateDiff {
|
||||
parent,
|
||||
added: Arc::new(added),
|
||||
removed: Arc::new(removed),
|
||||
})
|
||||
}
|
||||
Ok(StateDiff {
|
||||
parent,
|
||||
added: Arc::new(added),
|
||||
removed: Arc::new(removed),
|
||||
})
|
||||
}
|
||||
|
||||
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
|
||||
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
|
||||
for new in diff.added.iter() {
|
||||
value.extend_from_slice(&new[..]);
|
||||
}
|
||||
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
|
||||
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
|
||||
for new in diff.added.iter() {
|
||||
value.extend_from_slice(&new[..]);
|
||||
}
|
||||
|
||||
if !diff.removed.is_empty() {
|
||||
value.extend_from_slice(&0_u64.to_be_bytes());
|
||||
for removed in diff.removed.iter() {
|
||||
value.extend_from_slice(&removed[..]);
|
||||
}
|
||||
}
|
||||
if !diff.removed.is_empty() {
|
||||
value.extend_from_slice(&0_u64.to_be_bytes());
|
||||
for removed in diff.removed.iter() {
|
||||
value.extend_from_slice(&removed[..]);
|
||||
}
|
||||
}
|
||||
|
||||
self.shortstatehash_statediff
|
||||
.insert(&shortstatehash.to_be_bytes(), &value)
|
||||
}
|
||||
self.shortstatehash_statediff.insert(&shortstatehash.to_be_bytes(), &value)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,74 +7,58 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEven
|
|||
type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>;
|
||||
|
||||
impl service::rooms::threads::Data for KeyValueDatabase {
|
||||
fn threads_until<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
room_id: &'a RoomId,
|
||||
until: u64,
|
||||
_include: &'a IncludeThreads,
|
||||
) -> PduEventIterResult<'a> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
fn threads_until<'a>(
|
||||
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
|
||||
) -> PduEventIterResult<'a> {
|
||||
let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec();
|
||||
|
||||
let mut current = prefix.clone();
|
||||
current.extend_from_slice(&(until - 1).to_be_bytes());
|
||||
let mut current = prefix.clone();
|
||||
current.extend_from_slice(&(until - 1).to_be_bytes());
|
||||
|
||||
Ok(Box::new(
|
||||
self.threadid_userids
|
||||
.iter_from(¤t, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(pduid, _users)| {
|
||||
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
|
||||
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
|
||||
let mut pdu = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(&pduid)?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_database("Invalid pduid reference in threadid_userids")
|
||||
})?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((count, pdu))
|
||||
}),
|
||||
))
|
||||
}
|
||||
Ok(Box::new(
|
||||
self.threadid_userids.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||
move |(pduid, _users)| {
|
||||
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
|
||||
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
|
||||
let mut pdu = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((count, pdu))
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
|
||||
let users = participants
|
||||
.iter()
|
||||
.map(|user| user.as_bytes())
|
||||
.collect::<Vec<_>>()
|
||||
.join(&[0xff][..]);
|
||||
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
|
||||
let users = participants.iter().map(|user| user.as_bytes()).collect::<Vec<_>>().join(&[0xFF][..]);
|
||||
|
||||
self.threadid_userids.insert(root_id, &users)?;
|
||||
self.threadid_userids.insert(root_id, &users)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
|
||||
if let Some(users) = self.threadid_userids.get(root_id)? {
|
||||
Ok(Some(
|
||||
users
|
||||
.split(|b| *b == 0xff)
|
||||
.map(|bytes| {
|
||||
UserId::parse(utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid UserId bytes in threadid_userids.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect(),
|
||||
))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
|
||||
if let Some(users) = self.threadid_userids.get(root_id)? {
|
||||
Ok(Some(
|
||||
users
|
||||
.split(|b| *b == 0xFF)
|
||||
.map(|bytes| {
|
||||
UserId::parse(
|
||||
utils::string_from_bytes(bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect(),
|
||||
))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,364 +1,286 @@
|
|||
use std::{collections::hash_map, mem::size_of, sync::Arc};
|
||||
|
||||
use ruma::{
|
||||
api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId,
|
||||
};
|
||||
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
|
||||
use service::rooms::timeline::PduCount;
|
||||
use tracing::error;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||
|
||||
use service::rooms::timeline::PduCount;
|
||||
|
||||
impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
|
||||
match self
|
||||
.lasttimelinecount_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.entry(room_id.to_owned())
|
||||
{
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
if let Some(last_count) = self
|
||||
.pdus_until(sender_user, room_id, PduCount::max())?
|
||||
.find_map(|r| {
|
||||
// Filter out buggy events
|
||||
if r.is_err() {
|
||||
error!("Bad pdu in pdus_since: {:?}", r);
|
||||
}
|
||||
r.ok()
|
||||
})
|
||||
{
|
||||
Ok(*v.insert(last_count.0))
|
||||
} else {
|
||||
Ok(PduCount::Normal(0))
|
||||
}
|
||||
}
|
||||
hash_map::Entry::Occupied(o) => Ok(*o.get()),
|
||||
}
|
||||
}
|
||||
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
|
||||
match self.lasttimelinecount_cache.lock().unwrap().entry(room_id.to_owned()) {
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
if let Some(last_count) = self.pdus_until(sender_user, room_id, PduCount::max())?.find_map(|r| {
|
||||
// Filter out buggy events
|
||||
if r.is_err() {
|
||||
error!("Bad pdu in pdus_since: {:?}", r);
|
||||
}
|
||||
r.ok()
|
||||
}) {
|
||||
Ok(*v.insert(last_count.0))
|
||||
} else {
|
||||
Ok(PduCount::Normal(0))
|
||||
}
|
||||
},
|
||||
hash_map::Entry::Occupied(o) => Ok(*o.get()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the `count` of this pdu's id.
|
||||
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu_id| pdu_count(&pdu_id))
|
||||
.transpose()
|
||||
}
|
||||
/// Returns the `count` of this pdu's id.
|
||||
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
|
||||
self.eventid_pduid.get(event_id.as_bytes())?.map(|pdu_id| pdu_count(&pdu_id)).transpose()
|
||||
}
|
||||
|
||||
/// Returns the json of a pdu.
|
||||
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|
||||
|| {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.transpose()
|
||||
},
|
||||
|x| Ok(Some(x)),
|
||||
)
|
||||
}
|
||||
/// Returns the json of a pdu.
|
||||
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|
||||
|| {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||
.transpose()
|
||||
},
|
||||
|x| Ok(Some(x)),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the json of a pdu.
|
||||
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pduid| {
|
||||
self.pduid_pdu
|
||||
.get(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
})
|
||||
.transpose()?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
/// Returns the json of a pdu.
|
||||
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pduid| {
|
||||
self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
})
|
||||
.transpose()?
|
||||
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns the pdu's id.
|
||||
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> {
|
||||
self.eventid_pduid.get(event_id.as_bytes())
|
||||
}
|
||||
/// Returns the pdu's id.
|
||||
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { self.eventid_pduid.get(event_id.as_bytes()) }
|
||||
|
||||
/// Returns the pdu.
|
||||
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pduid| {
|
||||
self.pduid_pdu
|
||||
.get(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
})
|
||||
.transpose()?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
/// Returns the pdu.
|
||||
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pduid| {
|
||||
self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
})
|
||||
.transpose()?
|
||||
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns the pdu.
|
||||
///
|
||||
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
|
||||
fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
|
||||
if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) {
|
||||
return Ok(Some(Arc::clone(p)));
|
||||
}
|
||||
/// Returns the pdu.
|
||||
///
|
||||
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
|
||||
fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
|
||||
if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) {
|
||||
return Ok(Some(Arc::clone(p)));
|
||||
}
|
||||
|
||||
if let Some(pdu) = self
|
||||
.get_non_outlier_pdu(event_id)?
|
||||
.map_or_else(
|
||||
|| {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.transpose()
|
||||
},
|
||||
|x| Ok(Some(x)),
|
||||
)?
|
||||
.map(Arc::new)
|
||||
{
|
||||
self.pdu_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(event_id.to_owned(), Arc::clone(&pdu));
|
||||
Ok(Some(pdu))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
if let Some(pdu) = self
|
||||
.get_non_outlier_pdu(event_id)?
|
||||
.map_or_else(
|
||||
|| {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||
.transpose()
|
||||
},
|
||||
|x| Ok(Some(x)),
|
||||
)?
|
||||
.map(Arc::new)
|
||||
{
|
||||
self.pdu_cache.lock().unwrap().insert(event_id.to_owned(), Arc::clone(&pdu));
|
||||
Ok(Some(pdu))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the pdu.
|
||||
///
|
||||
/// This does __NOT__ check the outliers `Tree`.
|
||||
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
/// Returns the pdu.
|
||||
///
|
||||
/// This does __NOT__ check the outliers `Tree`.
|
||||
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
|
||||
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
|
||||
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn append_pdu(
|
||||
&self,
|
||||
pdu_id: &[u8],
|
||||
pdu: &PduEvent,
|
||||
json: &CanonicalJsonObject,
|
||||
count: u64,
|
||||
) -> Result<()> {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
|
||||
self.lasttimelinecount_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(pdu.room_id.clone(), PduCount::Normal(count));
|
||||
self.lasttimelinecount_cache.lock().unwrap().insert(pdu.room_id.clone(), PduCount::Normal(count));
|
||||
|
||||
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
|
||||
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
|
||||
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
|
||||
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepend_backfill_pdu(
|
||||
&self,
|
||||
pdu_id: &[u8],
|
||||
event_id: &EventId,
|
||||
json: &CanonicalJsonObject,
|
||||
) -> Result<()> {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
|
||||
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
|
||||
self.eventid_outlierpdu.remove(event_id.as_bytes())?;
|
||||
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
|
||||
self.eventid_outlierpdu.remove(event_id.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a pdu and creates a new one with the same id.
|
||||
fn replace_pdu(
|
||||
&self,
|
||||
pdu_id: &[u8],
|
||||
pdu_json: &CanonicalJsonObject,
|
||||
pdu: &PduEvent,
|
||||
) -> Result<()> {
|
||||
if self.pduid_pdu.get(pdu_id)?.is_some() {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PDU does not exist.",
|
||||
));
|
||||
}
|
||||
/// Removes a pdu and creates a new one with the same id.
|
||||
fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> {
|
||||
if self.pduid_pdu.get(pdu_id)?.is_some() {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist."));
|
||||
}
|
||||
|
||||
self.pdu_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(&(*pdu.event_id).to_owned());
|
||||
self.pdu_cache.lock().unwrap().remove(&(*pdu.event_id).to_owned());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns an iterator over all events and their tokens in a room that happened before the
|
||||
/// event with id `until` in reverse-chronological order.
|
||||
fn pdus_until<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
until: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
|
||||
/// Returns an iterator over all events and their tokens in a room that
|
||||
/// happened before the event with id `until` in reverse-chronological
|
||||
/// order.
|
||||
fn pdus_until<'a>(
|
||||
&'a self, user_id: &UserId, room_id: &RoomId, until: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
|
||||
|
||||
let user_id = user_id.to_owned();
|
||||
let user_id = user_id.to_owned();
|
||||
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu
|
||||
.iter_from(¤t, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
pdu.add_age()?;
|
||||
let count = pdu_count(&pdu_id)?;
|
||||
Ok((count, pdu))
|
||||
}),
|
||||
))
|
||||
}
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||
move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
pdu.add_age()?;
|
||||
let count = pdu_count(&pdu_id)?;
|
||||
Ok((count, pdu))
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn pdus_after<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
from: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
|
||||
fn pdus_after<'a>(
|
||||
&'a self, user_id: &UserId, room_id: &RoomId, from: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
|
||||
|
||||
let user_id = user_id.to_owned();
|
||||
let user_id = user_id.to_owned();
|
||||
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu
|
||||
.iter_from(¤t, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
pdu.add_age()?;
|
||||
let count = pdu_count(&pdu_id)?;
|
||||
Ok((count, pdu))
|
||||
}),
|
||||
))
|
||||
}
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu.iter_from(¤t, false).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||
move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
pdu.add_age()?;
|
||||
let count = pdu_count(&pdu_id)?;
|
||||
Ok((count, pdu))
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn increment_notification_counts(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
notifies: Vec<OwnedUserId>,
|
||||
highlights: Vec<OwnedUserId>,
|
||||
) -> Result<()> {
|
||||
let mut notifies_batch = Vec::new();
|
||||
let mut highlights_batch = Vec::new();
|
||||
for user in notifies {
|
||||
let mut userroom_id = user.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
notifies_batch.push(userroom_id);
|
||||
}
|
||||
for user in highlights {
|
||||
let mut userroom_id = user.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
highlights_batch.push(userroom_id);
|
||||
}
|
||||
fn increment_notification_counts(
|
||||
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
|
||||
) -> Result<()> {
|
||||
let mut notifies_batch = Vec::new();
|
||||
let mut highlights_batch = Vec::new();
|
||||
for user in notifies {
|
||||
let mut userroom_id = user.as_bytes().to_vec();
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
notifies_batch.push(userroom_id);
|
||||
}
|
||||
for user in highlights {
|
||||
let mut userroom_id = user.as_bytes().to_vec();
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
highlights_batch.push(userroom_id);
|
||||
}
|
||||
|
||||
self.userroomid_notificationcount
|
||||
.increment_batch(&mut notifies_batch.into_iter())?;
|
||||
self.userroomid_highlightcount
|
||||
.increment_batch(&mut highlights_batch.into_iter())?;
|
||||
Ok(())
|
||||
}
|
||||
self.userroomid_notificationcount.increment_batch(&mut notifies_batch.into_iter())?;
|
||||
self.userroomid_highlightcount.increment_batch(&mut highlights_batch.into_iter())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the `count` of this pdu's id.
|
||||
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
|
||||
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
|
||||
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
|
||||
let second_last_u64 = utils::u64_from_bytes(
|
||||
&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()],
|
||||
);
|
||||
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
|
||||
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
|
||||
let second_last_u64 =
|
||||
utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()]);
|
||||
|
||||
if matches!(second_last_u64, Ok(0)) {
|
||||
Ok(PduCount::Backfilled(u64::MAX - last_u64))
|
||||
} else {
|
||||
Ok(PduCount::Normal(last_u64))
|
||||
}
|
||||
if matches!(second_last_u64, Ok(0)) {
|
||||
Ok(PduCount::Backfilled(u64::MAX - last_u64))
|
||||
} else {
|
||||
Ok(PduCount::Normal(last_u64))
|
||||
}
|
||||
}
|
||||
|
||||
fn count_to_id(
|
||||
room_id: &RoomId,
|
||||
count: PduCount,
|
||||
offset: u64,
|
||||
subtract: bool,
|
||||
) -> Result<(Vec<u8>, Vec<u8>)> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
let mut pdu_id = prefix.clone();
|
||||
// +1 so we don't send the base event
|
||||
let count_raw = match count {
|
||||
PduCount::Normal(x) => {
|
||||
if subtract {
|
||||
x - offset
|
||||
} else {
|
||||
x + offset
|
||||
}
|
||||
}
|
||||
PduCount::Backfilled(x) => {
|
||||
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
|
||||
let num = u64::MAX - x;
|
||||
if subtract {
|
||||
if num > 0 {
|
||||
num - offset
|
||||
} else {
|
||||
num
|
||||
}
|
||||
} else {
|
||||
num + offset
|
||||
}
|
||||
}
|
||||
};
|
||||
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
|
||||
fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec<u8>, Vec<u8>)> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
let mut pdu_id = prefix.clone();
|
||||
// +1 so we don't send the base event
|
||||
let count_raw = match count {
|
||||
PduCount::Normal(x) => {
|
||||
if subtract {
|
||||
x - offset
|
||||
} else {
|
||||
x + offset
|
||||
}
|
||||
},
|
||||
PduCount::Backfilled(x) => {
|
||||
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
|
||||
let num = u64::MAX - x;
|
||||
if subtract {
|
||||
if num > 0 {
|
||||
num - offset
|
||||
} else {
|
||||
num
|
||||
}
|
||||
} else {
|
||||
num + offset
|
||||
}
|
||||
},
|
||||
};
|
||||
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
|
||||
|
||||
Ok((prefix, pdu_id))
|
||||
Ok((prefix, pdu_id))
|
||||
}
|
||||
|
|
|
@ -3,147 +3,122 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
|
|||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::user::Data for KeyValueDatabase {
|
||||
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xff);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xFF);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.userroomid_notificationcount
|
||||
.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
self.userroomid_highlightcount
|
||||
.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
self.userroomid_notificationcount.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
self.userroomid_highlightcount.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
|
||||
self.roomuserid_lastnotificationread.insert(
|
||||
&roomuser_id,
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
self.roomuserid_lastnotificationread.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_notificationcount
|
||||
.get(&userroom_id)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid notification count in db."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
self.userroomid_notificationcount
|
||||
.get(&userroom_id)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
|
||||
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_highlightcount
|
||||
.get(&userroom_id)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid highlight count in db."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
self.userroomid_highlightcount
|
||||
.get(&userroom_id)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
|
||||
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
Ok(self
|
||||
.roomuserid_lastnotificationread
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
Ok(self
|
||||
.roomuserid_lastnotificationread
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
fn associate_token_shortstatehash(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
token: u64,
|
||||
shortstatehash: u64,
|
||||
) -> Result<()> {
|
||||
let shortroomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists");
|
||||
fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
|
||||
let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists");
|
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&token.to_be_bytes());
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&token.to_be_bytes());
|
||||
|
||||
self.roomsynctoken_shortstatehash
|
||||
.insert(&key, &shortstatehash.to_be_bytes())
|
||||
}
|
||||
self.roomsynctoken_shortstatehash.insert(&key, &shortstatehash.to_be_bytes())
|
||||
}
|
||||
|
||||
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
|
||||
let shortroomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists");
|
||||
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
|
||||
let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists");
|
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&token.to_be_bytes());
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&token.to_be_bytes());
|
||||
|
||||
self.roomsynctoken_shortstatehash
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
self.roomsynctoken_shortstatehash
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash"))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_shared_rooms<'a>(
|
||||
&'a self,
|
||||
users: Vec<OwnedUserId>,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
|
||||
let iterators = users.into_iter().map(move |user_id| {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
fn get_shared_rooms<'a>(
|
||||
&'a self, users: Vec<OwnedUserId>,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
|
||||
let iterators = users.into_iter().map(move |user_id| {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
self.userroomid_joined
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
let roomid_index = key
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, &b)| b == 0xff)
|
||||
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
|
||||
.0
|
||||
+ 1; // +1 because the room id starts AFTER the separator
|
||||
self.userroomid_joined
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
let roomid_index = key
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, &b)| b == 0xFF)
|
||||
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
|
||||
.0 + 1; // +1 because the room id starts AFTER the separator
|
||||
|
||||
let room_id = key[roomid_index..].to_vec();
|
||||
let room_id = key[roomid_index..].to_vec();
|
||||
|
||||
Ok::<_, Error>(room_id)
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
});
|
||||
Ok::<_, Error>(room_id)
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
});
|
||||
|
||||
// We use the default compare function because keys are sorted correctly (not reversed)
|
||||
Ok(Box::new(
|
||||
utils::common_elements(iterators, Ord::cmp)
|
||||
.expect("users is not empty")
|
||||
.map(|bytes| {
|
||||
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid RoomId bytes in userroomid_joined")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
|
||||
}),
|
||||
))
|
||||
}
|
||||
// We use the default compare function because keys are sorted correctly (not
|
||||
// reversed)
|
||||
Ok(Box::new(
|
||||
utils::common_elements(iterators, Ord::cmp).expect("users is not empty").map(|bytes| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
|
||||
}),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,205 +1,181 @@
|
|||
use ruma::{ServerName, UserId};
|
||||
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{
|
||||
self,
|
||||
sending::{OutgoingKind, SendingEventType},
|
||||
},
|
||||
services, utils, Error, Result,
|
||||
database::KeyValueDatabase,
|
||||
service::{
|
||||
self,
|
||||
sending::{OutgoingKind, SendingEventType},
|
||||
},
|
||||
services, utils, Error, Result,
|
||||
};
|
||||
|
||||
impl service::sending::Data for KeyValueDatabase {
|
||||
fn active_requests<'a>(
|
||||
&'a self,
|
||||
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>> + 'a> {
|
||||
Box::new(
|
||||
self.servercurrentevent_data
|
||||
.iter()
|
||||
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))),
|
||||
)
|
||||
}
|
||||
fn active_requests<'a>(
|
||||
&'a self,
|
||||
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>> + 'a> {
|
||||
Box::new(
|
||||
self.servercurrentevent_data
|
||||
.iter()
|
||||
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))),
|
||||
)
|
||||
}
|
||||
|
||||
fn active_requests_for<'a>(
|
||||
&'a self,
|
||||
outgoing_kind: &OutgoingKind,
|
||||
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
Box::new(
|
||||
self.servercurrentevent_data
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
|
||||
)
|
||||
}
|
||||
fn active_requests_for<'a>(
|
||||
&'a self, outgoing_kind: &OutgoingKind,
|
||||
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
Box::new(
|
||||
self.servercurrentevent_data
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
|
||||
)
|
||||
}
|
||||
|
||||
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> {
|
||||
self.servercurrentevent_data.remove(&key)
|
||||
}
|
||||
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { self.servercurrentevent_data.remove(&key) }
|
||||
|
||||
fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
|
||||
self.servercurrentevent_data.remove(&key)?;
|
||||
}
|
||||
fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) {
|
||||
self.servercurrentevent_data.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
|
||||
self.servercurrentevent_data.remove(&key).unwrap();
|
||||
}
|
||||
fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
|
||||
self.servercurrentevent_data.remove(&key).unwrap();
|
||||
}
|
||||
|
||||
for (key, _) in self.servernameevent_data.scan_prefix(prefix) {
|
||||
self.servernameevent_data.remove(&key).unwrap();
|
||||
}
|
||||
for (key, _) in self.servernameevent_data.scan_prefix(prefix) {
|
||||
self.servernameevent_data.remove(&key).unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn queue_requests(
|
||||
&self,
|
||||
requests: &[(&OutgoingKind, SendingEventType)],
|
||||
) -> Result<Vec<Vec<u8>>> {
|
||||
let mut batch = Vec::new();
|
||||
let mut keys = Vec::new();
|
||||
for (outgoing_kind, event) in requests {
|
||||
let mut key = outgoing_kind.get_prefix();
|
||||
if let SendingEventType::Pdu(value) = &event {
|
||||
key.extend_from_slice(value);
|
||||
} else {
|
||||
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
}
|
||||
let value = if let SendingEventType::Edu(value) = &event {
|
||||
&**value
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
batch.push((key.clone(), value.to_owned()));
|
||||
keys.push(key);
|
||||
}
|
||||
self.servernameevent_data
|
||||
.insert_batch(&mut batch.into_iter())?;
|
||||
Ok(keys)
|
||||
}
|
||||
fn queue_requests(&self, requests: &[(&OutgoingKind, SendingEventType)]) -> Result<Vec<Vec<u8>>> {
|
||||
let mut batch = Vec::new();
|
||||
let mut keys = Vec::new();
|
||||
for (outgoing_kind, event) in requests {
|
||||
let mut key = outgoing_kind.get_prefix();
|
||||
if let SendingEventType::Pdu(value) = &event {
|
||||
key.extend_from_slice(value);
|
||||
} else {
|
||||
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
}
|
||||
let value = if let SendingEventType::Edu(value) = &event {
|
||||
&**value
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
batch.push((key.clone(), value.to_owned()));
|
||||
keys.push(key);
|
||||
}
|
||||
self.servernameevent_data.insert_batch(&mut batch.into_iter())?;
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
fn queued_requests<'a>(
|
||||
&'a self,
|
||||
outgoing_kind: &OutgoingKind,
|
||||
) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
return Box::new(
|
||||
self.servernameevent_data
|
||||
.scan_prefix(prefix)
|
||||
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
|
||||
);
|
||||
}
|
||||
fn queued_requests<'a>(
|
||||
&'a self, outgoing_kind: &OutgoingKind,
|
||||
) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
return Box::new(
|
||||
self.servernameevent_data
|
||||
.scan_prefix(prefix)
|
||||
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
|
||||
);
|
||||
}
|
||||
|
||||
fn mark_as_active(&self, events: &[(SendingEventType, Vec<u8>)]) -> Result<()> {
|
||||
for (e, key) in events {
|
||||
let value = if let SendingEventType::Edu(value) = &e {
|
||||
&**value
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
self.servercurrentevent_data.insert(key, value)?;
|
||||
self.servernameevent_data.remove(key)?;
|
||||
}
|
||||
fn mark_as_active(&self, events: &[(SendingEventType, Vec<u8>)]) -> Result<()> {
|
||||
for (e, key) in events {
|
||||
let value = if let SendingEventType::Edu(value) = &e {
|
||||
&**value
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
self.servercurrentevent_data.insert(key, value)?;
|
||||
self.servernameevent_data.remove(key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
|
||||
self.servername_educount
|
||||
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
|
||||
}
|
||||
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
|
||||
self.servername_educount.insert(server_name.as_bytes(), &last_count.to_be_bytes())
|
||||
}
|
||||
|
||||
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
|
||||
self.servername_educount
|
||||
.get(server_name.as_bytes())?
|
||||
.map_or(Ok(0), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
|
||||
})
|
||||
}
|
||||
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
|
||||
self.servername_educount.get(server_name.as_bytes())?.map_or(Ok(0), |bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(key))]
|
||||
fn parse_servercurrentevent(
|
||||
key: &[u8],
|
||||
value: Vec<u8>,
|
||||
) -> Result<(OutgoingKind, SendingEventType)> {
|
||||
// Appservices start with a plus
|
||||
Ok::<_, Error>(if key.starts_with(b"+") {
|
||||
let mut parts = key[1..].splitn(2, |&b| b == 0xff);
|
||||
fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(OutgoingKind, SendingEventType)> {
|
||||
// Appservices start with a plus
|
||||
Ok::<_, Error>(if key.starts_with(b"+") {
|
||||
let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
|
||||
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
|
||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
||||
})?;
|
||||
let server = utils::string_from_bytes(server)
|
||||
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
|
||||
|
||||
(
|
||||
OutgoingKind::Appservice(server),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
} else if key.starts_with(b"$") {
|
||||
let mut parts = key[1..].splitn(3, |&b| b == 0xff);
|
||||
(
|
||||
OutgoingKind::Appservice(server),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
} else if key.starts_with(b"$") {
|
||||
let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
|
||||
|
||||
let user = parts.next().expect("splitn always returns one element");
|
||||
let user_string = utils::string_from_bytes(user)
|
||||
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
|
||||
let user_id = UserId::parse(user_string)
|
||||
.map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
|
||||
let user = parts.next().expect("splitn always returns one element");
|
||||
let user_string = utils::string_from_bytes(user)
|
||||
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
|
||||
let user_id =
|
||||
UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
|
||||
|
||||
let pushkey = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let pushkey_string = utils::string_from_bytes(pushkey)
|
||||
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
|
||||
let pushkey = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let pushkey_string = utils::string_from_bytes(pushkey)
|
||||
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
|
||||
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
|
||||
(
|
||||
OutgoingKind::Push(user_id, pushkey_string),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
// I'm pretty sure this should never be called
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let mut parts = key.splitn(2, |&b| b == 0xff);
|
||||
(
|
||||
OutgoingKind::Push(user_id, pushkey_string),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
// I'm pretty sure this should never be called
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let mut parts = key.splitn(2, |&b| b == 0xFF);
|
||||
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
|
||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
||||
})?;
|
||||
let server = utils::string_from_bytes(server)
|
||||
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
|
||||
|
||||
(
|
||||
OutgoingKind::Normal(ServerName::parse(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server string in server_currenttransaction")
|
||||
})?),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
})
|
||||
(
|
||||
OutgoingKind::Normal(
|
||||
ServerName::parse(server)
|
||||
.map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?,
|
||||
),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,37 +3,30 @@ use ruma::{DeviceId, TransactionId, UserId};
|
|||
use crate::{database::KeyValueDatabase, service, Result};
|
||||
|
||||
impl service::transaction_ids::Data for KeyValueDatabase {
|
||||
fn add_txnid(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
txn_id: &TransactionId,
|
||||
data: &[u8],
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
fn add_txnid(
|
||||
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
|
||||
self.userdevicetxnid_response.insert(&key, data)?;
|
||||
self.userdevicetxnid_response.insert(&key, data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn existing_txnid(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
txn_id: &TransactionId,
|
||||
) -> Result<Option<Vec<u8>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
fn existing_txnid(
|
||||
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
|
||||
) -> Result<Option<Vec<u8>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
|
||||
// If there's no entry, this is a new transaction
|
||||
self.userdevicetxnid_response.get(&key)
|
||||
}
|
||||
// If there's no entry, this is a new transaction
|
||||
self.userdevicetxnid_response.get(&key)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,89 +1,64 @@
|
|||
use ruma::{
|
||||
api::client::{error::ErrorKind, uiaa::UiaaInfo},
|
||||
CanonicalJsonValue, DeviceId, UserId,
|
||||
api::client::{error::ErrorKind, uiaa::UiaaInfo},
|
||||
CanonicalJsonValue, DeviceId, UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, Error, Result};
|
||||
|
||||
impl service::uiaa::Data for KeyValueDatabase {
|
||||
fn set_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
request: &CanonicalJsonValue,
|
||||
) -> Result<()> {
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(
|
||||
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
|
||||
request.to_owned(),
|
||||
);
|
||||
fn set_uiaa_request(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
|
||||
) -> Result<()> {
|
||||
self.userdevicesessionid_uiaarequest.write().unwrap().insert(
|
||||
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
|
||||
request.to_owned(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
) -> Option<CanonicalJsonValue> {
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
|
||||
.map(std::borrow::ToOwned::to_owned)
|
||||
}
|
||||
fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option<CanonicalJsonValue> {
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
|
||||
.map(std::borrow::ToOwned::to_owned)
|
||||
}
|
||||
|
||||
fn update_uiaa_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
uiaainfo: Option<&UiaaInfo>,
|
||||
) -> Result<()> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
fn update_uiaa_session(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>,
|
||||
) -> Result<()> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xFF);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xFF);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
|
||||
if let Some(uiaainfo) = uiaainfo {
|
||||
self.userdevicesessionid_uiaainfo.insert(
|
||||
&userdevicesessionid,
|
||||
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
|
||||
)?;
|
||||
} else {
|
||||
self.userdevicesessionid_uiaainfo
|
||||
.remove(&userdevicesessionid)?;
|
||||
}
|
||||
if let Some(uiaainfo) = uiaainfo {
|
||||
self.userdevicesessionid_uiaainfo.insert(
|
||||
&userdevicesessionid,
|
||||
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
|
||||
)?;
|
||||
} else {
|
||||
self.userdevicesessionid_uiaainfo.remove(&userdevicesessionid)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_uiaa_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
) -> Result<UiaaInfo> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xFF);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xFF);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
|
||||
serde_json::from_slice(
|
||||
&self
|
||||
.userdevicesessionid_uiaainfo
|
||||
.get(&userdevicesessionid)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"UIAA session does not exist.",
|
||||
))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
|
||||
}
|
||||
serde_json::from_slice(
|
||||
&self
|
||||
.userdevicesessionid_uiaainfo
|
||||
.get(&userdevicesessionid)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::Forbidden, "UIAA session does not exist."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
2255
src/database/mod.rs
2255
src/database/mod.rs
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue