refactor state accessor, state cache, user, uiaa
This commit is contained in:
parent
3e22bbeecd
commit
82e7f57b38
12 changed files with 116 additions and 933 deletions
|
@ -1,7 +1,5 @@
|
|||
/// Builds a StateMap by iterating over all keys that start
|
||||
/// with state_hash, this gives the full state for the given state_hash.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> {
|
||||
impl service::room::state_accessor::Data for KeyValueDatabase {
|
||||
async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> {
|
||||
let full_state = self
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
|
@ -21,8 +19,7 @@
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub async fn state_full(
|
||||
async fn state_full(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
|
@ -59,8 +56,7 @@
|
|||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn state_get_id(
|
||||
fn state_get_id(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
event_type: &StateEventType,
|
||||
|
@ -86,8 +82,7 @@
|
|||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn state_get(
|
||||
fn state_get(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
event_type: &StateEventType,
|
||||
|
@ -98,7 +93,7 @@
|
|||
}
|
||||
|
||||
/// Returns the state hash for this pdu.
|
||||
pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
|
||||
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
|
||||
self.eventid_shorteventid
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |shorteventid| {
|
||||
|
@ -116,8 +111,7 @@
|
|||
}
|
||||
|
||||
/// Returns the full room state.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub async fn room_state_full(
|
||||
async fn room_state_full(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
|
@ -129,8 +123,7 @@
|
|||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn room_state_get_id(
|
||||
fn room_state_get_id(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
|
@ -144,8 +137,7 @@
|
|||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn room_state_get(
|
||||
fn room_state_get(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
|
@ -157,4 +149,3 @@
|
|||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
8
src/database/key_value/rooms/state_cache.rs
Normal file
8
src/database/key_value/rooms/state_cache.rs
Normal file
|
@ -0,0 +1,8 @@
|
|||
impl service::room::state_cache::Data for KeyValueDatabase {
|
||||
fn mark_as_once_joined(user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
self.roomuseroncejoinedids.insert(&userroom_id, &[])?;
|
||||
}
|
||||
}
|
|
@ -1,6 +1,5 @@
|
|||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
impl service::room::user::Data for KeyValueDatabase {
|
||||
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
@ -13,8 +12,7 @@
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
@ -28,8 +26,7 @@
|
|||
.unwrap_or(Ok(0))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
@ -43,7 +40,7 @@
|
|||
.unwrap_or(Ok(0))
|
||||
}
|
||||
|
||||
pub fn associate_token_shortstatehash(
|
||||
fn associate_token_shortstatehash(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
token: u64,
|
||||
|
@ -58,7 +55,7 @@
|
|||
.insert(&key, &shortstatehash.to_be_bytes())
|
||||
}
|
||||
|
||||
pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
|
||||
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
|
||||
let shortroomid = self.get_shortroomid(room_id)?.expect("room exists");
|
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
|
@ -74,8 +71,7 @@
|
|||
.transpose()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_shared_rooms<'a>(
|
||||
fn get_shared_rooms<'a>(
|
||||
&'a self,
|
||||
users: Vec<Box<UserId>>,
|
||||
) -> Result<impl Iterator<Item = Result<Box<RoomId>>> + 'a> {
|
||||
|
@ -111,4 +107,4 @@
|
|||
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
|
||||
}))
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,149 +1,4 @@
|
|||
use std::{
|
||||
collections::BTreeMap,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
uiaa::{
|
||||
AuthType, IncomingAuthData, IncomingPassword,
|
||||
IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo,
|
||||
},
|
||||
},
|
||||
signatures::CanonicalJsonValue,
|
||||
DeviceId, UserId,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct Uiaa {
|
||||
pub(super) userdevicesessionid_uiaainfo: Arc<dyn Tree>, // User-interactive authentication
|
||||
pub(super) userdevicesessionid_uiaarequest:
|
||||
RwLock<BTreeMap<(Box<UserId>, Box<DeviceId>, String), CanonicalJsonValue>>,
|
||||
}
|
||||
|
||||
impl Uiaa {
|
||||
/// Creates a new Uiaa session. Make sure the session token is unique.
|
||||
pub fn create(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
uiaainfo: &UiaaInfo,
|
||||
json_body: &CanonicalJsonValue,
|
||||
) -> Result<()> {
|
||||
self.set_uiaa_request(
|
||||
user_id,
|
||||
device_id,
|
||||
uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?)
|
||||
json_body,
|
||||
)?;
|
||||
self.update_uiaa_session(
|
||||
user_id,
|
||||
device_id,
|
||||
uiaainfo.session.as_ref().expect("session should be set"),
|
||||
Some(uiaainfo),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn try_auth(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
auth: &IncomingAuthData,
|
||||
uiaainfo: &UiaaInfo,
|
||||
users: &super::users::Users,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<(bool, UiaaInfo)> {
|
||||
let mut uiaainfo = auth
|
||||
.session()
|
||||
.map(|session| self.get_uiaa_session(user_id, device_id, session))
|
||||
.unwrap_or_else(|| Ok(uiaainfo.clone()))?;
|
||||
|
||||
if uiaainfo.session.is_none() {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
}
|
||||
|
||||
match auth {
|
||||
// Find out what the user completed
|
||||
IncomingAuthData::Password(IncomingPassword {
|
||||
identifier,
|
||||
password,
|
||||
..
|
||||
}) => {
|
||||
let username = match identifier {
|
||||
UserIdOrLocalpart(username) => username,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unrecognized,
|
||||
"Identifier type not recognized.",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let user_id =
|
||||
UserId::parse_with_server_name(username.clone(), globals.server_name())
|
||||
.map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.")
|
||||
})?;
|
||||
|
||||
// Check if password is correct
|
||||
if let Some(hash) = users.password_hash(&user_id)? {
|
||||
let hash_matches =
|
||||
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);
|
||||
|
||||
if !hash_matches {
|
||||
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody {
|
||||
kind: ErrorKind::Forbidden,
|
||||
message: "Invalid username or password.".to_owned(),
|
||||
});
|
||||
return Ok((false, uiaainfo));
|
||||
}
|
||||
}
|
||||
|
||||
// Password was correct! Let's add it to `completed`
|
||||
uiaainfo.completed.push(AuthType::Password);
|
||||
}
|
||||
IncomingAuthData::Dummy(_) => {
|
||||
uiaainfo.completed.push(AuthType::Dummy);
|
||||
}
|
||||
k => error!("type not supported: {:?}", k),
|
||||
}
|
||||
|
||||
// Check if a flow now succeeds
|
||||
let mut completed = false;
|
||||
'flows: for flow in &mut uiaainfo.flows {
|
||||
for stage in &flow.stages {
|
||||
if !uiaainfo.completed.contains(stage) {
|
||||
continue 'flows;
|
||||
}
|
||||
}
|
||||
// We didn't break, so this flow succeeded!
|
||||
completed = true;
|
||||
}
|
||||
|
||||
if !completed {
|
||||
self.update_uiaa_session(
|
||||
user_id,
|
||||
device_id,
|
||||
uiaainfo.session.as_ref().expect("session is always set"),
|
||||
Some(&uiaainfo),
|
||||
)?;
|
||||
return Ok((false, uiaainfo));
|
||||
}
|
||||
|
||||
// UIAA was successful! Remove this session and return true
|
||||
self.update_uiaa_session(
|
||||
user_id,
|
||||
device_id,
|
||||
uiaainfo.session.as_ref().expect("session is always set"),
|
||||
None,
|
||||
)?;
|
||||
Ok((true, uiaainfo))
|
||||
}
|
||||
|
||||
impl service::uiaa::Data for KeyValueDatabase {
|
||||
fn set_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
|
@ -162,7 +17,7 @@ impl Uiaa {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_uiaa_request(
|
||||
fn get_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue