minor auth_chain optimizations/cleanup

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-09-25 03:52:28 +00:00 committed by strawberry
parent 4776fe66c4
commit 3f7ec4221d
8 changed files with 125 additions and 118 deletions

View file

@ -27,8 +27,10 @@ pub(super) async fn echo(&self, message: Vec<String>) -> Result<RoomMessageEvent
#[admin_command] #[admin_command]
pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> { pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> {
let event_id = Arc::<EventId>::from(event_id); let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await else {
if let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await { return Ok(RoomMessageEventContent::notice_plain("Event not found."));
};
let room_id_str = event let room_id_str = event
.get("room_id") .get("room_id")
.and_then(|val| val.as_str()) .and_then(|val| val.as_str())
@ -42,7 +44,7 @@ pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<Room
.services .services
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(room_id, vec![event_id]) .event_ids_iter(room_id, &[&event_id])
.await? .await?
.count() .count()
.await; .await;
@ -51,9 +53,6 @@ pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<Room
Ok(RoomMessageEventContent::text_plain(format!( Ok(RoomMessageEventContent::text_plain(format!(
"Loaded auth chain with length {count} in {elapsed:?}" "Loaded auth chain with length {count} in {elapsed:?}"
))) )))
} else {
Ok(RoomMessageEventContent::text_plain("Event not found."))
}
} }
#[admin_command] #[admin_command]

View file

@ -1,4 +1,4 @@
use std::sync::Arc; use std::borrow::Borrow;
use axum::extract::State; use axum::extract::State;
use conduit::{Error, Result}; use conduit::{Error, Result};
@ -57,7 +57,7 @@ pub(crate) async fn get_event_authorization_route(
let auth_chain = services let auth_chain = services
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(room_id, vec![Arc::from(&*body.event_id)]) .event_ids_iter(room_id, &[body.event_id.borrow()])
.await? .await?
.filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() }) .filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() })
.then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))

View file

@ -1,6 +1,6 @@
#![allow(deprecated)] #![allow(deprecated)]
use std::collections::BTreeMap; use std::{borrow::Borrow, collections::BTreeMap};
use axum::extract::State; use axum::extract::State;
use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result}; use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result};
@ -11,7 +11,7 @@ use ruma::{
room::member::{MembershipState, RoomMemberEventContent}, room::member::{MembershipState, RoomMemberEventContent},
StateEventType, StateEventType,
}, },
CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName, CanonicalJsonValue, EventId, OwnedServerName, OwnedUserId, RoomId, ServerName,
}; };
use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use service::Services; use service::Services;
@ -196,10 +196,11 @@ async fn create_join_event(
.try_collect() .try_collect()
.await?; .await?;
let starting_events: Vec<&EventId> = state_ids.values().map(Borrow::borrow).collect();
let auth_chain = services let auth_chain = services
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(room_id, state_ids.values().cloned().collect()) .event_ids_iter(room_id, &starting_events)
.await? .await?
.map(Ok) .map(Ok)
.and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await }) .and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await })

View file

@ -1,4 +1,4 @@
use std::sync::Arc; use std::borrow::Borrow;
use axum::extract::State; use axum::extract::State;
use conduit::{err, result::LogErr, utils::IterStream, Err, Result}; use conduit::{err, result::LogErr, utils::IterStream, Err, Result};
@ -63,7 +63,7 @@ pub(crate) async fn get_room_state_route(
let auth_chain = services let auth_chain = services
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) .event_ids_iter(&body.room_id, &[body.event_id.borrow()])
.await? .await?
.map(Ok) .map(Ok)
.and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await }) .and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await })

View file

@ -1,4 +1,4 @@
use std::sync::Arc; use std::borrow::Borrow;
use axum::extract::State; use axum::extract::State;
use conduit::{err, Err}; use conduit::{err, Err};
@ -55,7 +55,7 @@ pub(crate) async fn get_room_state_ids_route(
let auth_chain_ids = services let auth_chain_ids = services
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) .event_ids_iter(&body.room_id, &[body.event_id.borrow()])
.await? .await?
.map(|id| (*id).to_owned()) .map(|id| (*id).to_owned())
.collect() .collect()

View file

