Refactor for structured insertions
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
8258d16a94
commit
2ed0c267eb
31 changed files with 364 additions and 621 deletions
|
@ -3,18 +3,19 @@ use std::{collections::BTreeMap, mem, mem::size_of, sync::Arc};
|
|||
use conduit::{
|
||||
debug_warn, err, utils,
|
||||
utils::{stream::TryIgnore, string::Unquoted, ReadyExt},
|
||||
warn, Err, Error, Result, Server,
|
||||
Err, Error, Result, Server,
|
||||
};
|
||||
use database::{Deserialized, Ignore, Interfix, Map};
|
||||
use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt};
|
||||
use database::{Deserialized, Ignore, Interfix, Json, Map};
|
||||
use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
|
||||
use ruma::{
|
||||
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
|
||||
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
|
||||
events::{ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType, StateEventType},
|
||||
events::{ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType},
|
||||
serde::Raw,
|
||||
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId,
|
||||
OwnedMxcUri, OwnedUserId, UInt, UserId,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{account_data, admin, globals, rooms, Dep};
|
||||
|
||||
|
@ -194,22 +195,16 @@ impl Service {
|
|||
|
||||
/// Hash and set the user's password to the Argon2 hash
|
||||
pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
|
||||
if let Some(password) = password {
|
||||
if let Ok(hash) = utils::hash::password(password) {
|
||||
self.db
|
||||
.userid_password
|
||||
.insert(user_id.as_bytes(), hash.as_bytes());
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Password does not meet the requirements.",
|
||||
))
|
||||
}
|
||||
} else {
|
||||
self.db.userid_password.insert(user_id.as_bytes(), b"");
|
||||
Ok(())
|
||||
}
|
||||
password
|
||||
.map(utils::hash::password)
|
||||
.transpose()
|
||||
.map_err(|e| err!(Request(InvalidParam("Password does not meet the requirements: {e}"))))?
|
||||
.map_or_else(
|
||||
|| self.db.userid_password.insert(user_id, b""),
|
||||
|hash| self.db.userid_password.insert(user_id, hash),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the displayname of a user on this homeserver.
|
||||
|
@ -221,11 +216,9 @@ impl Service {
|
|||
/// need to nofify all rooms of this change.
|
||||
pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) {
|
||||
if let Some(displayname) = displayname {
|
||||
self.db
|
||||
.userid_displayname
|
||||
.insert(user_id.as_bytes(), displayname.as_bytes());
|
||||
self.db.userid_displayname.insert(user_id, displayname);
|
||||
} else {
|
||||
self.db.userid_displayname.remove(user_id.as_bytes());
|
||||
self.db.userid_displayname.remove(user_id);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -237,11 +230,9 @@ impl Service {
|
|||
/// Sets a new avatar_url or removes it if avatar_url is None.
|
||||
pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) {
|
||||
if let Some(avatar_url) = avatar_url {
|
||||
self.db
|
||||
.userid_avatarurl
|
||||
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes());
|
||||
self.db.userid_avatarurl.insert(user_id, &avatar_url);
|
||||
} else {
|
||||
self.db.userid_avatarurl.remove(user_id.as_bytes());
|
||||
self.db.userid_avatarurl.remove(user_id);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -253,11 +244,9 @@ impl Service {
|
|||
/// Sets a new avatar_url or removes it if avatar_url is None.
|
||||
pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) {
|
||||
if let Some(blurhash) = blurhash {
|
||||
self.db
|
||||
.userid_blurhash
|
||||
.insert(user_id.as_bytes(), blurhash.as_bytes());
|
||||
self.db.userid_blurhash.insert(user_id, blurhash);
|
||||
} else {
|
||||
self.db.userid_blurhash.remove(user_id.as_bytes());
|
||||
self.db.userid_blurhash.remove(user_id);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -269,41 +258,29 @@ impl Service {
|
|||
// This method should never be called for nonexistent users. We shouldn't assert
|
||||
// though...
|
||||
if !self.exists(user_id).await {
|
||||
warn!("Called create_device for non-existent user {} in database", user_id);
|
||||
return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."));
|
||||
return Err!(Request(InvalidParam(error!("Called create_device for non-existent {user_id}"))));
|
||||
}
|
||||
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xFF);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
let key = (user_id, device_id);
|
||||
let val = Device {
|
||||
device_id: device_id.into(),
|
||||
display_name: initial_device_display_name,
|
||||
last_seen_ip: client_ip,
|
||||
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
|
||||
};
|
||||
|
||||
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
||||
|
||||
self.db.userdeviceid_metadata.insert(
|
||||
&userdeviceid,
|
||||
&serde_json::to_vec(&Device {
|
||||
device_id: device_id.into(),
|
||||
display_name: initial_device_display_name,
|
||||
last_seen_ip: client_ip,
|
||||
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
|
||||
})
|
||||
.expect("Device::to_string never fails."),
|
||||
);
|
||||
|
||||
self.set_token(user_id, device_id, token).await?;
|
||||
|
||||
Ok(())
|
||||
self.db.userdeviceid_metadata.put(key, Json(val));
|
||||
self.set_token(user_id, device_id, token).await
|
||||
}
|
||||
|
||||
/// Removes a device from a user.
|
||||
pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xFF);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
let userdeviceid = (user_id, device_id);
|
||||
|
||||
// Remove tokens
|
||||
if let Ok(old_token) = self.db.userdeviceid_token.get(&userdeviceid).await {
|
||||
self.db.userdeviceid_token.remove(&userdeviceid);
|
||||
if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await {
|
||||
self.db.userdeviceid_token.del(userdeviceid);
|
||||
self.db.token_userdeviceid.remove(&old_token);
|
||||
}
|
||||
|
||||
|
@ -320,7 +297,7 @@ impl Service {
|
|||
|
||||
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
||||
|
||||
self.db.userdeviceid_metadata.remove(&userdeviceid);
|
||||
self.db.userdeviceid_metadata.del(userdeviceid);
|
||||
}
|
||||
|
||||
/// Returns an iterator over all device ids of this user.
|
||||
|
@ -333,6 +310,11 @@ impl Service {
|
|||
.map(|(_, device_id): (Ignore, &DeviceId)| device_id)
|
||||
}
|
||||
|
||||
pub async fn get_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result<String> {
|
||||
let key = (user_id, device_id);
|
||||
self.db.userdeviceid_token.qry(&key).await.deserialized()
|
||||
}
|
||||
|
||||
/// Replaces the access token of one device.
|
||||
pub async fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
|
||||
let key = (user_id, device_id);
|
||||
|
@ -352,15 +334,8 @@ impl Service {
|
|||
}
|
||||
|
||||
// Assign token to user device combination
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xFF);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
self.db
|
||||
.userdeviceid_token
|
||||
.insert(&userdeviceid, token.as_bytes());
|
||||
self.db
|
||||
.token_userdeviceid
|
||||
.insert(token.as_bytes(), &userdeviceid);
|
||||
self.db.userdeviceid_token.put_raw(key, token);
|
||||
self.db.token_userdeviceid.raw_put(token, key);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -393,14 +368,12 @@ impl Service {
|
|||
.as_bytes(),
|
||||
);
|
||||
|
||||
self.db.onetimekeyid_onetimekeys.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
|
||||
);
|
||||
|
||||
self.db
|
||||
.userid_lastonetimekeyupdate
|
||||
.insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes());
|
||||
.onetimekeyid_onetimekeys
|
||||
.raw_put(key, Json(one_time_key_value));
|
||||
|
||||
let count = self.services.globals.next_count().unwrap();
|
||||
self.db.userid_lastonetimekeyupdate.raw_put(user_id, count);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -417,9 +390,8 @@ impl Service {
|
|||
pub async fn take_one_time_key(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm,
|
||||
) -> Result<(OwnedDeviceKeyId, Raw<OneTimeKey>)> {
|
||||
self.db
|
||||
.userid_lastonetimekeyupdate
|
||||
.insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes());
|
||||
let count = self.services.globals.next_count()?.to_be_bytes();
|
||||
self.db.userid_lastonetimekeyupdate.insert(user_id, count);
|
||||
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
@ -488,15 +460,9 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) {
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xFF);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
self.db.keyid_key.insert(
|
||||
&userdeviceid,
|
||||
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
|
||||
);
|
||||
let key = (user_id, device_id);
|
||||
|
||||
self.db.keyid_key.put(key, Json(device_keys));
|
||||
self.mark_device_key_update(user_id).await;
|
||||
}
|
||||
|
||||
|
@ -611,13 +577,8 @@ impl Service {
|
|||
.ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))?
|
||||
.insert(signature.0, signature.1.into());
|
||||
|
||||
let mut key = target_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(key_id.as_bytes());
|
||||
self.db.keyid_key.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"),
|
||||
);
|
||||
let key = (target_id, key_id);
|
||||
self.db.keyid_key.put(key, Json(cross_signing_key));
|
||||
|
||||
self.mark_device_key_update(target_id).await;
|
||||
|
||||
|
@ -640,34 +601,21 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn mark_device_key_update(&self, user_id: &UserId) {
|
||||
let count = self.services.globals.next_count().unwrap().to_be_bytes();
|
||||
let count = self.services.globals.next_count().unwrap();
|
||||
|
||||
let rooms_joined = self.services.state_cache.rooms_joined(user_id);
|
||||
|
||||
pin_mut!(rooms_joined);
|
||||
while let Some(room_id) = rooms_joined.next().await {
|
||||
self.services
|
||||
.state_cache
|
||||
.rooms_joined(user_id)
|
||||
// Don't send key updates to unencrypted rooms
|
||||
if self
|
||||
.services
|
||||
.state_accessor
|
||||
.room_state_get(room_id, &StateEventType::RoomEncryption, "")
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
.filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id))
|
||||
.ready_for_each(|room_id| {
|
||||
let key = (room_id, count);
|
||||
self.db.keychangeid_userid.put_raw(key, user_id);
|
||||
})
|
||||
.await;
|
||||
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(&count);
|
||||
|
||||
self.db.keychangeid_userid.insert(&key, user_id.as_bytes());
|
||||
}
|
||||
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(&count);
|
||||
self.db.keychangeid_userid.insert(&key, user_id.as_bytes());
|
||||
let key = (user_id, count);
|
||||
self.db.keychangeid_userid.put_raw(key, user_id);
|
||||
}
|
||||
|
||||
pub async fn get_device_keys<'a>(&'a self, user_id: &'a UserId, device_id: &DeviceId) -> Result<Raw<DeviceKeys>> {
|
||||
|
@ -681,12 +629,7 @@ impl Service {
|
|||
where
|
||||
F: Fn(&UserId) -> bool + Send + Sync,
|
||||
{
|
||||
let key = self
|
||||
.db
|
||||
.keyid_key
|
||||
.get(key_id)
|
||||
.await
|
||||
.deserialized::<serde_json::Value>()?;
|
||||
let key: serde_json::Value = self.db.keyid_key.get(key_id).await.deserialized()?;
|
||||
|
||||
let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?;
|
||||
let raw_value = serde_json::value::to_raw_value(&cleaned)?;
|
||||
|
@ -718,29 +661,29 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result<Raw<CrossSigningKey>> {
|
||||
let key_id = self.db.userid_usersigningkeyid.get(user_id).await?;
|
||||
|
||||
self.db.keyid_key.get(&*key_id).await.deserialized()
|
||||
self.db
|
||||
.userid_usersigningkeyid
|
||||
.get(user_id)
|
||||
.and_then(|key_id| self.db.keyid_key.get(&*key_id))
|
||||
.await
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
pub async fn add_to_device_event(
|
||||
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
|
||||
content: serde_json::Value,
|
||||
) {
|
||||
let mut key = target_user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(target_device_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes());
|
||||
let count = self.services.globals.next_count().unwrap();
|
||||
|
||||
let mut json = serde_json::Map::new();
|
||||
json.insert("type".to_owned(), event_type.to_owned().into());
|
||||
json.insert("sender".to_owned(), sender.to_string().into());
|
||||
json.insert("content".to_owned(), content);
|
||||
|
||||
let value = serde_json::to_vec(&json).expect("Map::to_vec always works");
|
||||
|
||||
self.db.todeviceid_events.insert(&key, &value);
|
||||
let key = (target_user_id, target_device_id, count);
|
||||
self.db.todeviceid_events.put(
|
||||
key,
|
||||
Json(json!({
|
||||
"type": event_type,
|
||||
"sender": sender,
|
||||
"content": content,
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn get_to_device_events<'a>(
|
||||
|
@ -783,13 +726,8 @@ impl Service {
|
|||
pub async fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> {
|
||||
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
||||
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xFF);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
self.db.userdeviceid_metadata.insert(
|
||||
&userdeviceid,
|
||||
&serde_json::to_vec(device).expect("Device::to_string always works"),
|
||||
);
|
||||
let key = (user_id, device_id);
|
||||
self.db.userdeviceid_metadata.put(key, Json(device));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -824,23 +762,15 @@ impl Service {
|
|||
pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> String {
|
||||
let filter_id = utils::random_string(4);
|
||||
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(filter_id.as_bytes());
|
||||
|
||||
self.db
|
||||
.userfilterid_filter
|
||||
.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"));
|
||||
let key = (user_id, &filter_id);
|
||||
self.db.userfilterid_filter.put(key, Json(filter));
|
||||
|
||||
filter_id
|
||||
}
|
||||
|
||||
pub async fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<FilterDefinition> {
|
||||
self.db
|
||||
.userfilterid_filter
|
||||
.qry(&(user_id, filter_id))
|
||||
.await
|
||||
.deserialized()
|
||||
let key = (user_id, filter_id);
|
||||
self.db.userfilterid_filter.qry(&key).await.deserialized()
|
||||
}
|
||||
|
||||
/// Creates an OpenID token, which can be used to prove that a user has
|
||||
|
@ -913,17 +843,13 @@ impl Service {
|
|||
|
||||
/// Sets a new profile key value, removes the key if value is None
|
||||
pub fn set_profile_key(&self, user_id: &UserId, profile_key: &str, profile_key_value: Option<serde_json::Value>) {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(profile_key.as_bytes());
|
||||
|
||||
// TODO: insert to the stable MSC4175 key when it's stable
|
||||
if let Some(value) = profile_key_value {
|
||||
let value = serde_json::to_vec(&value).unwrap();
|
||||
let key = (user_id, profile_key);
|
||||
|
||||
self.db.useridprofilekey_value.insert(&key, &value);
|
||||
if let Some(value) = profile_key_value {
|
||||
self.db.useridprofilekey_value.put(key, value);
|
||||
} else {
|
||||
self.db.useridprofilekey_value.remove(&key);
|
||||
self.db.useridprofilekey_value.del(key);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -945,17 +871,13 @@ impl Service {
|
|||
|
||||
/// Sets a new timezone or removes it if timezone is None.
|
||||
pub fn set_timezone(&self, user_id: &UserId, timezone: Option<String>) {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(b"us.cloke.msc4175.tz");
|
||||
|
||||
// TODO: insert to the stable MSC4175 key when it's stable
|
||||
let key = (user_id, "us.cloke.msc4175.tz");
|
||||
|
||||
if let Some(timezone) = timezone {
|
||||
self.db
|
||||
.useridprofilekey_value
|
||||
.insert(&key, timezone.as_bytes());
|
||||
self.db.useridprofilekey_value.put_raw(key, &timezone);
|
||||
} else {
|
||||
self.db.useridprofilekey_value.remove(&key);
|
||||
self.db.useridprofilekey_value.del(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1012,5 +934,5 @@ where
|
|||
fn increment(db: &Arc<Map>, key: &[u8]) {
|
||||
let old = db.get_blocking(key);
|
||||
let new = utils::increment(old.ok().as_deref());
|
||||
db.insert(key, &new);
|
||||
db.insert(key, new);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue