minor auth_chain optimizations/cleanup

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-09-25 03:52:28 +00:00 committed by strawberry
parent 4776fe66c4
commit 3f7ec4221d
8 changed files with 125 additions and 118 deletions

View file

@ -27,33 +27,32 @@ pub(super) async fn echo(&self, message: Vec<String>) -> Result<RoomMessageEvent
#[admin_command]
pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> {
let event_id = Arc::<EventId>::from(event_id);
if let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await {
let room_id_str = event
.get("room_id")
.and_then(|val| val.as_str())
.ok_or_else(|| Error::bad_database("Invalid event in database"))?;
let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await else {
return Ok(RoomMessageEventContent::notice_plain("Event not found."));
};
let room_id = <&RoomId>::try_from(room_id_str)
.map_err(|_| Error::bad_database("Invalid room id field in event in database"))?;
let room_id_str = event
.get("room_id")
.and_then(|val| val.as_str())
.ok_or_else(|| Error::bad_database("Invalid event in database"))?;
let start = Instant::now();
let count = self
.services
.rooms
.auth_chain
.event_ids_iter(room_id, vec![event_id])
.await?
.count()
.await;
let room_id = <&RoomId>::try_from(room_id_str)
.map_err(|_| Error::bad_database("Invalid room id field in event in database"))?;
let elapsed = start.elapsed();
Ok(RoomMessageEventContent::text_plain(format!(
"Loaded auth chain with length {count} in {elapsed:?}"
)))
} else {
Ok(RoomMessageEventContent::text_plain("Event not found."))
}
let start = Instant::now();
let count = self
.services
.rooms
.auth_chain
.event_ids_iter(room_id, &[&event_id])
.await?
.count()
.await;
let elapsed = start.elapsed();
Ok(RoomMessageEventContent::text_plain(format!(
"Loaded auth chain with length {count} in {elapsed:?}"
)))
}
#[admin_command]

View file

@ -1,4 +1,4 @@
use std::sync::Arc;
use std::borrow::Borrow;
use axum::extract::State;
use conduit::{Error, Result};
@ -57,7 +57,7 @@ pub(crate) async fn get_event_authorization_route(
let auth_chain = services
.rooms
.auth_chain
.event_ids_iter(room_id, vec![Arc::from(&*body.event_id)])
.event_ids_iter(room_id, &[body.event_id.borrow()])
.await?
.filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() })
.then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))

View file

@ -1,6 +1,6 @@
#![allow(deprecated)]
use std::collections::BTreeMap;
use std::{borrow::Borrow, collections::BTreeMap};
use axum::extract::State;
use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result};
@ -11,7 +11,7 @@ use ruma::{
room::member::{MembershipState, RoomMemberEventContent},
StateEventType,
},
CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName,
CanonicalJsonValue, EventId, OwnedServerName, OwnedUserId, RoomId, ServerName,
};
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use service::Services;
@ -196,10 +196,11 @@ async fn create_join_event(
.try_collect()
.await?;
let starting_events: Vec<&EventId> = state_ids.values().map(Borrow::borrow).collect();
let auth_chain = services
.rooms
.auth_chain
.event_ids_iter(room_id, state_ids.values().cloned().collect())
.event_ids_iter(room_id, &starting_events)
.await?
.map(Ok)
.and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await })

View file

@ -1,4 +1,4 @@
use std::sync::Arc;
use std::borrow::Borrow;
use axum::extract::State;
use conduit::{err, result::LogErr, utils::IterStream, Err, Result};
@ -63,7 +63,7 @@ pub(crate) async fn get_room_state_route(
let auth_chain = services
.rooms
.auth_chain
.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)])
.event_ids_iter(&body.room_id, &[body.event_id.borrow()])
.await?
.map(Ok)
.and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await })

View file

@ -1,4 +1,4 @@
use std::sync::Arc;
use std::borrow::Borrow;
use axum::extract::State;
use conduit::{err, Err};
@ -55,7 +55,7 @@ pub(crate) async fn get_room_state_ids_route(
let auth_chain_ids = services
.rooms
.auth_chain
.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)])
.event_ids_iter(&body.room_id, &[body.event_id.borrow()])
.await?
.map(|id| (*id).to_owned())
.collect()

View file

@ -3,7 +3,7 @@ use std::{
sync::{Arc, Mutex},
};
use conduit::{utils, utils::math::usize_from_f64, Result};
use conduit::{err, utils, utils::math::usize_from_f64, Err, Result};
use database::Map;
use lru_cache::LruCache;
@ -24,54 +24,63 @@ impl Data {
}
}
pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[u64]>> {
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
// Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
return Ok(Some(Arc::clone(result)));
if let Some(result) = self
.auth_chain_cache
.lock()
.expect("cache locked")
.get_mut(key)
{
return Ok(Arc::clone(result));
}
// We only save auth chains for single events in the db
if key.len() == 1 {
// Check DB cache
let chain = self.shorteventid_authchain.qry(&key[0]).await.map(|chain| {
chain
.chunks_exact(size_of::<u64>())
.map(utils::u64_from_u8)
.collect::<Arc<[u64]>>()
});
if let Ok(chain) = chain {
// Cache in RAM
self.auth_chain_cache
.lock()
.expect("locked")
.insert(vec![key[0]], Arc::clone(&chain));
return Ok(Some(chain));
}
if key.len() != 1 {
return Err!(Request(NotFound("auth_chain not cached")));
}
Ok(None)
// Check database
let chain = self
.shorteventid_authchain
.qry(&key[0])
.await
.map_err(|_| err!(Request(NotFound("auth_chain not found"))))?;
let chain = chain
.chunks_exact(size_of::<u64>())
.map(utils::u64_from_u8)
.collect::<Arc<[u64]>>();
// Cache in RAM
self.auth_chain_cache
.lock()
.expect("cache locked")
.insert(vec![key[0]], Arc::clone(&chain));
Ok(chain)
}
pub(super) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()> {
pub(super) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) {
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
// Only persist single events in db
if key.len() == 1 {
self.shorteventid_authchain.insert(
&key[0].to_be_bytes(),
&auth_chain
.iter()
.flat_map(|s| s.to_be_bytes().to_vec())
.collect::<Vec<u8>>(),
);
let key = key[0].to_be_bytes();
let val = auth_chain
.iter()
.flat_map(|s| s.to_be_bytes().to_vec())
.collect::<Vec<u8>>();
self.shorteventid_authchain.insert(&key, &val);
}
// Cache in RAM
self.auth_chain_cache
.lock()
.expect("locked")
.expect("cache locked")
.insert(key, auth_chain);
Ok(())
}
}

View file