@ -3,7 +3,7 @@ use std::{
sync::{Arc, Mutex}, sync::{Arc, Mutex},
}; };
use conduit::{utils, utils::math::usize_from_f64, Result}; use conduit::{err, utils, utils::math::usize_from_f64, Err, Result};
use database::Map; use database::Map;
use lru_cache::LruCache; use lru_cache::LruCache;
@ -24,54 +24,63 @@ impl Data {
} }
} }
pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> { pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[u64]>> {
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
// Check RAM cache // Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { if let Some(result) = self
return Ok(Some(Arc::clone(result))); .auth_chain_cache
.lock()
.expect("cache locked")
.get_mut(key)
{
return Ok(Arc::clone(result));
} }
// We only save auth chains for single events in the db // We only save auth chains for single events in the db
if key.len() == 1 { if key.len() != 1 {
// Check DB cache return Err!(Request(NotFound("auth_chain not cached")));
let chain = self.shorteventid_authchain.qry(&key[0]).await.map(|chain| { }
chain
// Check database
let chain = self
.shorteventid_authchain
.qry(&key[0])
.await
.map_err(|_| err!(Request(NotFound("auth_chain not found"))))?;
let chain = chain
.chunks_exact(size_of::<u64>()) .chunks_exact(size_of::<u64>())
.map(utils::u64_from_u8) .map(utils::u64_from_u8)
.collect::<Arc<[u64]>>() .collect::<Arc<[u64]>>();
});
if let Ok(chain) = chain {
// Cache in RAM // Cache in RAM
self.auth_chain_cache self.auth_chain_cache
.lock() .lock()
.expect("locked") .expect("cache locked")
.insert(vec![key[0]], Arc::clone(&chain)); .insert(vec![key[0]], Arc::clone(&chain));
return Ok(Some(chain)); Ok(chain)
}
} }
Ok(None) pub(super) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) {
} debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
pub(super) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()> {
// Only persist single events in db // Only persist single events in db
if key.len() == 1 { if key.len() == 1 {
self.shorteventid_authchain.insert( let key = key[0].to_be_bytes();
&key[0].to_be_bytes(), let val = auth_chain
&auth_chain
.iter() .iter()
.flat_map(|s| s.to_be_bytes().to_vec()) .flat_map(|s| s.to_be_bytes().to_vec())
.collect::<Vec<u8>>(), .collect::<Vec<u8>>();
);
self.shorteventid_authchain.insert(&key, &val);
} }
// Cache in RAM // Cache in RAM
self.auth_chain_cache self.auth_chain_cache
.lock() .lock()
.expect("locked") .expect("cache locked")
.insert(key, auth_chain); .insert(key, auth_chain);
Ok(())
} }
} }

View file

