feat: swappable database backend

This commit is contained in:
Timo Kösters 2021-06-08 18:10:00 +02:00
parent 81715bd84d
commit d0ee823254
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
47 changed files with 1434 additions and 981 deletions

View file

@ -13,22 +13,23 @@ use std::{
use tokio::sync::Semaphore;
use trust_dns_resolver::TokioAsyncResolver;
pub const COUNTER: &str = "c";
use super::abstraction::Tree;
pub const COUNTER: &[u8] = b"c";
type WellKnownMap = HashMap<Box<ServerName>, (String, String)>;
type TlsNameMap = HashMap<String, webpki::DNSName>;
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
#[derive(Clone)]
pub struct Globals {
pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub tls_name_override: Arc<RwLock<TlsNameMap>>,
pub(super) globals: sled::Tree,
pub(super) globals: Arc<dyn Tree>,
config: Config,
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
reqwest_client: reqwest::Client,
dns_resolver: TokioAsyncResolver,
jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>,
pub(super) server_signingkeys: sled::Tree,
pub(super) server_signingkeys: Arc<dyn Tree>,
pub bad_event_ratelimiter: Arc<RwLock<BTreeMap<EventId, RateLimitState>>>,
pub bad_signature_ratelimiter: Arc<RwLock<BTreeMap<Vec<String>, RateLimitState>>>,
pub servername_ratelimiter: Arc<RwLock<BTreeMap<Box<ServerName>, Arc<Semaphore>>>>,
@ -69,15 +70,20 @@ impl ServerCertVerifier for MatrixServerVerifier {
impl Globals {
pub fn load(
globals: sled::Tree,
server_signingkeys: sled::Tree,
globals: Arc<dyn Tree>,
server_signingkeys: Arc<dyn Tree>,
config: Config,
) -> Result<Self> {
let bytes = &*globals
.update_and_fetch("keypair", utils::generate_keypair)?
.expect("utils::generate_keypair always returns Some");
let keypair_bytes = globals.get(b"keypair")?.map_or_else(
|| {
let keypair = utils::generate_keypair();
globals.insert(b"keypair", &keypair)?;
Ok::<_, Error>(keypair)
},
|s| Ok(s.to_vec()),
)?;
let mut parts = bytes.splitn(2, |&b| b == 0xff);
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff);
let keypair = utils::string_from_bytes(
// 1. version
@ -102,7 +108,7 @@ impl Globals {
Ok(k) => k,
Err(e) => {
error!("Keypair invalid. Deleting...");
globals.remove("keypair")?;
globals.remove(b"keypair")?;
return Err(e);
}
};
@ -159,13 +165,8 @@ impl Globals {
}
pub fn next_count(&self) -> Result<u64> {
Ok(utils::u64_from_bytes(
&self
.globals
.update_and_fetch(COUNTER, utils::increment)?
.expect("utils::increment will always put in a value"),
)
.map_err(|_| Error::bad_database("Count has invalid bytes."))?)
Ok(utils::u64_from_bytes(&self.globals.increment(COUNTER)?)
.map_err(|_| Error::bad_database("Count has invalid bytes."))?)
}
pub fn current_count(&self) -> Result<u64> {
@ -211,21 +212,30 @@ impl Globals {
/// Remove the outdated keys and insert the new ones.
///
/// This doesn't actually check that the keys provided are newer than the old set.
pub fn add_signing_key(&self, origin: &ServerName, new_keys: &ServerSigningKeys) -> Result<()> {
self.server_signingkeys
.update_and_fetch(origin.as_bytes(), |signingkeys| {
let mut keys = signingkeys
.and_then(|keys| serde_json::from_slice(keys).ok())
.unwrap_or_else(|| {
// Just insert "now", it doesn't matter
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
});
keys.verify_keys
.extend(new_keys.verify_keys.clone().into_iter());
keys.old_verify_keys
.extend(new_keys.old_verify_keys.clone().into_iter());
Some(serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"))
})?;
pub fn add_signing_key(&self, origin: &ServerName, new_keys: ServerSigningKeys) -> Result<()> {
// Not atomic, but this is not critical
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
let mut keys = signingkeys
.and_then(|keys| serde_json::from_slice(&keys).ok())
.unwrap_or_else(|| {
// Just insert "now", it doesn't matter
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
});
let ServerSigningKeys {
verify_keys,
old_verify_keys,
..
} = new_keys;
keys.verify_keys.extend(verify_keys.into_iter());
keys.old_verify_keys.extend(old_verify_keys.into_iter());
self.server_signingkeys.insert(
origin.as_bytes(),
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
)?;
Ok(())
}
@ -254,14 +264,15 @@ impl Globals {
}
pub fn database_version(&self) -> Result<u64> {
self.globals.get("version")?.map_or(Ok(0), |version| {
self.globals.get(b"version")?.map_or(Ok(0), |version| {
utils::u64_from_bytes(&version)
.map_err(|_| Error::bad_database("Database version id is invalid."))
})
}
pub fn bump_database_version(&self, new_version: u64) -> Result<()> {
self.globals.insert("version", &new_version.to_be_bytes())?;
self.globals
.insert(b"version", &new_version.to_be_bytes())?;
Ok(())
}
}