Refactor for structured insertions

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-10-07 17:54:27 +00:00 committed by strawberry
parent 8258d16a94
commit 2ed0c267eb
31 changed files with 364 additions and 621 deletions

View file

@ -3,18 +3,19 @@ use std::{collections::BTreeMap, mem, mem::size_of, sync::Arc};
use conduit::{
debug_warn, err, utils,
utils::{stream::TryIgnore, string::Unquoted, ReadyExt},
warn, Err, Error, Result, Server,
Err, Error, Result, Server,
};
use database::{Deserialized, Ignore, Interfix, Map};
use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt};
use database::{Deserialized, Ignore, Interfix, Json, Map};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
use ruma::{
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::{ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType, StateEventType},
events::{ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType},
serde::Raw,
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId,
OwnedMxcUri, OwnedUserId, UInt, UserId,
};
use serde_json::json;
use crate::{account_data, admin, globals, rooms, Dep};
@ -194,22 +195,16 @@ impl Service {
/// Hash and set the user's password to the Argon2 hash
pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password {
if let Ok(hash) = utils::hash::password(password) {
self.db
.userid_password
.insert(user_id.as_bytes(), hash.as_bytes());
Ok(())
} else {
Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Password does not meet the requirements.",
))
}
} else {
self.db.userid_password.insert(user_id.as_bytes(), b"");
Ok(())
}
password
.map(utils::hash::password)
.transpose()
.map_err(|e| err!(Request(InvalidParam("Password does not meet the requirements: {e}"))))?
.map_or_else(
|| self.db.userid_password.insert(user_id, b""),
|hash| self.db.userid_password.insert(user_id, hash),
);
Ok(())
}
/// Returns the displayname of a user on this homeserver.
@ -221,11 +216,9 @@ impl Service {
/// need to nofify all rooms of this change.
pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) {
if let Some(displayname) = displayname {
self.db
.userid_displayname
.insert(user_id.as_bytes(), displayname.as_bytes());
self.db.userid_displayname.insert(user_id, displayname);
} else {
self.db.userid_displayname.remove(user_id.as_bytes());
self.db.userid_displayname.remove(user_id);
}
}
@ -237,11 +230,9 @@ impl Service {
/// Sets a new avatar_url or removes it if avatar_url is None.
pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) {
if let Some(avatar_url) = avatar_url {
self.db
.userid_avatarurl
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes());
self.db.userid_avatarurl.insert(user_id, &avatar_url);
} else {
self.db.userid_avatarurl.remove(user_id.as_bytes());
self.db.userid_avatarurl.remove(user_id);
}
}
@ -253,11 +244,9 @@ impl Service {
/// Sets a new avatar_url or removes it if avatar_url is None.
pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) {
if let Some(blurhash) = blurhash {
self.db
.userid_blurhash
.insert(user_id.as_bytes(), blurhash.as_bytes());
self.db.userid_blurhash.insert(user_id, blurhash);
} else {
self.db.userid_blurhash.remove(user_id.as_bytes());
self.db.userid_blurhash.remove(user_id);
}
}
@ -269,41 +258,29 @@ impl Service {
// This method should never be called for nonexistent users. We shouldn't assert
// though...
if !self.exists(user_id).await {
warn!("Called create_device for non-existent user {} in database", user_id);
return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."));
return Err!(Request(InvalidParam(error!("Called create_device for non-existent {user_id}"))));
}
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
let key = (user_id, device_id);
let val = Device {
device_id: device_id.into(),
display_name: initial_device_display_name,
last_seen_ip: client_ip,
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
};
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
self.db.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(&Device {
device_id: device_id.into(),
display_name: initial_device_display_name,
last_seen_ip: client_ip,
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
})
.expect("Device::to_string never fails."),
);
self.set_token(user_id, device_id, token).await?;
Ok(())
self.db.userdeviceid_metadata.put(key, Json(val));
self.set_token(user_id, device_id, token).await
}
/// Removes a device from a user.
pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
let userdeviceid = (user_id, device_id);
// Remove tokens
if let Ok(old_token) = self.db.userdeviceid_token.get(&userdeviceid).await {
self.db.userdeviceid_token.remove(&userdeviceid);
if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await {
self.db.userdeviceid_token.del(userdeviceid);
self.db.token_userdeviceid.remove(&old_token);
}
@ -320,7 +297,7 @@ impl Service {
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
self.db.userdeviceid_metadata.remove(&userdeviceid);
self.db.userdeviceid_metadata.del(userdeviceid);
}
/// Returns an iterator over all device ids of this user.
@ -333,6 +310,11 @@ impl Service {
.map(|(_, device_id): (Ignore, &DeviceId)| device_id)
}
pub async fn get_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result<String> {
let key = (user_id, device_id);
self.db.userdeviceid_token.qry(&key).await.deserialized()
}
/// Replaces the access token of one device.
pub async fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
let key = (user_id, device_id);
@ -352,15 +334,8 @@ impl Service {
}
// Assign token to user device combination
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.db
.userdeviceid_token
.insert(&userdeviceid, token.as_bytes());
self.db
.token_userdeviceid
.insert(token.as_bytes(), &userdeviceid);
self.db.userdeviceid_token.put_raw(key, token);
self.db.token_userdeviceid.raw_put(token, key);
Ok(())
}
@ -393,14 +368,12 @@ impl Service {
.as_bytes(),
);
self.db.onetimekeyid_onetimekeys.insert(
&key,
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
);
self.db
.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes());
.onetimekeyid_onetimekeys
.raw_put(key, Json(one_time_key_value));
let count = self.services.globals.next_count().unwrap();
self.db.userid_lastonetimekeyupdate.raw_put(user_id, count);
Ok(())
}
@ -417,9 +390,8 @@ impl Service {
pub async fn take_one_time_key(
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm,
) -> Result<(OwnedDeviceKeyId, Raw<OneTimeKey>)> {
self.db
.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes());
let count = self.services.globals.next_count()?.to_be_bytes();
self.db.userid_lastonetimekeyupdate.insert(user_id, count);
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
@ -488,15 +460,9 @@ impl Service {
}
pub async fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.db.keyid_key.insert(
&userdeviceid,
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
);
let key = (user_id, device_id);
self.db.keyid_key.put(key, Json(device_keys));
self.mark_device_key_update(user_id).await;
}
@ -611,13 +577,8 @@ impl Service {
.ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))?
.insert(signature.0, signature.1.into());
let mut key = target_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(key_id.as_bytes());
self.db.keyid_key.insert(
&key,
&serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"),
);
let key = (target_id, key_id);
self.db.keyid_key.put(key, Json(cross_signing_key));
self.mark_device_key_update(target_id).await;
@ -640,34 +601,21 @@ impl Service {
}
pub async fn mark_device_key_update(&self, user_id: &UserId) {
let count = self.services.globals.next_count().unwrap().to_be_bytes();
let count = self.services.globals.next_count().unwrap();
let rooms_joined = self.services.state_cache.rooms_joined(user_id);
pin_mut!(rooms_joined);
while let Some(room_id) = rooms_joined.next().await {
self.services
.state_cache
.rooms_joined(user_id)
// Don't send key updates to unencrypted rooms
if self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomEncryption, "")
.await
.is_err()
{
continue;
}
.filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id))
.ready_for_each(|room_id| {
let key = (room_id, count);
self.db.keychangeid_userid.put_raw(key, user_id);
})
.await;
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&count);
self.db.keychangeid_userid.insert(&key, user_id.as_bytes());
}
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&count);
self.db.keychangeid_userid.insert(&key, user_id.as_bytes());
let key = (user_id, count);
self.db.keychangeid_userid.put_raw(key, user_id);
}
pub async fn get_device_keys<'a>(&'a self, user_id: &'a UserId, device_id: &DeviceId) -> Result<Raw<DeviceKeys>> {
@ -681,12 +629,7 @@ impl Service {
where
F: Fn(&UserId) -> bool + Send + Sync,
{
let key = self
.db
.keyid_key
.get(key_id)
.await
.deserialized::<serde_json::Value>()?;
let key: serde_json::Value = self.db.keyid_key.get(key_id).await.deserialized()?;
let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?;
let raw_value = serde_json::value::to_raw_value(&cleaned)?;
@ -718,29 +661,29 @@ impl Service {
}
pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result<Raw<CrossSigningKey>> {
let key_id = self.db.userid_usersigningkeyid.get(user_id).await?;
self.db.keyid_key.get(&*key_id).await.deserialized()
self.db
.userid_usersigningkeyid
.get(user_id)
.and_then(|key_id| self.db.keyid_key.get(&*key_id))
.await
.deserialized()
}
pub async fn add_to_device_event(
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
content: serde_json::Value,
) {
let mut key = target_user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(target_device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes());
let count = self.services.globals.next_count().unwrap();
let mut json = serde_json::Map::new();
json.insert("type".to_owned(), event_type.to_owned().into());
json.insert("sender".to_owned(), sender.to_string().into());
json.insert("content".to_owned(), content);
let value = serde_json::to_vec(&json).expect("Map::to_vec always works");
self.db.todeviceid_events.insert(&key, &value);
let key = (target_user_id, target_device_id, count);
self.db.todeviceid_events.put(
key,
Json(json!({
"type": event_type,
"sender": sender,
"content": content,
})),
);
}
pub fn get_to_device_events<'a>(
@ -783,13 +726,8 @@ impl Service {
pub async fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> {
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes());
self.db.userdeviceid_metadata.insert(
&userdeviceid,
&serde_json::to_vec(device).expect("Device::to_string always works"),
);
let key = (user_id, device_id);
self.db.userdeviceid_metadata.put(key, Json(device));
Ok(())
}
@ -824,23 +762,15 @@ impl Service {
pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> String {
let filter_id = utils::random_string(4);
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes());
self.db
.userfilterid_filter
.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"));
let key = (user_id, &filter_id);
self.db.userfilterid_filter.put(key, Json(filter));
filter_id
}
pub async fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<FilterDefinition> {
self.db
.userfilterid_filter
.qry(&(user_id, filter_id))
.await
.deserialized()
let key = (user_id, filter_id);
self.db.userfilterid_filter.qry(&key).await.deserialized()
}
/// Creates an OpenID token, which can be used to prove that a user has
@ -913,17 +843,13 @@ impl Service {
/// Sets a new profile key value, removes the key if value is None
pub fn set_profile_key(&self, user_id: &UserId, profile_key: &str, profile_key_value: Option<serde_json::Value>) {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(profile_key.as_bytes());
// TODO: insert to the stable MSC4175 key when it's stable
if let Some(value) = profile_key_value {
let value = serde_json::to_vec(&value).unwrap();
let key = (user_id, profile_key);
self.db.useridprofilekey_value.insert(&key, &value);
if let Some(value) = profile_key_value {
self.db.useridprofilekey_value.put(key, value);
} else {
self.db.useridprofilekey_value.remove(&key);
self.db.useridprofilekey_value.del(key);
}
}
@ -945,17 +871,13 @@ impl Service {
/// Sets a new timezone or removes it if timezone is None.
pub fn set_timezone(&self, user_id: &UserId, timezone: Option<String>) {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(b"us.cloke.msc4175.tz");
// TODO: insert to the stable MSC4175 key when it's stable
let key = (user_id, "us.cloke.msc4175.tz");
if let Some(timezone) = timezone {
self.db
.useridprofilekey_value
.insert(&key, timezone.as_bytes());
self.db.useridprofilekey_value.put_raw(key, &timezone);
} else {
self.db.useridprofilekey_value.remove(&key);
self.db.useridprofilekey_value.del(key);
}
}
}
@ -1012,5 +934,5 @@ where
fn increment(db: &Arc<Map>, key: &[u8]) {
let old = db.get_blocking(key);
let new = utils::increment(old.ok().as_deref());
db.insert(key, &new);
db.insert(key, new);
}