@ -37,25 +37,18 @@ impl crate::Service for Service {
} }
impl Service { impl Service {
pub async fn event_ids_iter<'a>( pub async fn event_ids_iter(
&'a self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>, &self, room_id: &RoomId, starting_events: &[&EventId],
) -> Result<impl Stream<Item = Arc<EventId>> + Send + 'a> { ) -> Result<impl Stream<Item = Arc<EventId>> + Send + '_> {
let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len()); let chain = self.get_auth_chain(room_id, starting_events).await?;
for starting_event in &starting_events_ { let iter = chain.into_iter().stream().filter_map(|sid| {
starting_events.push(starting_event);
}
Ok(self
.get_auth_chain(room_id, &starting_events)
.await?
.into_iter()
.stream()
.filter_map(|sid| {
self.services self.services
.short .short
.get_eventid_from_short(sid) .get_eventid_from_short(sid)
.map(Result::ok) .map(Result::ok)
})) });
Ok(iter)
} }
#[tracing::instrument(skip_all, name = "auth_chain")] #[tracing::instrument(skip_all, name = "auth_chain")]
@ -93,7 +86,7 @@ impl Service {
} }
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect(); let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key).await? { if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await {
trace!("Found cache entry for whole chunk"); trace!("Found cache entry for whole chunk");
full_auth_chain.extend(cached.iter().copied()); full_auth_chain.extend(cached.iter().copied());
hits = hits.saturating_add(1); hits = hits.saturating_add(1);
@ -104,13 +97,13 @@ impl Service {
let mut misses2: usize = 0; let mut misses2: usize = 0;
let mut chunk_cache = Vec::with_capacity(chunk.len()); let mut chunk_cache = Vec::with_capacity(chunk.len());
for (sevent_id, event_id) in chunk { for (sevent_id, event_id) in chunk {
if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await? { if let Ok(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await {
trace!(?event_id, "Found cache entry for event"); trace!(?event_id, "Found cache entry for event");
chunk_cache.extend(cached.iter().copied()); chunk_cache.extend(cached.iter().copied());
hits2 = hits2.saturating_add(1); hits2 = hits2.saturating_add(1);
} else { } else {
let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?; let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?;
self.cache_auth_chain(vec![sevent_id], &auth_chain)?; self.cache_auth_chain(vec![sevent_id], &auth_chain);
chunk_cache.extend(auth_chain.iter()); chunk_cache.extend(auth_chain.iter());
misses2 = misses2.saturating_add(1); misses2 = misses2.saturating_add(1);
debug!( debug!(
@ -125,7 +118,7 @@ impl Service {
chunk_cache.sort_unstable(); chunk_cache.sort_unstable();
chunk_cache.dedup(); chunk_cache.dedup();
self.cache_auth_chain_vec(chunk_key, &chunk_cache)?; self.cache_auth_chain_vec(chunk_key, &chunk_cache);
full_auth_chain.extend(chunk_cache.iter()); full_auth_chain.extend(chunk_cache.iter());
misses = misses.saturating_add(1); misses = misses.saturating_add(1);
debug!( debug!(
@ -163,11 +156,11 @@ impl Service {
Ok(pdu) => { Ok(pdu) => {
if pdu.room_id != room_id { if pdu.room_id != room_id {
return Err!(Request(Forbidden( return Err!(Request(Forbidden(
"auth event {event_id:?} for incorrect room {} which is not {}", "auth event {event_id:?} for incorrect room {} which is not {room_id}",
pdu.room_id, pdu.room_id,
room_id
))); )));
} }
for auth_event in &pdu.auth_events { for auth_event in &pdu.auth_events {
let sauthevent = self let sauthevent = self
.services .services
@ -187,20 +180,21 @@ impl Service {
Ok(found) Ok(found)
} }
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> { #[inline]
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[u64]>> {
self.db.get_cached_eventid_authchain(key).await self.db.get_cached_eventid_authchain(key).await
} }
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) -> Result<()> { pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) {
self.db let val = auth_chain.iter().copied().collect::<Arc<[u64]>>();
.cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>()) self.db.cache_auth_chain(key, val);
} }
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<u64>) -> Result<()> { pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<u64>) {
self.db let val = auth_chain.iter().copied().collect::<Arc<[u64]>>();
.cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>()) self.db.cache_auth_chain(key, val);
} }
pub fn get_cache_usage(&self) -> (usize, usize) { pub fn get_cache_usage(&self) -> (usize, usize) {

View file

@ -1,6 +1,7 @@
mod parse_incoming_pdu; mod parse_incoming_pdu;
use std::{ use std::{
borrow::Borrow,
collections::{hash_map, BTreeMap, HashMap, HashSet}, collections::{hash_map, BTreeMap, HashMap, HashSet},
fmt::Write, fmt::Write,
sync::{Arc, RwLock as StdRwLock}, sync::{Arc, RwLock as StdRwLock},
@ -773,6 +774,7 @@ impl Service {
Ok(pdu_id) Ok(pdu_id)
} }
#[tracing::instrument(skip_all, name = "resolve")]
pub async fn resolve_state( pub async fn resolve_state(
&self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>, &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>,
) -> Result<Arc<HashSet<CompressedStateEvent>>> { ) -> Result<Arc<HashSet<CompressedStateEvent>>> {
@ -793,14 +795,17 @@ impl Service {
let fork_states = [current_state_ids, incoming_state]; let fork_states = [current_state_ids, incoming_state];
let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); let mut auth_chain_sets = Vec::with_capacity(fork_states.len());
for state in &fork_states { for state in &fork_states {
auth_chain_sets.push( let starting_events: Vec<&EventId> = state.values().map(Borrow::borrow).collect();
self.services
let auth_chain = self
.services
.auth_chain .auth_chain
.event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect()) .event_ids_iter(room_id, &starting_events)
.await? .await?
.collect::<HashSet<Arc<EventId>>>() .collect::<HashSet<Arc<EventId>>>()
.await, .await;
);
auth_chain_sets.push(auth_chain);
} }
debug!("Loading fork states"); debug!("Loading fork states");
@ -962,12 +967,11 @@ impl Service {
let mut state = StateMap::with_capacity(leaf_state.len()); let mut state = StateMap::with_capacity(leaf_state.len());
let mut starting_events = Vec::with_capacity(leaf_state.len()); let mut starting_events = Vec::with_capacity(leaf_state.len());
for (k, id) in &leaf_state {
for (k, id) in leaf_state {
if let Ok((ty, st_key)) = self if let Ok((ty, st_key)) = self
.services .services
.short .short
.get_statekey_from_short(k) .get_statekey_from_short(*k)
.await .await
.log_err() .log_err()
{ {
@ -976,18 +980,18 @@ impl Service {
state.insert((ty.to_string().into(), st_key), id.clone()); state.insert((ty.to_string().into(), st_key), id.clone());
} }
starting_events.push(id); starting_events.push(id.borrow());
} }
auth_chain_sets.push( let auth_chain = self
self.services .services
.auth_chain .auth_chain
.event_ids_iter(room_id, starting_events) .event_ids_iter(room_id, &starting_events)
.await? .await?
.collect() .collect()
.await, .await;
);
auth_chain_sets.push(auth_chain);
fork_states.push(state); fork_states.push(state);
} }