diff --git a/src/database/map/get.rs b/src/database/map/get.rs index 72382e36..2f7df031 100644 --- a/src/database/map/get.rs +++ b/src/database/map/get.rs @@ -1,8 +1,8 @@ use std::{convert::AsRef, fmt::Debug, future::Future, io::Write}; use arrayvec::ArrayVec; -use conduit::{err, implement, Result}; -use futures::future::ready; +use conduit::{err, implement, utils::IterStream, Result}; +use futures::{future::ready, Stream}; use rocksdb::DBPinnableSlice; use serde::Serialize; @@ -50,6 +50,7 @@ where /// Fetch a value from the database into cache, returning a reference-handle /// asynchronously. The key is referenced directly to perform the query. #[implement(super::Map)] +#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] pub fn get(&self, key: &K) -> impl Future>> + Send where K: AsRef<[u8]> + ?Sized + Debug, @@ -61,10 +62,9 @@ where /// The key is referenced directly to perform the query. This is a thread- /// blocking call. #[implement(super::Map)] -#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] pub fn get_blocking(&self, key: &K) -> Result> where - K: AsRef<[u8]> + ?Sized + Debug, + K: AsRef<[u8]> + ?Sized, { let res = self .db @@ -76,10 +76,19 @@ where #[implement(super::Map)] #[tracing::instrument(skip(self, keys), fields(%self), level = "trace")] -pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> Vec>> +pub fn get_batch<'a, I, K>(&self, keys: I) -> impl Stream>> where I: Iterator + ExactSizeIterator + Send + Debug, - K: AsRef<[u8]> + Sized + Debug + 'a, + K: AsRef<[u8]> + Send + Sync + Sized + Debug + 'a, +{ + self.get_batch_blocking(keys).stream() +} + +#[implement(super::Map)] +pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> impl Iterator>> +where + I: Iterator + ExactSizeIterator + Send, + K: AsRef<[u8]> + Sized + 'a, { // Optimization can be `true` if key vector is pre-sorted **by the column // comparator**. @@ -91,7 +100,6 @@ where .batched_multi_get_cf_opt(&self.cf(), keys, SORTED, read_options) .into_iter() .map(into_result_handle) - .collect() } fn into_result_handle(result: RocksdbResult<'_>) -> Result> { diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index c22732c2..cabb6f0c 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -6,7 +6,7 @@ use std::{ }; use conduit::{debug, debug_error, trace, utils::IterStream, validated, warn, Err, Result}; -use futures::Stream; +use futures::{Stream, StreamExt}; use ruma::{EventId, RoomId}; use self::data::Data; @@ -69,15 +69,15 @@ impl Service { const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new(); let started = std::time::Instant::now(); - let mut buckets = [BUCKET; NUM_BUCKETS]; - for (i, &short) in self + let mut starting_ids = self .services .short .multi_get_or_create_shorteventid(starting_events) - .await - .iter() .enumerate() - { + .boxed(); + + let mut buckets = [BUCKET; NUM_BUCKETS]; + while let Some((i, short)) = starting_ids.next().await { let bucket: usize = short.try_into()?; let bucket: usize = validated!(bucket % NUM_BUCKETS); buckets[bucket].insert((short, starting_events[i])); diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index e8b00d9b..703df796 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -3,6 +3,7 @@ use std::{mem::size_of_val, sync::Arc}; pub use conduit::pdu::{ShortEventId, ShortId, ShortRoomId}; use conduit::{err, implement, utils, Result}; use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; use ruma::{events::StateEventType, EventId, RoomId}; use crate::{globals, Dep}; @@ -71,11 +72,12 @@ pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEvent } #[implement(Service)] -pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec { +pub fn multi_get_or_create_shorteventid<'a>( + &'a self, event_ids: &'a [&EventId], +) -> impl Stream + Send + 'a { self.db .eventid_shorteventid - .get_batch_blocking(event_ids.iter()) - .into_iter() + .get_batch(event_ids.iter()) .enumerate() .map(|(i, result)| match result { Ok(ref short) => utils::u64_from_u8(short), @@ -95,7 +97,6 @@ pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> short }, }) - .collect() } #[implement(Service)] @@ -163,10 +164,10 @@ pub async fn multi_get_eventid_from_short(&self, shorteventid: &[ShortEventId]) self.db .shorteventid_eventid - .get_batch_blocking(keys.iter()) - .into_iter() + .get_batch(keys.iter()) .map(Deserialized::deserialized) .collect() + .await } #[implement(Service)]