return stream from multi_get_eventid_from_short

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-12-03 09:42:26 +00:00
parent 48703173bc
commit 784ccd6bad
5 changed files with 39 additions and 56 deletions

View file

@ -6,7 +6,11 @@ use std::{
sync::Arc, sync::Arc,
}; };
use conduit::{debug, debug_error, trace, utils::IterStream, validated, warn, Err, Result}; use conduit::{
debug, debug_error, trace,
utils::{stream::ReadyExt, IterStream},
validated, warn, Err, Result,
};
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use ruma::{EventId, RoomId}; use ruma::{EventId, RoomId};
@ -61,11 +65,10 @@ impl Service {
let event_ids = self let event_ids = self
.services .services
.short .short
.multi_get_eventid_from_short(chain.into_iter()) .multi_get_eventid_from_short(chain.iter())
.await .ready_filter_map(Result::ok)
.into_iter() .collect()
.filter_map(Result::ok) .await;
.collect();
Ok(event_ids) Ok(event_ids)
} }

View file

@ -139,7 +139,7 @@ pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &s
#[implement(Service)] #[implement(Service)]
pub async fn get_eventid_from_short<Id>(&self, shorteventid: ShortEventId) -> Result<Id> pub async fn get_eventid_from_short<Id>(&self, shorteventid: ShortEventId) -> Result<Id>
where where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, Id: for<'de> Deserialize<'de> + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>, <Id as ToOwned>::Owned: Borrow<EventId>,
{ {
const BUFSIZE: usize = size_of::<ShortEventId>(); const BUFSIZE: usize = size_of::<ShortEventId>();
@ -153,22 +153,18 @@ where
} }
#[implement(Service)] #[implement(Service)]
pub async fn multi_get_eventid_from_short<Id, I>(&self, shorteventid: I) -> Vec<Result<Id>> pub fn multi_get_eventid_from_short<'a, Id, I>(&'a self, shorteventid: I) -> impl Stream<Item = Result<Id>> + Send + 'a
where where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, I: Iterator<Item = &'a ShortEventId> + Send + 'a,
Id: for<'de> Deserialize<'de> + Sized + ToOwned + 'a,
<Id as ToOwned>::Owned: Borrow<EventId>, <Id as ToOwned>::Owned: Borrow<EventId>,
I: Iterator<Item = ShortEventId> + Send,
{ {
const BUFSIZE: usize = size_of::<ShortEventId>(); const BUFSIZE: usize = size_of::<ShortEventId>();
let keys: Vec<[u8; BUFSIZE]> = shorteventid.map(u64::to_be_bytes).collect();
self.db self.db
.shorteventid_eventid .shorteventid_eventid
.get_batch(keys.iter()) .aqry_batch::<BUFSIZE, _, _>(shorteventid)
.map(Deserialized::deserialized) .map(Deserialized::deserialized)
.collect()
.await
} }
#[implement(Service)] #[implement(Service)]

View file

