feat: heed db backend (LMDB)

This commit is contained in:
Timo Kösters 2021-07-29 20:17:47 +02:00
parent c209775abd
commit 5c776e9ba7
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
9 changed files with 456 additions and 97 deletions

View file

@ -12,6 +12,9 @@ pub mod sled;
#[cfg(feature = "sqlite")]
pub mod sqlite;
#[cfg(feature = "heed")]
pub mod heed;
pub trait DatabaseEngine: Sized {
fn open(config: &Config) -> Result<Arc<Self>>;
fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>>;

View file

@ -0,0 +1,241 @@
use super::super::Config;
use crossbeam::channel::{bounded, Sender as ChannelSender};
use threadpool::ThreadPool;
use crate::{Error, Result};
use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::{Arc, Mutex, RwLock},
};
use tokio::sync::oneshot::Sender;
use super::{DatabaseEngine, Tree};
type TupleOfBytes = (Vec<u8>, Vec<u8>);
pub struct Engine {
env: heed::Env,
iter_pool: Mutex<ThreadPool>,
}
pub struct EngineTree {
engine: Arc<Engine>,
tree: Arc<heed::UntypedDatabase>,
watchers: RwLock<HashMap<Vec<u8>, Vec<Sender<()>>>>,
}
fn convert_error(error: heed::Error) -> Error {
panic!(error.to_string());
Error::HeedError {
error: error.to_string(),
}
}
impl DatabaseEngine for Engine {
fn open(config: &Config) -> Result<Arc<Self>> {
let mut env_builder = heed::EnvOpenOptions::new();
env_builder.map_size(1024 * 1024 * 1024 * 1024); // 1 Terabyte
env_builder.max_readers(126);
env_builder.max_dbs(128);
unsafe {
env_builder.flag(heed::flags::Flags::MdbNoSync);
env_builder.flag(heed::flags::Flags::MdbNoMetaSync);
}
Ok(Arc::new(Engine {
env: env_builder
.open(&config.database_path)
.map_err(convert_error)?,
iter_pool: Mutex::new(ThreadPool::new(10)),
}))
}
fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>> {
// Creates the db if it doesn't exist already
Ok(Arc::new(EngineTree {
engine: Arc::clone(self),
tree: Arc::new(
self.env
.create_database(Some(name))
.map_err(convert_error)?,
),
watchers: RwLock::new(HashMap::new()),
}))
}
fn flush(self: &Arc<Self>) -> Result<()> {
self.env.force_sync().map_err(convert_error)?;
Ok(())
}
}
impl EngineTree {
#[tracing::instrument(skip(self, tree, from, backwards))]
fn iter_from_thread(
&self,
tree: Arc<heed::UntypedDatabase>,
from: Vec<u8>,
backwards: bool,
) -> Box<dyn Iterator<Item = TupleOfBytes> + Send + Sync> {
let (s, r) = bounded::<TupleOfBytes>(5);
let engine = Arc::clone(&self.engine);
let lock = self.engine.iter_pool.lock().unwrap();
if lock.active_count() < lock.max_count() {
lock.execute(move || {
iter_from_thread_work(tree, &engine.env.read_txn().unwrap(), from, backwards, &s);
});
} else {
std::thread::spawn(move || {
iter_from_thread_work(tree, &engine.env.read_txn().unwrap(), from, backwards, &s);
});
}
Box::new(r.into_iter())
}
}
#[tracing::instrument(skip(tree, txn, from, backwards))]
fn iter_from_thread_work(
tree: Arc<heed::UntypedDatabase>,
txn: &heed::RoTxn<'_>,
from: Vec<u8>,
backwards: bool,
s: &ChannelSender<(Vec<u8>, Vec<u8>)>,
) {
if backwards {
for (k, v) in tree.rev_range(txn, ..=&*from).unwrap().map(|r| r.unwrap()) {
if s.send((k.to_vec(), v.to_vec())).is_err() {
return;
}
}
} else {
if from.is_empty() {
for (k, v) in tree.iter(txn).unwrap().map(|r| r.unwrap()) {
if s.send((k.to_vec(), v.to_vec())).is_err() {
return;
}
}
} else {
for (k, v) in tree.range(txn, &*from..).unwrap().map(|r| r.unwrap()) {
if s.send((k.to_vec(), v.to_vec())).is_err() {
return;
}
}
}
}
}
impl Tree for EngineTree {
#[tracing::instrument(skip(self, key))]
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
let txn = self.engine.env.read_txn().map_err(convert_error)?;
Ok(self
.tree
.get(&txn, &key)
.map_err(convert_error)?
.map(|s| s.to_vec()))
}
#[tracing::instrument(skip(self, key, value))]
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
let mut txn = self.engine.env.write_txn().map_err(convert_error)?;
self.tree
.put(&mut txn, &key, &value)
.map_err(convert_error)?;
txn.commit().map_err(convert_error)?;
let watchers = self.watchers.read().unwrap();
let mut triggered = Vec::new();
for length in 0..=key.len() {
if watchers.contains_key(&key[..length]) {
triggered.push(&key[..length]);
}
}
drop(watchers);
if !triggered.is_empty() {
let mut watchers = self.watchers.write().unwrap();
for prefix in triggered {
if let Some(txs) = watchers.remove(prefix) {
for tx in txs {
let _ = tx.send(());
}
}
}
};
Ok(())
}
#[tracing::instrument(skip(self, key))]
fn remove(&self, key: &[u8]) -> Result<()> {
let mut txn = self.engine.env.write_txn().map_err(convert_error)?;
self.tree.delete(&mut txn, &key).map_err(convert_error)?;
txn.commit().map_err(convert_error)?;
Ok(())
}
#[tracing::instrument(skip(self))]
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + Send + 'a> {
self.iter_from(&[], false)
}
#[tracing::instrument(skip(self, from, backwards))]
fn iter_from(
&self,
from: &[u8],
backwards: bool,
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + Send> {
self.iter_from_thread(Arc::clone(&self.tree), from.to_vec(), backwards)
}
#[tracing::instrument(skip(self, key))]
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
let mut txn = self.engine.env.write_txn().map_err(convert_error)?;
let old = self.tree.get(&txn, &key).map_err(convert_error)?;
let new =
crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
self.tree
.put(&mut txn, &key, &&*new)
.map_err(convert_error)?;
txn.commit().map_err(convert_error)?;
Ok(new)
}
#[tracing::instrument(skip(self, prefix))]
fn scan_prefix<'a>(
&'a self,
prefix: Vec<u8>,
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + Send + 'a> {
Box::new(
self.iter_from(&prefix, false)
.take_while(move |(key, _)| key.starts_with(&prefix)),
)
}
#[tracing::instrument(skip(self, prefix))]
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.watchers
.write()
.unwrap()
.entry(prefix.to_vec())
.or_default()
.push(tx);
Box::pin(async move {
// Tx is never destroyed
rx.await.unwrap();
})
}
}

