refactor: split database into multiple files, more error handling, cleaner code

This commit is contained in:
timokoesters 2020-05-03 17:25:31 +02:00
parent 4b191a9311
commit 8f67c01efd
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
17 changed files with 1573 additions and 1630 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,92 +1,19 @@
use crate::utils;
pub(self) mod account_data;
pub(self) mod globals;
pub(self) mod rooms;
pub(self) mod users;
use directories::ProjectDirs;
use sled::IVec;
use std::fs::remove_dir_all;
pub struct MultiValue(sled::Tree);
pub const COUNTER: &str = "c";
impl MultiValue {
/// Get an iterator over all values.
pub fn iter_all(&self) -> sled::Iter {
self.0.scan_prefix(b"d")
}
/// Get an iterator over all values of this id.
pub fn get_iter(&self, id: &[u8]) -> sled::Iter {
// Data keys start with d
let mut key = vec![b'd'];
key.extend_from_slice(id.as_ref());
key.push(0xff); // Add delimiter so we don't find keys starting with the same id
self.0.scan_prefix(key)
}
pub fn clear(&self, id: &[u8]) {
for key in self.get_iter(id).keys() {
self.0.remove(key.unwrap()).unwrap();
}
}
pub fn remove_value(&self, id: &[u8], value: &[u8]) {
if let Some(key) = self
.get_iter(id)
.find(|t| &t.as_ref().unwrap().1 == value)
.map(|t| t.unwrap().0)
{
self.0.remove(key).unwrap();
}
}
/// Add another value to the id.
pub fn add(&self, id: &[u8], value: IVec) {
// The new value will need a new index. We store the last used index in 'n' + id
let mut count_key: Vec<u8> = vec![b'n'];
count_key.extend_from_slice(id.as_ref());
// Increment the last index and use that
let index = self
.0
.update_and_fetch(&count_key, utils::increment)
.unwrap()
.unwrap();
// Data keys start with d
let mut key = vec![b'd'];
key.extend_from_slice(id.as_ref());
key.push(0xff);
key.extend_from_slice(&index);
self.0.insert(key, value).unwrap();
}
}
pub struct Database {
pub userid_password: sled::Tree,
pub userid_displayname: sled::Tree,
pub userid_avatarurl: sled::Tree,
pub userid_deviceids: MultiValue,
pub userdeviceid_token: sled::Tree,
pub token_userid: sled::Tree,
pub pduid_pdu: sled::Tree, // PduId = RoomId + Count
pub eventid_pduid: sled::Tree,
pub roomid_pduleaves: MultiValue,
pub roomstateid_pdu: sled::Tree, // Room + StateType + StateKey
pub roomuserdataid_accountdata: sled::Tree, // RoomUserDataId = Room + User + Count + Type
pub roomuserid_lastread: sled::Tree, // RoomUserId = Room + User
pub roomid_joinuserids: MultiValue,
pub roomid_inviteuserids: MultiValue,
pub userid_joinroomids: MultiValue,
pub userid_inviteroomids: MultiValue,
pub userid_leftroomids: MultiValue,
// EDUs:
pub roomlatestid_roomlatest: sled::Tree, // Read Receipts, RoomLatestId = RoomId + Count + UserId TODO: Types
pub roomactiveid_roomactive: sled::Tree, // Typing, RoomActiveId = TimeoutTime + Count
pub globalallid_globalall: sled::Tree, // ToDevice, GlobalAllId = UserId + Count
pub globallatestid_globallatest: sled::Tree, // Presence, GlobalLatestId = Count + Type + UserId
pub keypair: ruma_signatures::Ed25519KeyPair,
pub global: sled::Db,
pub globals: globals::Globals,
pub users: users::Users,
pub rooms: rooms::Rooms,
pub account_data: account_data::AccountData,
//pub globalallid_globalall: sled::Tree, // ToDevice, GlobalAllId = UserId + Count
//pub globallatestid_globallatest: sled::Tree, // Presence, GlobalLatestId = Count + Type + UserId
pub _db: sled::Db,
}
impl Database {
@ -110,166 +37,38 @@ impl Database {
let db = sled::open(&path).unwrap();
Self {
userid_password: db.open_tree("userid_password").unwrap(),
userid_deviceids: MultiValue(db.open_tree("userid_deviceids").unwrap()),
userid_displayname: db.open_tree("userid_displayname").unwrap(),
userid_avatarurl: db.open_tree("userid_avatarurl").unwrap(),
userdeviceid_token: db.open_tree("userdeviceid_token").unwrap(),
token_userid: db.open_tree("token_userid").unwrap(),
pduid_pdu: db.open_tree("pduid_pdu").unwrap(),
eventid_pduid: db.open_tree("eventid_pduid").unwrap(),
roomid_pduleaves: MultiValue(db.open_tree("roomid_pduleaves").unwrap()),
roomstateid_pdu: db.open_tree("roomstateid_pdu").unwrap(),
roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata").unwrap(),
roomuserid_lastread: db.open_tree("roomuserid_lastread").unwrap(),
roomid_joinuserids: MultiValue(db.open_tree("roomid_joinuserids").unwrap()),
roomid_inviteuserids: MultiValue(db.open_tree("roomid_inviteuserids").unwrap()),
userid_joinroomids: MultiValue(db.open_tree("userid_joinroomids").unwrap()),
userid_inviteroomids: MultiValue(db.open_tree("userid_inviteroomids").unwrap()),
userid_leftroomids: MultiValue(db.open_tree("userid_leftroomids").unwrap()),
roomlatestid_roomlatest: db.open_tree("roomlatestid_roomlatest").unwrap(),
roomactiveid_roomactive: db.open_tree("roomactiveid_roomactive").unwrap(),
globalallid_globalall: db.open_tree("globalallid_globalall").unwrap(),
globallatestid_globallatest: db.open_tree("globallatestid_globallatest").unwrap(),
keypair: ruma_signatures::Ed25519KeyPair::new(
&*db.update_and_fetch("keypair", utils::generate_keypair)
.unwrap()
.unwrap(),
"key1".to_owned(),
)
.unwrap(),
global: db,
}
}
globals: globals::Globals::load(db.open_tree("global").unwrap(), hostname.to_owned()),
users: users::Users {
userid_password: db.open_tree("userid_password").unwrap(),
userdeviceid: db.open_tree("userdeviceid").unwrap(),
userid_displayname: db.open_tree("userid_displayname").unwrap(),
userid_avatarurl: db.open_tree("userid_avatarurl").unwrap(),
userdeviceid_token: db.open_tree("userdeviceid_token").unwrap(),
token_userid: db.open_tree("token_userid").unwrap(),
},
rooms: rooms::Rooms {
edus: rooms::RoomEdus {
roomuserid_lastread: db.open_tree("roomuserid_lastread").unwrap(),
roomlatestid_roomlatest: db.open_tree("roomlatestid_roomlatest").unwrap(),
roomactiveid_roomactive: db.open_tree("roomactiveid_roomactive").unwrap(),
},
pduid_pdu: db.open_tree("pduid_pdu").unwrap(),
eventid_pduid: db.open_tree("eventid_pduid").unwrap(),
roomid_pduleaves: db.open_tree("roomid_pduleaves").unwrap(),
roomstateid_pdu: db.open_tree("roomstateid_pdu").unwrap(),
pub fn debug(&self) {
println!("# UserId -> Password:");
for (k, v) in self.userid_password.iter().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# UserId -> DeviceIds:");
for (k, v) in self.userid_deviceids.iter_all().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# UserId -> Displayname:");
for (k, v) in self.userid_displayname.iter().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# UserId -> AvatarURL:");
for (k, v) in self.userid_avatarurl.iter().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# UserId+DeviceId -> Token:");
for (k, v) in self.userdeviceid_token.iter().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# Token -> UserId:");
for (k, v) in self.token_userid.iter().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# RoomId -> PDU leaves:");
for (k, v) in self.roomid_pduleaves.iter_all().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# RoomStateId -> PDU:");
for (k, v) in self.roomstateid_pdu.iter().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# RoomId -> UserIds:");
for (k, v) in self.roomid_joinuserids.iter_all().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# UserId -> RoomIds:");
for (k, v) in self.userid_joinroomids.iter_all().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# PDU Id -> PDU:");
for (k, v) in self.pduid_pdu.iter().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# EventId -> PDU Id:");
for (k, v) in self.eventid_pduid.iter().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# RoomLatestId -> RoomLatest:");
for (k, v) in self.roomlatestid_roomlatest.iter().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# RoomActiveId -> RoomActives:");
for (k, v) in self.roomactiveid_roomactive.iter().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# GlobalAllId -> GlobalAll:");
for (k, v) in self.globalallid_globalall.iter().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# GlobalLatestId -> GlobalLatest:");
for (k, v) in self.globallatestid_globallatest.iter().map(|r| r.unwrap()) {
println!(
"{:?} -> {:?}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
userroomid_joined: db.open_tree("userroomid_joined").unwrap(),
roomuserid_joined: db.open_tree("roomuserid_joined").unwrap(),
userroomid_invited: db.open_tree("userroomid_invited").unwrap(),
roomuserid_invited: db.open_tree("roomuserid_invited").unwrap(),
userroomid_left: db.open_tree("userroomid_left").unwrap(),
},
account_data: account_data::AccountData {
roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata").unwrap(),
},
//globalallid_globalall: db.open_tree("globalallid_globalall").unwrap(),
//globallatestid_globallatest: db.open_tree("globallatestid_globallatest").unwrap(),
_db: db,
}
}
}

View file

@ -0,0 +1,120 @@
use crate::Result;
use ruma_events::{collections::only::Event as EduEvent, EventJson};
use ruma_identifiers::{RoomId, UserId};
use std::collections::HashMap;
pub struct AccountData {
pub(super) roomuserdataid_accountdata: sled::Tree, // RoomUserDataId = Room + User + Count + Type
}
impl AccountData {
/// Places one event in the account data of the user and removes the previous entry.
pub fn update(
&self,
room_id: Option<&RoomId>,
user_id: &UserId,
event: EduEvent,
globals: &super::globals::Globals,
) -> Result<()> {
let mut prefix = room_id
.map(|r| r.to_string())
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xff);
prefix.extend_from_slice(&user_id.to_string().as_bytes());
prefix.push(0xff);
// Remove old entry
if let Some(old) = self
.roomuserdataid_accountdata
.scan_prefix(&prefix)
.keys()
.rev()
.filter_map(|r| r.ok())
.take_while(|key| key.starts_with(&prefix))
.filter(|key| {
key.split(|&b| b == 0xff)
.nth(1)
.filter(|&user| user == user_id.to_string().as_bytes())
.is_some()
})
.next()
{
// This is the old room_latest
self.roomuserdataid_accountdata.remove(old)?;
println!("removed old account data");
}
let mut key = prefix;
key.extend_from_slice(&globals.next_count()?.to_be_bytes());
key.push(0xff);
let json = serde_json::to_value(&event)?;
key.extend_from_slice(json["type"].as_str().unwrap().as_bytes());
self.roomuserdataid_accountdata
.insert(key, &*json.to_string())
.unwrap();
Ok(())
}
// TODO: Optimize
/// Searches the account data for a specific kind.
pub fn get(
&self,
room_id: Option<&RoomId>,
user_id: &UserId,
kind: &str,
) -> Result<Option<EventJson<EduEvent>>> {
Ok(self.all(room_id, user_id)?.remove(kind))
}
/// Returns all changes to the account data that happened after `since`.
pub fn changes_since(
&self,
room_id: Option<&RoomId>,
user_id: &UserId,
since: u64,
) -> Result<HashMap<String, EventJson<EduEvent>>> {
let mut userdata = HashMap::new();
let mut prefix = room_id
.map(|r| r.to_string())
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xff);
prefix.extend_from_slice(&user_id.to_string().as_bytes());
prefix.push(0xff);
// Skip the data that's exactly at since, because we sent that last time
let mut first_possible = prefix.clone();
first_possible.extend_from_slice(&(since + 1).to_be_bytes());
for json in self
.roomuserdataid_accountdata
.range(&*first_possible..)
.filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| serde_json::from_slice::<serde_json::Value>(&v).unwrap())
{
userdata.insert(
json["type"].as_str().unwrap().to_owned(),
serde_json::from_value::<EventJson<EduEvent>>(json)
.expect("userdata in db is valid"),
);
}
Ok(userdata)
}
/// Returns all account data.
pub fn all(
&self,
room_id: Option<&RoomId>,
user_id: &UserId,
) -> Result<HashMap<String, EventJson<EduEvent>>> {
self.changes_since(room_id, user_id, 0)
}
}