@ -6,7 +6,7 @@ use std::{
}; };
use conduit::{ use conduit::{
at, err, err,
result::FlatOk, result::FlatOk,
utils::{ utils::{
calculate_hash, calculate_hash,
@ -420,7 +420,7 @@ impl Service {
.collect() .collect()
.await; .await;
let auth_state: Vec<_> = self let (state_keys, event_ids): (Vec<_>, Vec<_>) = self
.services .services
.state_accessor .state_accessor
.state_full_shortids(shortstatehash) .state_full_shortids(shortstatehash)
@ -432,16 +432,13 @@ impl Service {
.remove(&shortstatekey) .remove(&shortstatekey)
.map(|(event_type, state_key)| ((event_type, state_key), shorteventid)) .map(|(event_type, state_key)| ((event_type, state_key), shorteventid))
}) })
.collect(); .unzip();
let auth_pdus = self let auth_pdus = self
.services .services
.short .short
.multi_get_eventid_from_short(auth_state.iter().map(at!(1))) .multi_get_eventid_from_short(event_ids.iter())
.await .zip(state_keys.into_iter().stream())
.into_iter()
.stream()
.zip(auth_state.into_iter().stream().map(at!(0)))
.ready_filter_map(|(event_id, tsk)| Some((tsk, event_id.ok()?))) .ready_filter_map(|(event_id, tsk)| Some((tsk, event_id.ok()?)))
.broad_filter_map(|(tsk, event_id): (_, OwnedEventId)| async move { .broad_filter_map(|(tsk, event_id): (_, OwnedEventId)| async move {
self.services self.services

View file

@ -1,12 +1,12 @@
use std::{borrow::Borrow, collections::HashMap, sync::Arc}; use std::{borrow::Borrow, collections::HashMap, sync::Arc};
use conduit::{ use conduit::{
at, err, at, err, ref_at,
utils::stream::{BroadbandExt, IterStream}, utils::stream::{BroadbandExt, IterStream, ReadyExt},
PduEvent, Result, PduEvent, Result,
}; };
use database::{Deserialized, Map}; use database::{Deserialized, Map};
use futures::{StreamExt, TryFutureExt}; use futures::{FutureExt, StreamExt, TryFutureExt};
use ruma::{events::StateEventType, EventId, OwnedEventId, RoomId}; use ruma::{events::StateEventType, EventId, OwnedEventId, RoomId};
use serde::Deserialize; use serde::Deserialize;
@ -59,23 +59,13 @@ impl Data {
} }
pub(super) async fn state_full_pdus(&self, shortstatehash: ShortStateHash) -> Result<Vec<PduEvent>> { pub(super) async fn state_full_pdus(&self, shortstatehash: ShortStateHash) -> Result<Vec<PduEvent>> {
let short_ids = self let short_ids = self.state_full_shortids(shortstatehash).await?;
.state_full_shortids(shortstatehash)
.await?
.into_iter()
.map(at!(1));
let event_ids = self let full_pdus = self
.services .services
.short .short
.multi_get_eventid_from_short(short_ids) .multi_get_eventid_from_short(short_ids.iter().map(ref_at!(1)))
.await .ready_filter_map(Result::ok)
.into_iter()
.filter_map(Result::ok);
let full_pdus = event_ids
.into_iter()
.stream()
.broad_filter_map( .broad_filter_map(
|event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await.ok() }, |event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await.ok() },
) )
@ -92,19 +82,16 @@ impl Data {
{ {
let short_ids = self.state_full_shortids(shortstatehash).await?; let short_ids = self.state_full_shortids(shortstatehash).await?;
let event_ids = self let full_ids = self
.services .services
.short .short
.multi_get_eventid_from_short(short_ids.iter().map(at!(1))) .multi_get_eventid_from_short(short_ids.iter().map(ref_at!(1)))
.zip(short_ids.iter().stream().map(at!(0)))
.ready_filter_map(|(event_id, shortstatekey)| Some((shortstatekey, event_id.ok()?)))
.collect()
.boxed()
.await; .await;
let full_ids = short_ids
.into_iter()
.map(at!(0))
.zip(event_ids.into_iter())
.filter_map(|(shortstatekey, event_id)| Some((shortstatekey, event_id.ok()?)))
.collect();
Ok(full_ids) Ok(full_ids)
} }
@ -134,7 +121,7 @@ impl Data {
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
) -> Result<Id> ) -> Result<Id>
where where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, Id: for<'de> Deserialize<'de> + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>, <Id as ToOwned>::Owned: Borrow<EventId>,
{ {
let shortstatekey = self let shortstatekey = self
@ -219,7 +206,7 @@ impl Data {
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Id> ) -> Result<Id>
where where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, Id: for<'de> Deserialize<'de> + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>, <Id as ToOwned>::Owned: Borrow<EventId>,
{ {
self.services self.services

View file

@ -102,7 +102,7 @@ impl Service {
/// Builds a StateMap by iterating over all keys that start /// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash. /// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub async fn state_full_ids<Id>(&self, shortstatehash: ShortStateHash) -> Result<HashMap<u64, Id>> pub async fn state_full_ids<Id>(&self, shortstatehash: ShortStateHash) -> Result<HashMap<ShortStateKey, Id>>
where where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>, <Id as ToOwned>::Owned: Borrow<EventId>,
@ -130,7 +130,7 @@ impl Service {
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
) -> Result<Id> ) -> Result<Id>
where where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, Id: for<'de> Deserialize<'de> + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>, <Id as ToOwned>::Owned: Borrow<EventId>,
{ {
self.db self.db
@ -154,7 +154,7 @@ impl Service {
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
) -> Result<T> ) -> Result<T>
where where
T: for<'de> Deserialize<'de> + Send, T: for<'de> Deserialize<'de>,
{ {
self.state_get(shortstatehash, event_type, state_key) self.state_get(shortstatehash, event_type, state_key)
.await .await
@ -337,7 +337,7 @@ impl Service {
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Id> ) -> Result<Id>
where where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, Id: for<'de> Deserialize<'de> + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>, <Id as ToOwned>::Owned: Borrow<EventId>,
{ {
self.db self.db
@ -359,7 +359,7 @@ impl Service {
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<T> ) -> Result<T>
where where
T: for<'de> Deserialize<'de> + Send, T: for<'de> Deserialize<'de>,
{ {
self.room_state_get(room_id, event_type, state_key) self.room_state_get(room_id, event_type, state_key)
.await .await