@ -37,25 +37,18 @@ impl crate::Service for Service {
}
impl Service {
pub async fn event_ids_iter<'a>(
&'a self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>,
) -> Result<impl Stream<Item = Arc<EventId>> + Send + 'a> {
let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len());
for starting_event in &starting_events_ {
starting_events.push(starting_event);
}
pub async fn event_ids_iter(
&self, room_id: &RoomId, starting_events: &[&EventId],
) -> Result<impl Stream<Item = Arc<EventId>> + Send + '_> {
let chain = self.get_auth_chain(room_id, starting_events).await?;
let iter = chain.into_iter().stream().filter_map(|sid| {
self.services
.short
.get_eventid_from_short(sid)
.map(Result::ok)
});
Ok(self
.get_auth_chain(room_id, &starting_events)
.await?
.into_iter()
.stream()
.filter_map(|sid| {
self.services
.short
.get_eventid_from_short(sid)
.map(Result::ok)
}))
Ok(iter)
}
#[tracing::instrument(skip_all, name = "auth_chain")]
@ -93,7 +86,7 @@ impl Service {
}
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key).await? {
if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await {
trace!("Found cache entry for whole chunk");
full_auth_chain.extend(cached.iter().copied());
hits = hits.saturating_add(1);
@ -104,13 +97,13 @@ impl Service {
let mut misses2: usize = 0;
let mut chunk_cache = Vec::with_capacity(chunk.len());
for (sevent_id, event_id) in chunk {
if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await? {
if let Ok(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await {
trace!(?event_id, "Found cache entry for event");
chunk_cache.extend(cached.iter().copied());
hits2 = hits2.saturating_add(1);
} else {
let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?;
self.cache_auth_chain(vec![sevent_id], &auth_chain)?;
self.cache_auth_chain(vec![sevent_id], &auth_chain);
chunk_cache.extend(auth_chain.iter());
misses2 = misses2.saturating_add(1);
debug!(
@ -125,7 +118,7 @@ impl Service {
chunk_cache.sort_unstable();
chunk_cache.dedup();
self.cache_auth_chain_vec(chunk_key, &chunk_cache)?;
self.cache_auth_chain_vec(chunk_key, &chunk_cache);
full_auth_chain.extend(chunk_cache.iter());
misses = misses.saturating_add(1);
debug!(
@ -163,11 +156,11 @@ impl Service {
Ok(pdu) => {
if pdu.room_id != room_id {
return Err!(Request(Forbidden(
"auth event {event_id:?} for incorrect room {} which is not {}",
"auth event {event_id:?} for incorrect room {} which is not {room_id}",
pdu.room_id,
room_id
)));
}
for auth_event in &pdu.auth_events {
let sauthevent = self
.services
@ -187,20 +180,21 @@ impl Service {
Ok(found)
}
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
#[inline]
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[u64]>> {
self.db.get_cached_eventid_authchain(key).await
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) -> Result<()> {
self.db
.cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>())
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) {
let val = auth_chain.iter().copied().collect::<Arc<[u64]>>();
self.db.cache_auth_chain(key, val);
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<u64>) -> Result<()> {
self.db
.cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>())
pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<u64>) {
let val = auth_chain.iter().copied().collect::<Arc<[u64]>>();
self.db.cache_auth_chain(key, val);
}
pub fn get_cache_usage(&self) -> (usize, usize) {

View file

@ -1,6 +1,7 @@
mod parse_incoming_pdu;
use std::{
borrow::Borrow,
collections::{hash_map, BTreeMap, HashMap, HashSet},
fmt::Write,
sync::{Arc, RwLock as StdRwLock},
@ -773,6 +774,7 @@ impl Service {
Ok(pdu_id)
}
#[tracing::instrument(skip_all, name = "resolve")]
pub async fn resolve_state(
&self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>,
) -> Result<Arc<HashSet<CompressedStateEvent>>> {
@ -793,14 +795,17 @@ impl Service {
let fork_states = [current_state_ids, incoming_state];
let mut auth_chain_sets = Vec::with_capacity(fork_states.len());
for state in &fork_states {
auth_chain_sets.push(
self.services
.auth_chain
.event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect())
.await?
.collect::<HashSet<Arc<EventId>>>()
.await,
);
let starting_events: Vec<&EventId> = state.values().map(Borrow::borrow).collect();
let auth_chain = self
.services
.auth_chain
.event_ids_iter(room_id, &starting_events)
.await?
.collect::<HashSet<Arc<EventId>>>()
.await;
auth_chain_sets.push(auth_chain);
}
debug!("Loading fork states");
@ -962,12 +967,11 @@ impl Service {
let mut state = StateMap::with_capacity(leaf_state.len());
let mut starting_events = Vec::with_capacity(leaf_state.len());
for (k, id) in leaf_state {
for (k, id) in &leaf_state {
if let Ok((ty, st_key)) = self
.services
.short
.get_statekey_from_short(k)
.get_statekey_from_short(*k)
.await
.log_err()
{
@ -976,18 +980,18 @@ impl Service {
state.insert((ty.to_string().into(), st_key), id.clone());
}
starting_events.push(id);
starting_events.push(id.borrow());
}
auth_chain_sets.push(
self.services
.auth_chain
.event_ids_iter(room_id, starting_events)
.await?
.collect()
.await,
);
let auth_chain = self
.services
.auth_chain
.event_ids_iter(room_id, &starting_events)
.await?
.collect()
.await;
auth_chain_sets.push(auth_chain);
fork_states.push(state);
}