61
src/database/globals.rs Normal file
View file

@ -0,0 +1,61 @@
use crate::{utils, Result};
pub const COUNTER: &str = "c";
pub struct Globals {
pub(super) globals: sled::Tree,
hostname: String,
keypair: ruma_signatures::Ed25519KeyPair,
reqwest_client: reqwest::Client,
}
impl Globals {
pub fn load(globals: sled::Tree, hostname: String) -> Self {
let keypair = ruma_signatures::Ed25519KeyPair::new(
&*globals
.update_and_fetch("keypair", utils::generate_keypair)
.unwrap()
.unwrap(),
"key1".to_owned(),
)
.unwrap();
Self {
globals,
hostname,
keypair,
reqwest_client: reqwest::Client::new(),
}
}
/// Returns the hostname of the server.
pub fn hostname(&self) -> &str {
&self.hostname
}
/// Returns this server's keypair.
pub fn keypair(&self) -> &ruma_signatures::Ed25519KeyPair {
&self.keypair
}
/// Returns a reqwest client which can be used to send requests.
pub fn reqwest_client(&self) -> &reqwest::Client {
&self.reqwest_client
}
pub fn next_count(&self) -> Result<u64> {
Ok(utils::u64_from_bytes(
&self
.globals
.update_and_fetch(COUNTER, utils::increment)?
.expect("utils::increment will always put in a value"),
))
}
pub fn current_count(&self) -> Result<u64> {
Ok(self
.globals
.get(COUNTER)?
.map_or(0_u64, |bytes| utils::u64_from_bytes(&bytes)))
}
}

547
src/database/rooms.rs Normal file
View file

@ -0,0 +1,547 @@
mod edus;
pub use edus::RoomEdus;
use crate::{utils, Error, PduEvent, Result};
use ruma_events::{room::power_levels::PowerLevelsEventContent, EventJson, EventType};
use ruma_identifiers::{EventId, RoomId, UserId};
use serde_json::json;
use std::{
collections::HashMap,
convert::{TryFrom, TryInto},
mem,
};
pub struct Rooms {
pub edus: edus::RoomEdus,
pub(super) pduid_pdu: sled::Tree, // PduId = RoomId + Count
pub(super) eventid_pduid: sled::Tree,
pub(super) roomid_pduleaves: sled::Tree,
pub(super) roomstateid_pdu: sled::Tree, // Room + StateType + StateKey
pub(super) userroomid_joined: sled::Tree,
pub(super) roomuserid_joined: sled::Tree,
pub(super) userroomid_invited: sled::Tree,
pub(super) roomuserid_invited: sled::Tree,
pub(super) userroomid_left: sled::Tree,
}
impl Rooms {
/// Checks if a room exists.
pub fn exists(&self, room_id: &RoomId) -> Result<bool> {
// Look for PDUs in that room.
let mut prefix = room_id.to_string().as_bytes().to_vec();
prefix.push(0xff);
Ok(self
.pduid_pdu
.get_gt(&prefix)?
.filter(|(k, _)| k.starts_with(&prefix))
.is_some())
}
// TODO: Remove and replace with public room dir
/// Returns a vector over all rooms.
pub fn all_rooms(&self) -> Vec<RoomId> {
let mut room_ids = self
.roomid_pduleaves
.iter()
.keys()
.map(|key| {
RoomId::try_from(
&*utils::string_from_bytes(
&key.unwrap()
.iter()
.copied()
.take_while(|&x| x != 0xff) // until delimiter
.collect::<Vec<_>>(),
)
.unwrap(),
)
.unwrap()
})
.collect::<Vec<_>>();
room_ids.dedup();
room_ids
}
/// Returns the full room state.
pub fn room_state(&self, room_id: &RoomId) -> Result<HashMap<(EventType, String), PduEvent>> {
let mut hashmap = HashMap::new();
for pdu in self
.roomstateid_pdu
.scan_prefix(&room_id.to_string().as_bytes())
.values()
.map(|value| Ok::<_, Error>(serde_json::from_slice::<PduEvent>(&value?)?))
{
let pdu = pdu?;
hashmap.insert(
(
pdu.kind.clone(),
pdu.state_key
.clone()
.expect("state events have a state key"),
),
pdu,
);
}
Ok(hashmap)
}
/// Returns the `count` of this pdu's id.
pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> {
Ok(self
.eventid_pduid
.get(event_id.to_string().as_bytes())?
.map(|pdu_id| {
utils::u64_from_bytes(&pdu_id[pdu_id.len() - mem::size_of::<u64>()..pdu_id.len()])
}))
}
/// Returns the json of a pdu.
pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<serde_json::Value>> {
self.eventid_pduid
.get(event_id.to_string().as_bytes())?
.map_or(Ok(None), |pdu_id| {
Ok(serde_json::from_slice(
&self.pduid_pdu.get(pdu_id)?.ok_or(Error::BadDatabase(
"eventid_pduid points to nonexistent pdu",
))?,
)?)
.map(Some)
})
}
/// Returns the leaf pdus of a room.
pub fn get_pdu_leaves(&self, room_id: &RoomId) -> Result<Vec<EventId>> {
let mut prefix = room_id.to_string().as_bytes().to_vec();
prefix.push(0xff);
let mut events = Vec::new();
for event in self
.roomid_pduleaves
.scan_prefix(prefix)
.values()
.map(|bytes| Ok::<_, Error>(EventId::try_from(&*utils::string_from_bytes(&bytes?)?)?))
{
events.push(event?);
}
Ok(events)
}
/// Replace the leaves of a room with a new event.
pub fn replace_pdu_leaves(&self, room_id: &RoomId, event_id: &EventId) -> Result<()> {
let mut prefix = room_id.to_string().as_bytes().to_vec();
prefix.push(0xff);
for key in self.roomid_pduleaves.scan_prefix(&prefix).keys() {
self.roomid_pduleaves.remove(key?)?;
}
prefix.extend_from_slice(event_id.to_string().as_bytes());
self.roomid_pduleaves
.insert(&prefix, &*event_id.to_string())?;
Ok(())
}
/// Creates a new persisted data unit and adds it to a room.
pub fn append_pdu(
&self,
room_id: RoomId,
sender: UserId,
event_type: EventType,
content: serde_json::Value,
unsigned: Option<serde_json::Map<String, serde_json::Value>>,
state_key: Option<String>,
globals: &super::globals::Globals,
) -> Result<EventId> {
// Is the event authorized?
if state_key.is_some() {
if let Some(pdu) = self
.room_state(&room_id)?
.get(&(EventType::RoomPowerLevels, "".to_owned()))
{
let power_levels = serde_json::from_value::<EventJson<PowerLevelsEventContent>>(
pdu.content.clone(),
)?
.deserialize()?;
match event_type {
EventType::RoomMember => {
// Member events are okay for now (TODO)
}
_ if power_levels
.users
.get(&sender)
.unwrap_or(&power_levels.users_default)
<= &0.into() =>
{
// Not authorized
return Err(Error::BadRequest("event not authorized"));
}
// User has sufficient power
_ => {}
}
}
}
// prev_events are the leaves of the current graph. This method removes all leaves from the
// room and replaces them with our event
// TODO: Make sure this isn't called twice in parallel
let prev_events = self.get_pdu_leaves(&room_id)?;
// Our depth is the maximum depth of prev_events + 1
let depth = prev_events
.iter()
.filter_map(|event_id| Some(self.get_pdu_json(event_id).ok()??.get("depth")?.as_u64()?))
.max()
.unwrap_or(0_u64)
+ 1;
let mut unsigned = unsigned.unwrap_or_default();
// TODO: Optimize this to not load the whole room state?
if let Some(state_key) = &state_key {
if let Some(prev_pdu) = self
.room_state(&room_id)?
.get(&(event_type.clone(), state_key.clone()))
{
unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone());
}
}
let mut pdu = PduEvent {
event_id: EventId::try_from("$thiswillbefilledinlater").expect("we know this is valid"),
room_id: room_id.clone(),
sender: sender.clone(),
origin: globals.hostname().to_owned(),
origin_server_ts: utils::millis_since_unix_epoch()
.try_into()
.expect("this only fails many years in the future"),
kind: event_type,
content,
state_key,
prev_events,
depth: depth
.try_into()
.expect("depth can overflow and should be deprecated..."),
auth_events: Vec::new(),
redacts: None,
unsigned,
hashes: ruma_federation_api::EventHash {
sha256: "aaa".to_owned(),
},
signatures: HashMap::new(),
};
// Generate event id
pdu.event_id = EventId::try_from(&*format!(
"${}",
ruma_signatures::reference_hash(&serde_json::to_value(&pdu)?)
.expect("ruma can calculate reference hashes")
))
.expect("ruma's reference hashes are correct");
let mut pdu_json = serde_json::to_value(&pdu)?;
ruma_signatures::hash_and_sign_event(globals.hostname(), globals.keypair(), &mut pdu_json)
.expect("our new event can be hashed and signed");
self.replace_pdu_leaves(&room_id, &pdu.event_id)?;
// Increment the last index and use that
// This is also the next_batch/since value
let index = globals.next_count()?;
let mut pdu_id = room_id.to_string().as_bytes().to_vec();
pdu_id.push(0xff);
pdu_id.extend_from_slice(&index.to_be_bytes());
self.pduid_pdu.insert(&pdu_id, &*pdu_json.to_string())?;
self.eventid_pduid
.insert(pdu.event_id.to_string(), pdu_id.clone())?;
if let Some(state_key) = pdu.state_key {
let mut key = room_id.to_string().as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(pdu.kind.to_string().as_bytes());
key.push(0xff);
key.extend_from_slice(state_key.to_string().as_bytes());
self.roomstateid_pdu.insert(key, &*pdu_json.to_string())?;
}
self.edus.room_read_set(&room_id, &sender, index)?;
Ok(pdu.event_id)
}
/// Returns an iterator over all PDUs in a room.
pub fn all_pdus(&self, room_id: &RoomId) -> Result<impl Iterator<Item = Result<PduEvent>>> {
self.pdus_since(room_id, 0)
}
/// Returns an iterator over all events in a room that happened after the event with id `since`.
pub fn pdus_since(
&self,
room_id: &RoomId,
since: u64,
) -> Result<impl Iterator<Item = Result<PduEvent>>> {
// Create the first part of the full pdu id
let mut pdu_id = room_id.to_string().as_bytes().to_vec();
pdu_id.push(0xff);
pdu_id.extend_from_slice(&(since).to_be_bytes());
self.pdus_since_pduid(room_id, &pdu_id)
}
/// Returns an iterator over all events in a room that happened after the event with id `since`.
pub fn pdus_since_pduid(
&self,
room_id: &RoomId,
pdu_id: &[u8],
) -> Result<impl Iterator<Item = Result<PduEvent>>> {
// Create the first part of the full pdu id
let mut prefix = room_id.to_string().as_bytes().to_vec();
prefix.push(0xff);
Ok(self
.pduid_pdu
.range(pdu_id..)
// Skip the first pdu if it's exactly at since, because we sent that last time
.skip(if self.pduid_pdu.get(pdu_id)?.is_some() {
1
} else {
0
})
.filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| Ok(serde_json::from_slice(&v)?)))
}
/// Returns an iterator over all events in a room that happened before the event with id
/// `until` in reverse-chronological order.
pub fn pdus_until(
&self,
room_id: &RoomId,
until: u64,
) -> impl Iterator<Item = Result<PduEvent>> {
// Create the first part of the full pdu id
let mut prefix = room_id.to_string().as_bytes().to_vec();
prefix.push(0xff);
let mut current = prefix.clone();
current.extend_from_slice(&until.to_be_bytes());
let current: &[u8] = &current;
self.pduid_pdu
.range(..current)
.rev()
.filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| Ok(serde_json::from_slice(&v)?))
}
/// Makes a user join a room.
pub fn join(
&self,
room_id: &RoomId,
user_id: &UserId,
displayname: Option<String>,
globals: &super::globals::Globals,
) -> Result<()> {
if !self.exists(room_id)? {
return Err(Error::BadRequest("room does not exist"));
}
let mut userroom_id = user_id.to_string().as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.to_string().as_bytes());
let mut roomuser_id = room_id.to_string().as_bytes().to_vec();
roomuser_id.push(0xff);
roomuser_id.extend_from_slice(user_id.to_string().as_bytes());
self.userroomid_joined.insert(&userroom_id, &[])?;
self.roomuserid_joined.insert(&roomuser_id, &[])?;
self.userroomid_invited.remove(&userroom_id)?;
self.roomuserid_invited.remove(&roomuser_id)?;
self.userroomid_left.remove(&userroom_id)?;
let mut content = json!({"membership": "join"});
if let Some(displayname) = displayname {
content
.as_object_mut()
.unwrap()
.insert("displayname".to_owned(), displayname.into());
}
self.append_pdu(
room_id.clone(),
user_id.clone(),
EventType::RoomMember,
content,
None,
Some(user_id.to_string()),
globals,
)?;
Ok(())
}
/// Makes a user leave a room.
pub fn leave(
&self,
sender: &UserId,
room_id: &RoomId,
user_id: &UserId,
globals: &super::globals::Globals,
) -> Result<()> {
let mut userroom_id = user_id.to_string().as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.to_string().as_bytes());
let mut roomuser_id = room_id.to_string().as_bytes().to_vec();
roomuser_id.push(0xff);
roomuser_id.extend_from_slice(user_id.to_string().as_bytes());
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invited.remove(&userroom_id)?;
self.roomuserid_invited.remove(&userroom_id)?;
self.userroomid_left.insert(&userroom_id, &[])?;
self.append_pdu(
room_id.clone(),
sender.clone(),
EventType::RoomMember,
json!({"membership": "leave"}),
None,
Some(user_id.to_string()),
globals,
)?;
Ok(())
}
/// Makes a user forget a room.
pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
let mut userroom_id = user_id.to_string().as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.to_string().as_bytes());
self.userroomid_left.remove(userroom_id)?;
Ok(())
}
/// Makes a user invite another user into room.
pub fn invite(
&self,
sender: &UserId,
room_id: &RoomId,
user_id: &UserId,
globals: &super::globals::Globals,
) -> Result<()> {
let mut userroom_id = user_id.to_string().as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.to_string().as_bytes());
let mut roomuser_id = room_id.to_string().as_bytes().to_vec();
roomuser_id.push(0xff);
roomuser_id.extend_from_slice(user_id.to_string().as_bytes());
self.userroomid_invited.insert(userroom_id, &[])?;
self.roomuserid_invited.insert(roomuser_id, &[])?;
self.append_pdu(
room_id.clone(),
sender.clone(),
EventType::RoomMember,
json!({"membership": "invite"}),
None,
Some(user_id.to_string()),
globals,
)?;
Ok(())
}
/// Returns an iterator over all rooms a user joined.
pub fn room_members(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> {
self.roomuserid_joined
.scan_prefix(room_id.to_string())
.values()
.map(|key| {
Ok(UserId::try_from(&*utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("userroomid is invalid"))?,
)?)?)
})
}
/// Returns an iterator over all rooms a user joined.
pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> {
self.roomuserid_invited
.scan_prefix(room_id.to_string())
.keys()
.map(|key| {
Ok(UserId::try_from(&*utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("userroomid is invalid"))?,
)?)?)
})
}
/// Returns an iterator over all rooms a user joined.
pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<RoomId>> {
self.userroomid_joined
.scan_prefix(user_id.to_string())
.keys()
.map(|key| {
Ok(RoomId::try_from(&*utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("userroomid is invalid"))?,
)?)?)
})
}
/// Returns an iterator over all rooms a user was invited to.
pub fn rooms_invited(&self, user_id: &UserId) -> impl Iterator<Item = Result<RoomId>> {
self.userroomid_invited
.scan_prefix(&user_id.to_string())
.keys()
.map(|key| {
Ok(RoomId::try_from(&*utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("userroomid is invalid"))?,
)?)?)
})
}
/// Returns an iterator over all rooms a user left.
pub fn rooms_left(&self, user_id: &UserId) -> impl Iterator<Item = Result<RoomId>> {
self.userroomid_left
.scan_prefix(&user_id.to_string())
.keys()
.map(|key| {
Ok(RoomId::try_from(&*utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.ok_or(Error::BadDatabase("userroomid is invalid"))?,
)?)?)
})
}
}