View file

@ -863,6 +863,7 @@ impl Rooms {
if let Some(body) = pdu.content.get("body").and_then(|b| b.as_str()) {
for word in body
.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|word| word.len() <= 50)
.map(str::to_lowercase)
{
let mut key = pdu.room_id.as_bytes().to_vec();

View file

@ -81,10 +81,10 @@ pub enum SendingEventType {
pub struct Sending {
/// The state for a given state hash.
pub(super) servername_educount: Arc<dyn Tree>, // EduCount: Count of last EDU sync
pub(super) servernamepduids: Arc<dyn Tree>, // ServernamePduId = (+ / $)SenderKey / ServerName / UserId + PduId
pub(super) servercurrentevents: Arc<dyn Tree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / (*)EduEvent
pub(super) servernameevent_data: Arc<dyn Tree>, // ServernamEvent = (+ / $)SenderKey / ServerName / UserId + PduId / * (for edus), Data = EDU content
pub(super) servercurrentevent_data: Arc<dyn Tree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / * (for edus), Data = EDU content
pub(super) maximum_requests: Arc<Semaphore>,
pub sender: mpsc::UnboundedSender<Vec<u8>>,
pub sender: mpsc::UnboundedSender<(Vec<u8>, Vec<u8>)>,
}
enum TransactionStatus {
@ -97,7 +97,7 @@ impl Sending {
pub fn start_handler(
&self,
db: Arc<RwLock<Database>>,
mut receiver: mpsc::UnboundedReceiver<Vec<u8>>,
mut receiver: mpsc::UnboundedReceiver<(Vec<u8>, Vec<u8>)>,
) {
tokio::spawn(async move {
let mut futures = FuturesUnordered::new();
@ -109,16 +109,15 @@ impl Sending {
let guard = db.read().await;
for (key, outgoing_kind, event) in
guard
.sending
.servercurrentevents
.iter()
.filter_map(|(key, _)| {
Self::parse_servercurrentevent(&key)
.ok()
.map(|(k, e)| (key, k, e))
})
for (key, outgoing_kind, event) in guard
.sending
.servercurrentevent_data
.iter()
.filter_map(|(key, v)| {
Self::parse_servercurrentevent(&key, v)
.ok()
.map(|(k, e)| (key, k, e))
})
{
let entry = initial_transactions
.entry(outgoing_kind.clone())
@ -129,7 +128,7 @@ impl Sending {
"Dropping some current events: {:?} {:?} {:?}",
key, outgoing_kind, event
);
guard.sending.servercurrentevents.remove(&key).unwrap();
guard.sending.servercurrentevent_data.remove(&key).unwrap();
continue;
}
@ -156,17 +155,17 @@ impl Sending {
let guard = db.read().await;
let prefix = outgoing_kind.get_prefix();
for (key, _) in guard.sending.servercurrentevents
for (key, _) in guard.sending.servercurrentevent_data
.scan_prefix(prefix.clone())
{
guard.sending.servercurrentevents.remove(&key).unwrap();
guard.sending.servercurrentevent_data.remove(&key).unwrap();
}
// Find events that have been added since starting the last request
let new_events = guard.sending.servernamepduids
let new_events = guard.sending.servernameevent_data
.scan_prefix(prefix.clone())
.filter_map(|(k, _)| {
Self::parse_servercurrentevent(&k).ok().map(|ev| (ev, k))
.filter_map(|(k, v)| {
Self::parse_servercurrentevent(&k, v).ok().map(|ev| (ev, k))
})
.take(30)
.collect::<Vec<_>>();
@ -175,9 +174,10 @@ impl Sending {
if !new_events.is_empty() {
// Insert pdus we found
for (_, key) in &new_events {
guard.sending.servercurrentevents.insert(&key, &[]).unwrap();
guard.sending.servernamepduids.remove(&key).unwrap();
for (e, key) in &new_events {
let value = if let SendingEventType::Edu(value) = &e.1 { &**value } else { &[] };
guard.sending.servercurrentevent_data.insert(&key, value).unwrap();
guard.sending.servernameevent_data.remove(&key).unwrap();
}
drop(guard);
@ -205,8 +205,8 @@ impl Sending {
}
};
},
Some(key) = receiver.next() => {
if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) {
Some((key, value)) = receiver.next() => {
if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key, value) {
let guard = db.read().await;
if let Ok(Some(events)) = Self::select_events(
@ -267,18 +267,25 @@ impl Sending {
if retry {
// We retry the previous transaction
for (key, _) in db.sending.servercurrentevents.scan_prefix(prefix) {
if let Ok((_, e)) = Self::parse_servercurrentevent(&key) {
for (key, value) in db.sending.servercurrentevent_data.scan_prefix(prefix) {
if let Ok((_, e)) = Self::parse_servercurrentevent(&key, value) {
events.push(e);
}
}
} else {
for (e, full_key) in new_events {
db.sending.servercurrentevents.insert(&full_key, &[])?;
let value = if let SendingEventType::Edu(value) = &e {
&**value
} else {
&[][..]
};
db.sending
.servercurrentevent_data
.insert(&full_key, value)?;
// If it was a PDU we have to unqueue it
// TODO: don't try to unqueue EDUs
db.sending.servernamepduids.remove(&full_key)?;
db.sending.servernameevent_data.remove(&full_key)?;
events.push(e);
}
@ -380,8 +387,8 @@ impl Sending {
key.extend_from_slice(&senderkey);
key.push(0xff);
key.extend_from_slice(pdu_id);
self.servernamepduids.insert(&key, b"")?;
self.sender.unbounded_send(key).unwrap();
self.servernameevent_data.insert(&key, &[])?;
self.sender.unbounded_send((key, vec![])).unwrap();
Ok(())
}
@ -391,20 +398,19 @@ impl Sending {
let mut key = server.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(pdu_id);
self.servernamepduids.insert(&key, b"")?;
self.sender.unbounded_send(key).unwrap();
self.servernameevent_data.insert(&key, &[])?;
self.sender.unbounded_send((key, vec![])).unwrap();
Ok(())
}
#[tracing::instrument(skip(self, server, serialized))]
pub fn send_reliable_edu(&self, server: &ServerName, serialized: &[u8]) -> Result<()> {
pub fn send_reliable_edu(&self, server: &ServerName, serialized: Vec<u8>) -> Result<()> {
let mut key = server.as_bytes().to_vec();
key.push(0xff);
key.push(b'*');
key.extend_from_slice(serialized);
self.servernamepduids.insert(&key, b"")?;
self.sender.unbounded_send(key).unwrap();
self.servernameevent_data.insert(&key, &serialized)?;
self.sender.unbounded_send((key, serialized)).unwrap();
Ok(())
}
@ -415,8 +421,8 @@ impl Sending {
key.extend_from_slice(appservice_id.as_bytes());
key.push(0xff);
key.extend_from_slice(pdu_id);
self.servernamepduids.insert(&key, b"")?;
self.sender.unbounded_send(key).unwrap();
self.servernameevent_data.insert(&key, &[])?;
self.sender.unbounded_send((key, vec![])).unwrap();
Ok(())
}
@ -451,7 +457,7 @@ impl Sending {
(
kind.clone(),
Error::bad_database(
"[Appservice] Event in servernamepduids not found in db.",
"[Appservice] Event in servernameevent_data not found in db.",
),
)
})?
@ -508,7 +514,7 @@ impl Sending {
(
kind.clone(),
Error::bad_database(
"[Push] Event in servernamepduids not found in db.",
"[Push] Event in servernamevent_datas not found in db.",
),
)
})?,
@ -602,7 +608,7 @@ impl Sending {
(
OutgoingKind::Normal(server.clone()),
Error::bad_database(
"[Normal] Event in servernamepduids not found in db.",
"[Normal] Event in servernamevent_datas not found in db.",
),
)
})?,
@ -662,7 +668,10 @@ impl Sending {
}
#[tracing::instrument(skip(key))]
fn parse_servercurrentevent(key: &[u8]) -> Result<(OutgoingKind, SendingEventType)> {
fn parse_servercurrentevent(
key: &[u8],
value: Vec<u8>,
) -> Result<(OutgoingKind, SendingEventType)> {
// Appservices start with a plus
Ok::<_, Error>(if key.starts_with(b"+") {
let mut parts = key[1..].splitn(2, |&b| b == 0xff);
@ -680,7 +689,7 @@ impl Sending {
Error::bad_database("Invalid server string in server_currenttransaction")
})?),
if event.starts_with(b"*") {
SendingEventType::Edu(event[1..].to_vec())
SendingEventType::Edu(value.to_vec())
} else {
SendingEventType::Pdu(event.to_vec())
},