improvement: better e2ee over fed, faster incoming event handling

This commit is contained in:
Timo Kösters 2021-08-24 19:10:31 +02:00
parent 72dd95f500
commit 81e056417c
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
9 changed files with 407 additions and 256 deletions

View file

@ -23,13 +23,13 @@ use ruma::{
uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId,
};
use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
collections::{BTreeMap, HashMap, HashSet},
convert::{TryFrom, TryInto},
mem::size_of,
sync::{Arc, Mutex},
};
use tokio::sync::MutexGuard;
use tracing::{debug, error, warn};
use tracing::{error, warn};
use super::{abstraction::Tree, admin::AdminCommand, pusher};
@ -73,8 +73,8 @@ pub struct Rooms {
pub(super) shorteventid_shortstatehash: Arc<dyn Tree>,
/// StateKey = EventType + StateKey, ShortStateKey = Count
pub(super) statekey_shortstatekey: Arc<dyn Tree>,
pub(super) shortstatekey_statekey: Arc<dyn Tree>,
pub(super) shortroomid_roomid: Arc<dyn Tree>,
pub(super) roomid_shortroomid: Arc<dyn Tree>,
pub(super) shorteventid_eventid: Arc<dyn Tree>,
@ -95,6 +95,7 @@ pub struct Rooms {
pub(super) shorteventid_cache: Mutex<LruCache<u64, EventId>>,
pub(super) eventidshort_cache: Mutex<LruCache<EventId, u64>>,
pub(super) statekeyshort_cache: Mutex<LruCache<(EventType, String), u64>>,
pub(super) shortstatekey_cache: Mutex<LruCache<u64, (EventType, String)>>,
pub(super) stateinfo_cache: Mutex<
LruCache<
u64,
@ -112,7 +113,7 @@ impl Rooms {
/// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self))]
pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeSet<EventId>> {
pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, EventId>> {
let full_state = self
.load_shortstatehash_info(shortstatehash)?
.pop()
@ -138,7 +139,7 @@ impl Rooms {
.into_iter()
.map(|compressed| self.parse_compressed_state_event(compressed))
.filter_map(|r| r.ok())
.map(|eventid| self.get_pdu(&eventid))
.map(|(_, eventid)| self.get_pdu(&eventid))
.filter_map(|r| r.ok().flatten())
.map(|pdu| {
Ok::<_, Error>((
@ -176,7 +177,11 @@ impl Rooms {
Ok(full_state
.into_iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| self.parse_compressed_state_event(compressed).ok()))
.and_then(|compressed| {
self.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
@ -232,6 +237,13 @@ impl Rooms {
state_key: Option<&str>,
content: &serde_json::Value,
) -> Result<StateMap<Arc<PduEvent>>> {
let shortstatehash =
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? {
current_shortstatehash
} else {
return Ok(HashMap::new());
};
let auth_events = state_res::auth_types_for_event(
kind,
sender,
@ -239,19 +251,30 @@ impl Rooms {
content.clone(),
);
let mut events = StateMap::new();
for (event_type, state_key) in auth_events {
if let Some(pdu) = self.room_state_get(room_id, &event_type, &state_key)? {
events.insert((event_type, state_key), pdu);
} else {
// This is okay because when creating a new room some events were not created yet
debug!(
"{:?}: Could not find {} {:?} in state",
content, event_type, state_key
);
}
}
Ok(events)
let mut sauthevents = auth_events
.into_iter()
.filter_map(|(event_type, state_key)| {
self.get_shortstatekey(&event_type, &state_key)
.ok()
.flatten()
.map(|s| (s, (event_type, state_key)))
})
.collect::<HashMap<_, _>>();
let full_state = self
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.into_iter()
.filter_map(|compressed| self.parse_compressed_state_event(compressed).ok())
.filter_map(|(shortstatekey, event_id)| {
sauthevents.remove(&shortstatekey).map(|k| (k, event_id))
})
.filter_map(|(k, event_id)| self.get_pdu(&event_id).ok().flatten().map(|pdu| (k, pdu)))
.collect())
}
/// Generate a new StateHash.
@ -306,32 +329,19 @@ impl Rooms {
/// Force the creation of a new StateHash and insert it into the db.
///
/// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot.
#[tracing::instrument(skip(self, new_state, db))]
#[tracing::instrument(skip(self, new_state_ids_compressed, db))]
pub fn force_state(
&self,
room_id: &RoomId,
new_state: HashMap<(EventType, String), EventId>,
new_state_ids_compressed: HashSet<CompressedStateEvent>,
db: &Database,
) -> Result<()> {
let previous_shortstatehash = self.current_shortstatehash(&room_id)?;
let new_state_ids_compressed = new_state
.iter()
.filter_map(|((event_type, state_key), event_id)| {
let shortstatekey = self
.get_or_create_shortstatekey(event_type, state_key, &db.globals)
.ok()?;
Some(
self.compress_state_event(shortstatekey, event_id, &db.globals)
.ok()?,
)
})
.collect::<HashSet<_>>();
let state_hash = self.calculate_hash(
&new_state
.values()
.map(|event_id| event_id.as_bytes())
&new_state_ids_compressed
.iter()
.map(|bytes| &bytes[..])
.collect::<Vec<_>>(),
);
@ -373,10 +383,11 @@ impl Rooms {
)?;
};
for event_id in statediffnew
.into_iter()
.filter_map(|new| self.parse_compressed_state_event(new).ok())
{
for event_id in statediffnew.into_iter().filter_map(|new| {
self.parse_compressed_state_event(new)
.ok()
.map(|(_, id)| id)
}) {
if let Some(pdu) = self.get_pdu_json(&event_id)? {
if pdu.get("type").and_then(|val| val.as_str()) == Some("m.room.member") {
if let Ok(pdu) = serde_json::from_value::<PduEvent>(
@ -504,15 +515,20 @@ impl Rooms {
Ok(v.try_into().expect("we checked the size above"))
}
/// Returns shortstatekey, event id
#[tracing::instrument(skip(self, compressed_event))]
pub fn parse_compressed_state_event(
&self,
compressed_event: CompressedStateEvent,
) -> Result<EventId> {
self.get_eventid_from_short(
utils::u64_from_bytes(&compressed_event[size_of::<u64>()..])
) -> Result<(u64, EventId)> {
Ok((
utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()])
.expect("bytes have right length"),
)
self.get_eventid_from_short(
utils::u64_from_bytes(&compressed_event[size_of::<u64>()..])
.expect("bytes have right length"),
)?,
))
}
/// Creates a new shortstatehash that often is just a diff to an already existing
@ -805,6 +821,8 @@ impl Rooms {
let shortstatekey = globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &statekey)?;
shortstatekey
}
};
@ -833,11 +851,10 @@ impl Rooms {
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
let event_id =
EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))?;
let event_id = EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
self.shorteventid_cache
.lock()
@ -847,6 +864,48 @@ impl Rooms {
Ok(event_id)
}
#[tracing::instrument(skip(self))]
pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(EventType, String)> {
if let Some(id) = self
.shortstatekey_cache
.lock()
.unwrap()
.get_mut(&shortstatekey)
{
return Ok(id.clone());
}
let bytes = self
.shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
let mut parts = bytes.splitn(2, |&b| b == 0xff);
let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type =
EventType::try_from(utils::string_from_bytes(&eventtype_bytes).map_err(|_| {
Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?;
let state_key = utils::string_from_bytes(&statekey_bytes).map_err(|_| {
Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
})?;
let result = (event_type, state_key);
self.shortstatekey_cache
.lock()
.unwrap()
.insert(shortstatekey, result.clone());
Ok(result)
}
/// Returns the full room state.
#[tracing::instrument(skip(self))]
pub fn room_state_full(
@ -1106,6 +1165,17 @@ impl Rooms {
.collect()
}
#[tracing::instrument(skip(self, room_id, event_ids))]
pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<()> {
for prev in event_ids {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes());
self.referencedevents.insert(&key, &[])?;
}
Ok(())
}
/// Replace the leaves of a room.
///
/// The provided `event_ids` become the new leaves, this allows a room to have multiple
@ -1202,12 +1272,7 @@ impl Rooms {
}
// We must keep track of all events that have been referenced.
for prev in &pdu.prev_events {
let mut key = pdu.room_id().as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes());
self.referencedevents.insert(&key, &[])?;
}
self.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?;
self.replace_pdu_leaves(&pdu.room_id, leaves)?;
let mutex_insert = Arc::clone(
@ -1565,35 +1630,22 @@ impl Rooms {
///
/// This adds all current state events (not including the incoming event)
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, state, globals))]
#[tracing::instrument(skip(self, state_ids_compressed, globals))]
pub fn set_event_state(
&self,
event_id: &EventId,
room_id: &RoomId,
state: &StateMap<Arc<PduEvent>>,
state_ids_compressed: HashSet<CompressedStateEvent>,
globals: &super::globals::Globals,
) -> Result<()> {
let shorteventid = self.get_or_create_shorteventid(&event_id, globals)?;
let previous_shortstatehash = self.current_shortstatehash(&room_id)?;
let state_ids_compressed = state
.iter()
.filter_map(|((event_type, state_key), pdu)| {
let shortstatekey = self
.get_or_create_shortstatekey(event_type, state_key, globals)
.ok()?;
Some(
self.compress_state_event(shortstatekey, &pdu.event_id, globals)
.ok()?,
)
})
.collect::<HashSet<_>>();
let state_hash = self.calculate_hash(
&state
.values()
.map(|pdu| pdu.event_id.as_bytes())
&state_ids_compressed
.iter()
.map(|s| &s[..])
.collect::<Vec<_>>(),
);
@ -1857,8 +1909,8 @@ impl Rooms {
&room_version,
&Arc::new(pdu.clone()),
create_prev_event,
&auth_events,
None, // TODO: third_party_invite
|k, s| auth_events.get(&(k.clone(), s.to_owned())).map(Arc::clone),
)
.map_err(|e| {
error!("{:?}", e);