190
src/database/rooms/edus.rs Normal file
View file

@ -0,0 +1,190 @@
use crate::{utils, Result};
use ruma_events::{collections::only::Event as EduEvent, EventJson};
use ruma_identifiers::{RoomId, UserId};
pub struct RoomEdus {
pub(in super::super) roomuserid_lastread: sled::Tree, // RoomUserId = Room + User
pub(in super::super) roomlatestid_roomlatest: sled::Tree, // Read Receipts, RoomLatestId = RoomId + Count + UserId
pub(in super::super) roomactiveid_roomactive: sled::Tree, // Typing, RoomActiveId = RoomId + TimeoutTime + Count
}
impl RoomEdus {
/// Adds an event which will be saved until a new event replaces it (e.g. read receipt).
pub fn roomlatest_update(
&self,
user_id: &UserId,
room_id: &RoomId,
event: EduEvent,
globals: &super::super::globals::Globals,
) -> Result<()> {
let mut prefix = room_id.to_string().as_bytes().to_vec();
prefix.push(0xff);
// Remove old entry
if let Some(old) = self
.roomlatestid_roomlatest
.scan_prefix(&prefix)
.keys()
.rev()
.filter_map(|r| r.ok())
.take_while(|key| key.starts_with(&prefix))
.find(|key| {
key.rsplit(|&b| b == 0xff).next().unwrap() == user_id.to_string().as_bytes()
})
{
// This is the old room_latest
self.roomlatestid_roomlatest.remove(old)?;
}
let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&globals.next_count()?.to_be_bytes());
room_latest_id.push(0xff);
room_latest_id.extend_from_slice(&user_id.to_string().as_bytes());
self.roomlatestid_roomlatest
.insert(room_latest_id, &*serde_json::to_string(&event)?)?;
Ok(())
}
/// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`.
pub fn roomlatests_since(
&self,
room_id: &RoomId,
since: u64,
) -> Result<impl Iterator<Item = Result<EventJson<EduEvent>>>> {
let mut prefix = room_id.to_string().as_bytes().to_vec();
prefix.push(0xff);
let mut first_possible_edu = prefix.clone();
first_possible_edu.extend_from_slice(&since.to_be_bytes());
Ok(self
.roomlatestid_roomlatest
.range(&*first_possible_edu..)
// Skip the first pdu if it's exactly at since, because we sent that last time
.skip(
if self
.roomlatestid_roomlatest
.get(first_possible_edu)?
.is_some()
{
1
} else {
0
},
)
.filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| Ok(serde_json::from_slice(&v)?)))
}
/// Returns a vector of the most recent read_receipts in a room that happened after the event with id `since`.
pub fn roomlatests_all(
&self,
room_id: &RoomId,
) -> Result<impl Iterator<Item = Result<EventJson<EduEvent>>>> {
self.roomlatests_since(room_id, 0)
}
/// Adds an event that will be saved until the `timeout` timestamp (e.g. typing notifications).
pub fn roomactive_add(
&self,
event: EduEvent,
room_id: &RoomId,
timeout: u64,
globals: &super::super::globals::Globals,
) -> Result<()> {
let mut prefix = room_id.to_string().as_bytes().to_vec();
prefix.push(0xff);
// Cleanup all outdated edus before inserting a new one
for outdated_edu in self
.roomactiveid_roomactive
.scan_prefix(&prefix)
.keys()
.filter_map(|r| r.ok())
.take_while(|k| {
utils::u64_from_bytes(
k.split(|&c| c == 0xff)
.nth(1)
.expect("roomactive has valid timestamp and delimiters"),
) < utils::millis_since_unix_epoch()
})
{
// This is an outdated edu (time > timestamp)
self.roomlatestid_roomlatest.remove(outdated_edu)?;
}
let mut room_active_id = prefix;
room_active_id.extend_from_slice(&timeout.to_be_bytes());
room_active_id.push(0xff);
room_active_id.extend_from_slice(&globals.next_count()?.to_be_bytes());
self.roomactiveid_roomactive
.insert(room_active_id, &*serde_json::to_string(&event)?)?;
Ok(())
}
/// Removes an active event manually (before the timeout is reached).
pub fn roomactive_remove(&self, event: EduEvent, room_id: &RoomId) -> Result<()> {
let mut prefix = room_id.to_string().as_bytes().to_vec();
prefix.push(0xff);
let json = serde_json::to_string(&event)?;
// Remove outdated entries
for outdated_edu in self
.roomactiveid_roomactive
.scan_prefix(&prefix)
.filter_map(|r| r.ok())
.filter(|(_, v)| v == json.as_bytes())
{
self.roomactiveid_roomactive.remove(outdated_edu.0)?;
}
Ok(())
}
/// Returns an iterator over all active events (e.g. typing notifications).
pub fn roomactives_all(
&self,
room_id: &RoomId,
) -> impl Iterator<Item = Result<EventJson<EduEvent>>> {
let mut prefix = room_id.to_string().as_bytes().to_vec();
prefix.push(0xff);
let mut first_active_edu = prefix.clone();
first_active_edu.extend_from_slice(&utils::millis_since_unix_epoch().to_be_bytes());
self.roomactiveid_roomactive
.range(first_active_edu..)
.filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, v)| Ok(serde_json::from_slice(&v)?))
}
/// Sets a private read marker at `count`.
pub fn room_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
let mut key = room_id.to_string().as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(&user_id.to_string().as_bytes());
self.roomuserid_lastread.insert(key, &count.to_be_bytes())?;
Ok(())
}
/// Returns the private read marker.
pub fn room_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.to_string().as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(&user_id.to_string().as_bytes());
Ok(self
.roomuserid_lastread
.get(key)?
.map(|v| utils::u64_from_bytes(&v)))
}
}

144
src/database/users.rs Normal file
View file

@ -0,0 +1,144 @@
use crate::{utils, Error, Result};
use ruma_identifiers::UserId;
use std::convert::TryFrom;
pub struct Users {
pub(super) userid_password: sled::Tree,
pub(super) userid_displayname: sled::Tree,
pub(super) userid_avatarurl: sled::Tree,
pub(super) userdeviceid: sled::Tree,
pub(super) userdeviceid_token: sled::Tree,
pub(super) token_userid: sled::Tree,
}
impl Users {
/// Check if a user has an account on this homeserver.
pub fn exists(&self, user_id: &UserId) -> Result<bool> {
Ok(self.userid_password.contains_key(user_id.to_string())?)
}
/// Create a new user account on this homeserver.
pub fn create(&self, user_id: &UserId, hash: &str) -> Result<()> {
self.userid_password.insert(user_id.to_string(), hash)?;
Ok(())
}
/// Find out which user an access token belongs to.
pub fn find_from_token(&self, token: &str) -> Result<Option<UserId>> {
self.token_userid.get(token)?.map_or(Ok(None), |bytes| {
utils::string_from_bytes(&bytes)
.and_then(|string| Ok(UserId::try_from(string)?))
.map(Some)
})
}
/// Returns an iterator over all users on this homeserver.
pub fn iter(&self) -> impl Iterator<Item = Result<UserId>> {
self.userid_password.iter().keys().map(|r| {
utils::string_from_bytes(&r?).and_then(|string| Ok(UserId::try_from(&*string)?))
})
}
/// Returns the password hash for the given user.
pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password
.get(user_id.to_string())?
.map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some))
}
/// Returns the displayname of a user on this homeserver.
pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname
.get(user_id.to_string())?
.map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some))
}
/// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change.
pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
if let Some(displayname) = displayname {
self.userid_displayname
.insert(user_id.to_string(), &*displayname)?;
} else {
self.userid_displayname.remove(user_id.to_string())?;
}
Ok(())
/* TODO:
for room_id in self.rooms_joined(user_id) {
self.pdu_append(
room_id.clone(),
user_id.clone(),
EventType::RoomMember,
json!({"membership": "join", "displayname": displayname}),
None,
Some(user_id.to_string()),
);
}
*/
}
/// Get a the avatar_url of a user.
pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_avatarurl
.get(user_id.to_string())?
.map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some))
}
/// Sets a new avatar_url or removes it if avatar_url is None.
pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<String>) -> Result<()> {
if let Some(avatar_url) = avatar_url {
self.userid_avatarurl
.insert(user_id.to_string(), &*avatar_url)?;
} else {
self.userid_avatarurl.remove(user_id.to_string())?;
}
Ok(())
}
/// Adds a new device to a user.
pub fn create_device(&self, user_id: &UserId, device_id: &str, token: &str) -> Result<()> {
if !self.exists(user_id)? {
return Err(Error::BadRequest(
"tried to create device for nonexistent user",
));
}
let mut key = user_id.to_string().as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(device_id.as_bytes());
self.userdeviceid.insert(key, &[])?;
self.set_token(user_id, device_id, token)?;
Ok(())
}
/// Replaces the access token of one device.
pub fn set_token(&self, user_id: &UserId, device_id: &str, token: &str) -> Result<()> {
let mut key = user_id.to_string().as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(device_id.as_bytes());
if self.userdeviceid.get(&key)?.is_none() {
return Err(Error::BadRequest(
"Tried to set token for nonexistent device",
));
}
// Remove old token
if let Some(old_token) = self.userdeviceid_token.get(&key)? {
self.token_userid.remove(old_token)?;
// It will be removed from userdeviceid_token by the insert later
}
// Assign token to device_id
self.userdeviceid_token.insert(key, &*token)?;
// Assign token to user
self.token_userid.insert(token, &*user_id.to_string())?;
Ok(())
}
}

