improvement: faster incoming transaction handling

This commit is contained in:
Timo Kösters 2021-08-19 11:01:18 +02:00
parent bf7e019a68
commit 46d8a46e1f
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
12 changed files with 365 additions and 280 deletions

View file

@ -110,6 +110,7 @@ pub struct Rooms {
impl Rooms {
/// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self))]
pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeSet<EventId>> {
let full_state = self
.load_shortstatehash_info(shortstatehash)?
@ -122,6 +123,7 @@ impl Rooms {
.collect()
}
#[tracing::instrument(skip(self))]
pub fn state_full(
&self,
shortstatehash: u64,
@ -220,6 +222,7 @@ impl Rooms {
}
/// This fetches auth events from the current state.
#[tracing::instrument(skip(self))]
pub fn get_auth_events(
&self,
room_id: &RoomId,
@ -261,6 +264,7 @@ impl Rooms {
}
/// Checks if a room exists.
#[tracing::instrument(skip(self))]
pub fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match self.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(),
@ -277,6 +281,7 @@ impl Rooms {
}
/// Checks if a room exists.
#[tracing::instrument(skip(self))]
pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> {
let prefix = self
.get_shortroomid(room_id)?
@ -300,6 +305,7 @@ impl Rooms {
/// Force the creation of a new StateHash and insert it into the db.
///
/// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot.
#[tracing::instrument(skip(self, new_state, db))]
pub fn force_state(
&self,
room_id: &RoomId,
@ -412,6 +418,7 @@ impl Rooms {
}
/// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer.
#[tracing::instrument(skip(self))]
pub fn load_shortstatehash_info(
&self,
shortstatehash: u64,
@ -480,6 +487,7 @@ impl Rooms {
}
}
#[tracing::instrument(skip(self, globals))]
pub fn compress_state_event(
&self,
shortstatekey: u64,
@ -495,6 +503,7 @@ impl Rooms {
Ok(v.try_into().expect("we checked the size above"))
}
#[tracing::instrument(skip(self, compressed_event))]
pub fn parse_compressed_state_event(
&self,
compressed_event: CompressedStateEvent,
@ -518,6 +527,13 @@ impl Rooms {
/// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid
/// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer
/// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer
#[tracing::instrument(skip(
self,
statediffnew,
statediffremoved,
diff_to_sibling,
parent_states
))]
pub fn save_state_from_diff(
&self,
shortstatehash: u64,
@ -642,6 +658,7 @@ impl Rooms {
}
/// Returns (shortstatehash, already_existed)
#[tracing::instrument(skip(self, globals))]
fn get_or_create_shortstatehash(
&self,
state_hash: &StateHashId,
@ -662,6 +679,7 @@ impl Rooms {
})
}
#[tracing::instrument(skip(self, globals))]
pub fn get_or_create_shorteventid(
&self,
event_id: &EventId,
@ -692,6 +710,7 @@ impl Rooms {
Ok(short)
}
#[tracing::instrument(skip(self))]
pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortroomid
.get(&room_id.as_bytes())?
@ -702,6 +721,7 @@ impl Rooms {
.transpose()
}
#[tracing::instrument(skip(self))]
pub fn get_shortstatekey(
&self,
event_type: &EventType,
@ -739,6 +759,7 @@ impl Rooms {
Ok(short)
}
#[tracing::instrument(skip(self, globals))]
pub fn get_or_create_shortroomid(
&self,
room_id: &RoomId,
@ -756,6 +777,7 @@ impl Rooms {
})
}
#[tracing::instrument(skip(self, globals))]
pub fn get_or_create_shortstatekey(
&self,
event_type: &EventType,
@ -794,6 +816,7 @@ impl Rooms {
Ok(short)
}
#[tracing::instrument(skip(self))]
pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<EventId> {
if let Some(id) = self
.shorteventid_cache
@ -876,12 +899,14 @@ impl Rooms {
}
/// Returns the `count` of this pdu's id.
#[tracing::instrument(skip(self))]
pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu_id| self.pdu_count(&pdu_id).map(Some))
}
#[tracing::instrument(skip(self))]
pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result<u64> {
let prefix = self
.get_shortroomid(room_id)?
@ -902,6 +927,7 @@ impl Rooms {
}
/// Returns the json of a pdu.
#[tracing::instrument(skip(self))]
pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid
.get(event_id.as_bytes())?
@ -920,6 +946,7 @@ impl Rooms {
}
/// Returns the json of a pdu.
#[tracing::instrument(skip(self))]
pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
@ -930,6 +957,7 @@ impl Rooms {
}
/// Returns the json of a pdu.
#[tracing::instrument(skip(self))]
pub fn get_non_outlier_pdu_json(
&self,
event_id: &EventId,
@ -951,6 +979,7 @@ impl Rooms {
}
/// Returns the pdu's id.
#[tracing::instrument(skip(self))]
pub fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> {
self.eventid_pduid
.get(event_id.as_bytes())?
@ -960,6 +989,7 @@ impl Rooms {
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
#[tracing::instrument(skip(self))]
pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_pduid
.get(event_id.as_bytes())?
@ -980,6 +1010,7 @@ impl Rooms {
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
#[tracing::instrument(skip(self))]
pub fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(&event_id) {
return Ok(Some(Arc::clone(p)));
@ -1019,6 +1050,7 @@ impl Rooms {
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
#[tracing::instrument(skip(self))]
pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
@ -1029,6 +1061,7 @@ impl Rooms {
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
#[tracing::instrument(skip(self))]
pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
@ -1039,6 +1072,7 @@ impl Rooms {
}
/// Removes a pdu and creates a new one with the same id.
#[tracing::instrument(skip(self))]
fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> {
if self.pduid_pdu.get(&pdu_id)?.is_some() {
self.pduid_pdu.insert(
@ -2298,6 +2332,7 @@ impl Rooms {
Ok(())
}
#[tracing::instrument(skip(self))]
pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> {
let mut joinedcount = 0_u64;
let mut joined_servers = HashSet::new();
@ -2347,6 +2382,7 @@ impl Rooms {
Ok(())
}
#[tracing::instrument(skip(self, db))]
pub async fn leave_room(
&self,
user_id: &UserId,
@ -2419,6 +2455,7 @@ impl Rooms {
Ok(())
}
#[tracing::instrument(skip(self, db))]
async fn remote_leave_room(
&self,
user_id: &UserId,
@ -2650,6 +2687,7 @@ impl Rooms {
})
}
#[tracing::instrument(skip(self))]
pub fn search_pdus<'a>(
&'a self,
room_id: &RoomId,
@ -2809,6 +2847,7 @@ impl Rooms {
})
}
#[tracing::instrument(skip(self))]
pub fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
Ok(self
.roomid_joinedcount

View file

@ -4,11 +4,14 @@ use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result};
use ruma::{
api::client::{
error::ErrorKind,
r0::uiaa::{IncomingAuthData, UiaaInfo},
r0::uiaa::{
IncomingAuthData, IncomingPassword, IncomingUserIdentifier::MatrixId, UiaaInfo,
},
},
signatures::CanonicalJsonValue,
DeviceId, UserId,
};
use tracing::error;
use super::abstraction::Tree;
@ -49,126 +52,91 @@ impl Uiaa {
users: &super::users::Users,
globals: &super::globals::Globals,
) -> Result<(bool, UiaaInfo)> {
if let IncomingAuthData::DirectRequest {
kind,
session,
auth_parameters,
} = &auth
{
let mut uiaainfo = session
.as_ref()
.map(|session| self.get_uiaa_session(&user_id, &device_id, session))
.unwrap_or_else(|| Ok(uiaainfo.clone()))?;
let mut uiaainfo = auth
.session()
.map(|session| self.get_uiaa_session(&user_id, &device_id, session))
.unwrap_or_else(|| Ok(uiaainfo.clone()))?;
if uiaainfo.session.is_none() {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
}
if uiaainfo.session.is_none() {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
}
match auth {
// Find out what the user completed
match &**kind {
"m.login.password" => {
let identifier = auth_parameters.get("identifier").ok_or(Error::BadRequest(
ErrorKind::MissingParam,
"m.login.password needs identifier.",
))?;
let identifier_type = identifier.get("type").ok_or(Error::BadRequest(
ErrorKind::MissingParam,
"Identifier needs a type.",
))?;
if identifier_type != "m.id.user" {
IncomingAuthData::Password(IncomingPassword {
identifier,
password,
..
}) => {
let username = match identifier {
MatrixId(username) => username,
_ => {
return Err(Error::BadRequest(
ErrorKind::Unrecognized,
"Identifier type not recognized.",
));
))
}
};
let username = identifier
.get("user")
.ok_or(Error::BadRequest(
ErrorKind::MissingParam,
"Identifier needs user field.",
))?
.as_str()
.ok_or(Error::BadRequest(
ErrorKind::BadJson,
"User is not a string.",
))?;
let user_id = UserId::parse_with_server_name(username, globals.server_name())
let user_id =
UserId::parse_with_server_name(username.clone(), globals.server_name())
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.")
})?;
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.")
})?;
let password = auth_parameters
.get("password")
.ok_or(Error::BadRequest(
ErrorKind::MissingParam,
"Password is missing.",
))?
.as_str()
.ok_or(Error::BadRequest(
ErrorKind::BadJson,
"Password is not a string.",
))?;
// Check if password is correct
if let Some(hash) = users.password_hash(&user_id)? {
let hash_matches =
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);
// Check if password is correct
if let Some(hash) = users.password_hash(&user_id)? {
let hash_matches =
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);
if !hash_matches {
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody {
kind: ErrorKind::Forbidden,
message: "Invalid username or password.".to_owned(),
});
return Ok((false, uiaainfo));
}
}
// Password was correct! Let's add it to `completed`
uiaainfo.completed.push("m.login.password".to_owned());
}
"m.login.dummy" => {
uiaainfo.completed.push("m.login.dummy".to_owned());
}
k => panic!("type not supported: {}", k),
}
// Check if a flow now succeeds
let mut completed = false;
'flows: for flow in &mut uiaainfo.flows {
for stage in &flow.stages {
if !uiaainfo.completed.contains(stage) {
continue 'flows;
if !hash_matches {
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody {
kind: ErrorKind::Forbidden,
message: "Invalid username or password.".to_owned(),
});
return Ok((false, uiaainfo));
}
}
// We didn't break, so this flow succeeded!
completed = true;
}
if !completed {
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
Some(&uiaainfo),
)?;
return Ok((false, uiaainfo));
// Password was correct! Let's add it to `completed`
uiaainfo.completed.push("m.login.password".to_owned());
}
IncomingAuthData::Dummy(_) => {
uiaainfo.completed.push("m.login.dummy".to_owned());
}
k => error!("type not supported: {:?}", k),
}
// UIAA was successful! Remove this session and return true
// Check if a flow now succeeds
let mut completed = false;
'flows: for flow in &mut uiaainfo.flows {
for stage in &flow.stages {
if !uiaainfo.completed.contains(stage) {
continue 'flows;
}
}
// We didn't break, so this flow succeeded!
completed = true;
}
if !completed {
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
None,
Some(&uiaainfo),
)?;
Ok((true, uiaainfo))
} else {
panic!("FallbackAcknowledgement is not supported yet");
return Ok((false, uiaainfo));
}
// UIAA was successful! Remove this session and return true
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
None,
)?;
Ok((true, uiaainfo))
}
fn set_uiaa_request(