refactor to iterator inputs for auth_chain/short batch functions

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-11-22 12:25:46 +00:00
parent 5da42fb859
commit 3789d60b6a
10 changed files with 76 additions and 71 deletions

View file

@ -2,6 +2,7 @@ mod data;
use std::{
collections::{BTreeSet, HashSet},
fmt::Debug,
sync::Arc,
};
@ -37,9 +38,12 @@ impl crate::Service for Service {
}
impl Service {
pub async fn event_ids_iter(
&self, room_id: &RoomId, starting_events: &[&EventId],
) -> Result<impl Stream<Item = Arc<EventId>> + Send + '_> {
pub async fn event_ids_iter<'a, I>(
&'a self, room_id: &RoomId, starting_events: I,
) -> Result<impl Stream<Item = Arc<EventId>> + Send + '_>
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{
let stream = self
.get_event_ids(room_id, starting_events)
.await?
@ -49,12 +53,15 @@ impl Service {
Ok(stream)
}
pub async fn get_event_ids(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result<Vec<Arc<EventId>>> {
pub async fn get_event_ids<'a, I>(&'a self, room_id: &RoomId, starting_events: I) -> Result<Vec<Arc<EventId>>>
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{
let chain = self.get_auth_chain(room_id, starting_events).await?;
let event_ids = self
.services
.short
.multi_get_eventid_from_short(&chain)
.multi_get_eventid_from_short(chain.into_iter())
.await
.into_iter()
.filter_map(Result::ok)
@ -64,7 +71,10 @@ impl Service {
}
#[tracing::instrument(skip_all, name = "auth_chain")]
pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result<Vec<ShortEventId>> {
pub async fn get_auth_chain<'a, I>(&'a self, room_id: &RoomId, starting_events: I) -> Result<Vec<ShortEventId>>
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{
const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db?
const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new();
@ -72,19 +82,19 @@ impl Service {
let mut starting_ids = self
.services
.short
.multi_get_or_create_shorteventid(starting_events)
.enumerate()
.multi_get_or_create_shorteventid(starting_events.clone())
.zip(starting_events.clone().stream())
.boxed();
let mut buckets = [BUCKET; NUM_BUCKETS];
while let Some((i, short)) = starting_ids.next().await {
while let Some((short, starting_event)) = starting_ids.next().await {
let bucket: usize = short.try_into()?;
let bucket: usize = validated!(bucket % NUM_BUCKETS);
buckets[bucket].insert((short, starting_events[i]));
buckets[bucket].insert((short, starting_event));
}
debug!(
starting_events = ?starting_events.len(),
starting_events = ?starting_events.count(),
elapsed = ?started.elapsed(),
"start",
);