36
src/error.rs Normal file
View file

@ -0,0 +1,36 @@
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)]
pub enum Error {
#[error("problem with the database")]
SledError {
#[from]
source: sled::Error,
},
#[error("tried to parse invalid string")]
StringFromBytesError {
#[from]
source: std::string::FromUtf8Error,
},
#[error("tried to parse invalid identifier")]
SerdeJsonError {
#[from]
source: serde_json::Error,
},
#[error("tried to parse invalid identifier")]
RumaIdentifierError {
#[from]
source: ruma_identifiers::Error,
},
#[error("tried to parse invalid event")]
RumaEventError {
#[from]
source: ruma_events::InvalidEvent,
},
#[error("bad request")]
BadRequest(&'static str),
#[error("problem in that database")]
BadDatabase(&'static str),
}

View file

@ -1,8 +1,9 @@
#![feature(proc_macro_hygiene, decl_macro)]
#![warn(rust_2018_idioms)]
mod client_server;
mod data;
mod database;
mod error;
mod pdu;
mod ruma_wrapper;
mod server_server;
@ -11,8 +12,8 @@ mod utils;
#[cfg(test)]
mod test;
pub use data::Data;
pub use database::Database;
pub use error::{Error, Result};
pub use pdu::PduEvent;
pub use ruma_wrapper::{MatrixResult, Ruma};
@ -75,7 +76,7 @@ fn setup_rocket() -> rocket::Rocket {
)
.attach(AdHoc::on_attach("Config", |rocket| {
let hostname = rocket.config().get_str("hostname").unwrap_or("localhost");
let data = Data::load_or_create(&hostname);
let data = Database::load_or_create(&hostname);
Ok(rocket.manage(data))
}))
@ -86,7 +87,6 @@ fn main() {
if let Err(_) = std::env::var("RUST_LOG") {
std::env::set_var("RUST_LOG", "warn");
}
pretty_env_logger::init();
setup_rocket().launch().unwrap();
}

View file

@ -27,21 +27,21 @@ impl<'a, T: Endpoint> FromData<'a> for Ruma<T> {
type Borrowed = Self::Owned;
fn transform<'r>(
_req: &'r Request,
_req: &'r Request<'_>,
data: Data,
) -> TransformFuture<'r, Self::Owned, Self::Error> {
Box::pin(async move { Transform::Owned(Success(data)) })
}
fn from_data(
request: &'a Request,
request: &'a Request<'_>,
outcome: Transformed<'a, Self>,
) -> FromDataFuture<'a, Self, Self::Error> {
Box::pin(async move {
let data = rocket::try_outcome!(outcome.owned());
let user_id = if T::METADATA.requires_authentication {
let data = request.guard::<State<crate::Data>>().await.unwrap();
let db = request.guard::<State<'_, crate::Database>>().await.unwrap();
// Get token from header or query value
let token = match request
@ -56,7 +56,7 @@ impl<'a, T: Endpoint> FromData<'a> for Ruma<T> {
};
// Check if token is valid
match data.user_from_token(&token) {
match db.users.find_from_token(&token).unwrap() {
// TODO: M_UNKNOWN_TOKEN
None => return Failure((Status::Unauthorized, ())),
Some(user_id) => Some(user_id),

View file

@ -1,7 +1,7 @@
use crate::{Data, MatrixResult};
use crate::{Database, MatrixResult};
use http::header::{HeaderValue, AUTHORIZATION};
use log::error;
use rocket::{get, post, put, response::content::Json, State};
use rocket::{get, response::content::Json, State};
use ruma_api::Endpoint;
use ruma_client_api::error::Error;
use ruma_federation_api::{v1::get_server_version, v2::get_server_keys};
@ -12,9 +12,9 @@ use std::{
time::{Duration, SystemTime},
};
pub async fn request_well_known(data: &crate::Data, destination: &str) -> Option<String> {
pub async fn request_well_known(db: &crate::Database, destination: &str) -> Option<String> {
let body: serde_json::Value = serde_json::from_str(
&data
&db.globals
.reqwest_client()
.get(&format!(
"https://{}/.well-known/matrix/server",
@ -32,14 +32,14 @@ pub async fn request_well_known(data: &crate::Data, destination: &str) -> Option
}
pub async fn send_request<T: Endpoint>(
data: &crate::Data,
db: &crate::Database,
destination: String,
request: T,
) -> Option<T::Response> {
let mut http_request: http::Request<_> = request.try_into().unwrap();
let actual_destination = "https://".to_owned()
+ &request_well_known(data, &destination)
+ &request_well_known(db, &destination)
.await
.unwrap_or(destination.clone() + ":8448");
*http_request.uri_mut() = (actual_destination + T::METADATA.path).parse().unwrap();
@ -55,11 +55,11 @@ pub async fn send_request<T: Endpoint>(
request_map.insert("method".to_owned(), T::METADATA.method.to_string().into());
request_map.insert("uri".to_owned(), T::METADATA.path.into());
request_map.insert("origin".to_owned(), data.hostname().into());
request_map.insert("origin".to_owned(), db.globals.hostname().into());
request_map.insert("destination".to_owned(), destination.into());
let mut request_json = request_map.into();
ruma_signatures::sign_json(data.hostname(), data.keypair(), &mut request_json).unwrap();
ruma_signatures::sign_json(db.globals.hostname(), db.globals.keypair(), &mut request_json).unwrap();
let signatures = request_json["signatures"]
.as_object()
@ -77,7 +77,7 @@ pub async fn send_request<T: Endpoint>(
AUTHORIZATION,
HeaderValue::from_str(&format!(
"X-Matrix origin={},key=\"{}\",sig=\"{}\"",
data.hostname(),
db.globals.hostname(),
s.0,
s.1
))
@ -85,7 +85,7 @@ pub async fn send_request<T: Endpoint>(
);
}
let reqwest_response = data.reqwest_client().execute(http_request.into()).await;
let reqwest_response = db.globals.reqwest_client().execute(http_request.into()).await;
// Because reqwest::Response -> http::Response is complicated:
match reqwest_response {
@ -120,7 +120,7 @@ pub async fn send_request<T: Endpoint>(
}
#[get("/.well-known/matrix/server")]
pub fn well_known_server(data: State<Data>) -> Json<String> {
pub fn well_known_server() -> Json<String> {
rocket::response::content::Json(
json!({ "m.server": "matrixtesting.koesters.xyz:14004"}).to_string(),
)
@ -137,17 +137,17 @@ pub fn get_server_version() -> MatrixResult<get_server_version::Response, Error>
}
#[get("/_matrix/key/v2/server")]
pub fn get_server_keys(data: State<Data>) -> Json<String> {
pub fn get_server_keys(db: State<'_, Database>) -> Json<String> {
let mut verify_keys = BTreeMap::new();
verify_keys.insert(
format!("ed25519:{}", data.keypair().version()),
format!("ed25519:{}", db.globals.keypair().version()),
get_server_keys::VerifyKey {
key: base64::encode_config(data.keypair().public_key(), base64::STANDARD_NO_PAD),
key: base64::encode_config(db.globals.keypair().public_key(), base64::STANDARD_NO_PAD),
},
);
let mut response = serde_json::from_slice(
http::Response::try_from(get_server_keys::Response {
server_name: data.hostname().to_owned(),
server_name: db.globals.hostname().to_owned(),
verify_keys,
old_verify_keys: BTreeMap::new(),
signatures: BTreeMap::new(),
@ -157,11 +157,11 @@ pub fn get_server_keys(data: State<Data>) -> Json<String> {
.body(),
)
.unwrap();
ruma_signatures::sign_json(data.hostname(), data.keypair(), &mut response).unwrap();
ruma_signatures::sign_json(db.globals.hostname(), db.globals.keypair(), &mut response).unwrap();
Json(response.to_string())
}
#[get("/_matrix/key/v2/server/<_key_id>")]
pub fn get_server_keys_deprecated(data: State<Data>, _key_id: String) -> Json<String> {
get_server_keys(data)
pub fn get_server_keys_deprecated(db: State<'_, Database>, _key_id: String) -> Json<String> {
get_server_keys(db)
}

View file

@ -1,8 +1,6 @@
use super::*;
use rocket::{http::Status, local::Client};
use ruma_client_api::error::ErrorKind;
use rocket::local::Client;
use serde_json::{json, Value};
use std::time::Duration;
fn setup_client() -> Client {
Database::try_remove("localhost");

View file

@ -1,3 +1,4 @@
use crate::Result;
use argon2::{Config, Variant};
use rand::prelude::*;
use std::{
@ -32,13 +33,15 @@ pub fn generate_keypair(old: Option<&[u8]>) -> Option<Vec<u8>> {
)
}
/// Parses the bytes into an u64.
pub fn u64_from_bytes(bytes: &[u8]) -> u64 {
let array: [u8; 8] = bytes.try_into().expect("bytes are valid u64");
u64::from_be_bytes(array)
}
pub fn string_from_bytes(bytes: &[u8]) -> String {
String::from_utf8(bytes.to_vec()).expect("bytes are valid utf8")
/// Parses the bytes into a string.
pub fn string_from_bytes(bytes: &[u8]) -> Result<String> {
Ok(String::from_utf8(bytes.to_vec())?)
}
pub fn random_string(length: usize) -> String {
@ -49,7 +52,7 @@ pub fn random_string(length: usize) -> String {
}
/// Calculate a new hash for the given password
pub fn calculate_hash(password: &str) -> Result<String, argon2::Error> {
pub fn calculate_hash(password: &str) -> std::result::Result<String, argon2::Error> {
let hashing_config = Config {
variant: Variant::Argon2id,
..Default::default()