minor auth_chain optimizations/cleanup
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
4776fe66c4
commit
3f7ec4221d
8 changed files with 125 additions and 118 deletions
|
@ -27,33 +27,32 @@ pub(super) async fn echo(&self, message: Vec<String>) -> Result<RoomMessageEvent
|
|||
|
||||
#[admin_command]
|
||||
pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> {
|
||||
let event_id = Arc::<EventId>::from(event_id);
|
||||
if let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await {
|
||||
let room_id_str = event
|
||||
.get("room_id")
|
||||
.and_then(|val| val.as_str())
|
||||
.ok_or_else(|| Error::bad_database("Invalid event in database"))?;
|
||||
let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await else {
|
||||
return Ok(RoomMessageEventContent::notice_plain("Event not found."));
|
||||
};
|
||||
|
||||
let room_id = <&RoomId>::try_from(room_id_str)
|
||||
.map_err(|_| Error::bad_database("Invalid room id field in event in database"))?;
|
||||
let room_id_str = event
|
||||
.get("room_id")
|
||||
.and_then(|val| val.as_str())
|
||||
.ok_or_else(|| Error::bad_database("Invalid event in database"))?;
|
||||
|
||||
let start = Instant::now();
|
||||
let count = self
|
||||
.services
|
||||
.rooms
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, vec![event_id])
|
||||
.await?
|
||||
.count()
|
||||
.await;
|
||||
let room_id = <&RoomId>::try_from(room_id_str)
|
||||
.map_err(|_| Error::bad_database("Invalid room id field in event in database"))?;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
Ok(RoomMessageEventContent::text_plain(format!(
|
||||
"Loaded auth chain with length {count} in {elapsed:?}"
|
||||
)))
|
||||
} else {
|
||||
Ok(RoomMessageEventContent::text_plain("Event not found."))
|
||||
}
|
||||
let start = Instant::now();
|
||||
let count = self
|
||||
.services
|
||||
.rooms
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, &[&event_id])
|
||||
.await?
|
||||
.count()
|
||||
.await;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
Ok(RoomMessageEventContent::text_plain(format!(
|
||||
"Loaded auth chain with length {count} in {elapsed:?}"
|
||||
)))
|
||||
}
|
||||
|
||||
#[admin_command]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::sync::Arc;
|
||||
use std::borrow::Borrow;
|
||||
|
||||
use axum::extract::State;
|
||||
use conduit::{Error, Result};
|
||||
|
@ -57,7 +57,7 @@ pub(crate) async fn get_event_authorization_route(
|
|||
let auth_chain = services
|
||||
.rooms
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, vec![Arc::from(&*body.event_id)])
|
||||
.event_ids_iter(room_id, &[body.event_id.borrow()])
|
||||
.await?
|
||||
.filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() })
|
||||
.then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#![allow(deprecated)]
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::{borrow::Borrow, collections::BTreeMap};
|
||||
|
||||
use axum::extract::State;
|
||||
use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result};
|
||||
|
@ -11,7 +11,7 @@ use ruma::{
|
|||
room::member::{MembershipState, RoomMemberEventContent},
|
||||
StateEventType,
|
||||
},
|
||||
CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName,
|
||||
CanonicalJsonValue, EventId, OwnedServerName, OwnedUserId, RoomId, ServerName,
|
||||
};
|
||||
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
|
||||
use service::Services;
|
||||
|
@ -196,10 +196,11 @@ async fn create_join_event(
|
|||
.try_collect()
|
||||
.await?;
|
||||
|
||||
let starting_events: Vec<&EventId> = state_ids.values().map(Borrow::borrow).collect();
|
||||
let auth_chain = services
|
||||
.rooms
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, state_ids.values().cloned().collect())
|
||||
.event_ids_iter(room_id, &starting_events)
|
||||
.await?
|
||||
.map(Ok)
|
||||
.and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await })
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::sync::Arc;
|
||||
use std::borrow::Borrow;
|
||||
|
||||
use axum::extract::State;
|
||||
use conduit::{err, result::LogErr, utils::IterStream, Err, Result};
|
||||
|
@ -63,7 +63,7 @@ pub(crate) async fn get_room_state_route(
|
|||
let auth_chain = services
|
||||
.rooms
|
||||
.auth_chain
|
||||
.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)])
|
||||
.event_ids_iter(&body.room_id, &[body.event_id.borrow()])
|
||||
.await?
|
||||
.map(Ok)
|
||||
.and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await })
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::sync::Arc;
|
||||
use std::borrow::Borrow;
|
||||
|
||||
use axum::extract::State;
|
||||
use conduit::{err, Err};
|
||||
|
@ -55,7 +55,7 @@ pub(crate) async fn get_room_state_ids_route(
|
|||
let auth_chain_ids = services
|
||||
.rooms
|
||||
.auth_chain
|
||||
.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)])
|
||||
.event_ids_iter(&body.room_id, &[body.event_id.borrow()])
|
||||
.await?
|
||||
.map(|id| (*id).to_owned())
|
||||
.collect()
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::{
|
|||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use conduit::{utils, utils::math::usize_from_f64, Result};
|
||||
use conduit::{err, utils, utils::math::usize_from_f64, Err, Result};
|
||||
use database::Map;
|
||||
use lru_cache::LruCache;
|
||||
|
||||
|
@ -24,54 +24,63 @@ impl Data {
|
|||
}
|
||||
}
|
||||
|
||||
pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
|
||||
pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[u64]>> {
|
||||
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
|
||||
|
||||
// Check RAM cache
|
||||
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
|
||||
return Ok(Some(Arc::clone(result)));
|
||||
if let Some(result) = self
|
||||
.auth_chain_cache
|
||||
.lock()
|
||||
.expect("cache locked")
|
||||
.get_mut(key)
|
||||
{
|
||||
return Ok(Arc::clone(result));
|
||||
}
|
||||
|
||||
// We only save auth chains for single events in the db
|
||||
if key.len() == 1 {
|
||||
// Check DB cache
|
||||
let chain = self.shorteventid_authchain.qry(&key[0]).await.map(|chain| {
|
||||
chain
|
||||
.chunks_exact(size_of::<u64>())
|
||||
.map(utils::u64_from_u8)
|
||||
.collect::<Arc<[u64]>>()
|
||||
});
|
||||
|
||||
if let Ok(chain) = chain {
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.insert(vec![key[0]], Arc::clone(&chain));
|
||||
|
||||
return Ok(Some(chain));
|
||||
}
|
||||
if key.len() != 1 {
|
||||
return Err!(Request(NotFound("auth_chain not cached")));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
// Check database
|
||||
let chain = self
|
||||
.shorteventid_authchain
|
||||
.qry(&key[0])
|
||||
.await
|
||||
.map_err(|_| err!(Request(NotFound("auth_chain not found"))))?;
|
||||
|
||||
let chain = chain
|
||||
.chunks_exact(size_of::<u64>())
|
||||
.map(utils::u64_from_u8)
|
||||
.collect::<Arc<[u64]>>();
|
||||
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache
|
||||
.lock()
|
||||
.expect("cache locked")
|
||||
.insert(vec![key[0]], Arc::clone(&chain));
|
||||
|
||||
Ok(chain)
|
||||
}
|
||||
|
||||
pub(super) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()> {
|
||||
pub(super) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) {
|
||||
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
|
||||
|
||||
// Only persist single events in db
|
||||
if key.len() == 1 {
|
||||
self.shorteventid_authchain.insert(
|
||||
&key[0].to_be_bytes(),
|
||||
&auth_chain
|
||||
.iter()
|
||||
.flat_map(|s| s.to_be_bytes().to_vec())
|
||||
.collect::<Vec<u8>>(),
|
||||
);
|
||||
let key = key[0].to_be_bytes();
|
||||
let val = auth_chain
|
||||
.iter()
|
||||
.flat_map(|s| s.to_be_bytes().to_vec())
|
||||
.collect::<Vec<u8>>();
|
||||
|
||||
self.shorteventid_authchain.insert(&key, &val);
|
||||
}
|
||||
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.expect("cache locked")
|
||||
.insert(key, auth_chain);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,25 +37,18 @@ impl crate::Service for Service {
|
|||
}
|
||||
|
||||
impl Service {
|
||||
pub async fn event_ids_iter<'a>(
|
||||
&'a self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>,
|
||||
) -> Result<impl Stream<Item = Arc<EventId>> + Send + 'a> {
|
||||
let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len());
|
||||
for starting_event in &starting_events_ {
|
||||
starting_events.push(starting_event);
|
||||
}
|
||||
pub async fn event_ids_iter(
|
||||
&self, room_id: &RoomId, starting_events: &[&EventId],
|
||||
) -> Result<impl Stream<Item = Arc<EventId>> + Send + '_> {
|
||||
let chain = self.get_auth_chain(room_id, starting_events).await?;
|
||||
let iter = chain.into_iter().stream().filter_map(|sid| {
|
||||
self.services
|
||||
.short
|
||||
.get_eventid_from_short(sid)
|
||||
.map(Result::ok)
|
||||
});
|
||||
|
||||
Ok(self
|
||||
.get_auth_chain(room_id, &starting_events)
|
||||
.await?
|
||||
.into_iter()
|
||||
.stream()
|
||||
.filter_map(|sid| {
|
||||
self.services
|
||||
.short
|
||||
.get_eventid_from_short(sid)
|
||||
.map(Result::ok)
|
||||
}))
|
||||
Ok(iter)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "auth_chain")]
|
||||
|
@ -93,7 +86,7 @@ impl Service {
|
|||
}
|
||||
|
||||
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
|
||||
if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key).await? {
|
||||
if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await {
|
||||
trace!("Found cache entry for whole chunk");
|
||||
full_auth_chain.extend(cached.iter().copied());
|
||||
hits = hits.saturating_add(1);
|
||||
|
@ -104,13 +97,13 @@ impl Service {
|
|||
let mut misses2: usize = 0;
|
||||
let mut chunk_cache = Vec::with_capacity(chunk.len());
|
||||
for (sevent_id, event_id) in chunk {
|
||||
if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await? {
|
||||
if let Ok(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await {
|
||||
trace!(?event_id, "Found cache entry for event");
|
||||
chunk_cache.extend(cached.iter().copied());
|
||||
hits2 = hits2.saturating_add(1);
|
||||
} else {
|
||||
let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?;
|
||||
self.cache_auth_chain(vec![sevent_id], &auth_chain)?;
|
||||
self.cache_auth_chain(vec![sevent_id], &auth_chain);
|
||||
chunk_cache.extend(auth_chain.iter());
|
||||
misses2 = misses2.saturating_add(1);
|
||||
debug!(
|
||||
|
@ -125,7 +118,7 @@ impl Service {
|
|||
|
||||
chunk_cache.sort_unstable();
|
||||
chunk_cache.dedup();
|
||||
self.cache_auth_chain_vec(chunk_key, &chunk_cache)?;
|
||||
self.cache_auth_chain_vec(chunk_key, &chunk_cache);
|
||||
full_auth_chain.extend(chunk_cache.iter());
|
||||
misses = misses.saturating_add(1);
|
||||
debug!(
|
||||
|
@ -163,11 +156,11 @@ impl Service {
|
|||
Ok(pdu) => {
|
||||
if pdu.room_id != room_id {
|
||||
return Err!(Request(Forbidden(
|
||||
"auth event {event_id:?} for incorrect room {} which is not {}",
|
||||
"auth event {event_id:?} for incorrect room {} which is not {room_id}",
|
||||
pdu.room_id,
|
||||
room_id
|
||||
)));
|
||||
}
|
||||
|
||||
for auth_event in &pdu.auth_events {
|
||||
let sauthevent = self
|
||||
.services
|
||||
|
@ -187,20 +180,21 @@ impl Service {
|
|||
Ok(found)
|
||||
}
|
||||
|
||||
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
|
||||
#[inline]
|
||||
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[u64]>> {
|
||||
self.db.get_cached_eventid_authchain(key).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) -> Result<()> {
|
||||
self.db
|
||||
.cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>())
|
||||
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) {
|
||||
let val = auth_chain.iter().copied().collect::<Arc<[u64]>>();
|
||||
self.db.cache_auth_chain(key, val);
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<u64>) -> Result<()> {
|
||||
self.db
|
||||
.cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>())
|
||||
pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<u64>) {
|
||||
let val = auth_chain.iter().copied().collect::<Arc<[u64]>>();
|
||||
self.db.cache_auth_chain(key, val);
|
||||
}
|
||||
|
||||
pub fn get_cache_usage(&self) -> (usize, usize) {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
mod parse_incoming_pdu;
|
||||
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
collections::{hash_map, BTreeMap, HashMap, HashSet},
|
||||
fmt::Write,
|
||||
sync::{Arc, RwLock as StdRwLock},
|
||||
|
@ -773,6 +774,7 @@ impl Service {
|
|||
Ok(pdu_id)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "resolve")]
|
||||
pub async fn resolve_state(
|
||||
&self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>,
|
||||
) -> Result<Arc<HashSet<CompressedStateEvent>>> {
|
||||
|
@ -793,14 +795,17 @@ impl Service {
|
|||
let fork_states = [current_state_ids, incoming_state];
|
||||
let mut auth_chain_sets = Vec::with_capacity(fork_states.len());
|
||||
for state in &fork_states {
|
||||
auth_chain_sets.push(
|
||||
self.services
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect())
|
||||
.await?
|
||||
.collect::<HashSet<Arc<EventId>>>()
|
||||
.await,
|
||||
);
|
||||
let starting_events: Vec<&EventId> = state.values().map(Borrow::borrow).collect();
|
||||
|
||||
let auth_chain = self
|
||||
.services
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, &starting_events)
|
||||
.await?
|
||||
.collect::<HashSet<Arc<EventId>>>()
|
||||
.await;
|
||||
|
||||
auth_chain_sets.push(auth_chain);
|
||||
}
|
||||
|
||||
debug!("Loading fork states");
|
||||
|
@ -962,12 +967,11 @@ impl Service {
|
|||
|
||||
let mut state = StateMap::with_capacity(leaf_state.len());
|
||||
let mut starting_events = Vec::with_capacity(leaf_state.len());
|
||||
|
||||
for (k, id) in leaf_state {
|
||||
for (k, id) in &leaf_state {
|
||||
if let Ok((ty, st_key)) = self
|
||||
.services
|
||||
.short
|
||||
.get_statekey_from_short(k)
|
||||
.get_statekey_from_short(*k)
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
|
@ -976,18 +980,18 @@ impl Service {
|
|||
state.insert((ty.to_string().into(), st_key), id.clone());
|
||||
}
|
||||
|
||||
starting_events.push(id);
|
||||
starting_events.push(id.borrow());
|
||||
}
|
||||
|
||||
auth_chain_sets.push(
|
||||
self.services
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, starting_events)
|
||||
.await?
|
||||
.collect()
|
||||
.await,
|
||||
);
|
||||
let auth_chain = self
|
||||
.services
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, &starting_events)
|
||||
.await?
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
auth_chain_sets.push(auth_chain);
|
||||
fork_states.push(state);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue