From aea82183b2deb7318f6c0b1f8e4c2c86be9fbdcc Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 21 Nov 2024 05:51:25 +0000 Subject: [PATCH] add set intersection util for two sorted streams Signed-off-by: Jason Volk --- src/core/utils/set.rs | 32 +++++++++++++++++++++++++++++++- src/core/utils/tests.rs | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/core/utils/set.rs b/src/core/utils/set.rs index 563f9df5..ddcf05ff 100644 --- a/src/core/utils/set.rs +++ b/src/core/utils/set.rs @@ -1,4 +1,10 @@ -use std::cmp::{Eq, Ord}; +use std::{ + cmp::{Eq, Ord}, + pin::Pin, + sync::Arc, +}; + +use futures::{Stream, StreamExt}; use crate::{is_equal_to, is_less_than}; @@ -45,3 +51,27 @@ where }) }) } + +/// Intersection of sets +/// +/// Outputs the set of elements common to both streams. Streams must be sorted. +pub fn intersection_sorted_stream2(a: S, b: S) -> impl Stream + Send +where + S: Stream + Send + Unpin, + Item: Eq + PartialOrd + Send + Sync, +{ + use tokio::sync::Mutex; + + let b = Arc::new(Mutex::new(b.peekable())); + a.map(move |ai| (ai, b.clone())) + .filter_map(|(ai, b)| async move { + let mut lock = b.lock().await; + while let Some(bi) = Pin::new(&mut *lock).next_if(|bi| *bi <= ai).await.as_ref() { + if ai == *bi { + return Some(ai); + } + } + + None + }) +} diff --git a/src/core/utils/tests.rs b/src/core/utils/tests.rs index 84d35936..f4f78b02 100644 --- a/src/core/utils/tests.rs +++ b/src/core/utils/tests.rs @@ -237,3 +237,42 @@ fn set_intersection_sorted_all() { let r = intersection_sorted(i.into_iter()); assert!(r.eq(["bar", "baz", "foo"].iter())); } + +#[tokio::test] +async fn set_intersection_sorted_stream2() { + use futures::StreamExt; + use utils::{set::intersection_sorted_stream2, IterStream}; + + let a = ["bar"]; + let b = ["bar", "foo"]; + let r = intersection_sorted_stream2(a.iter().stream(), b.iter().stream()) + .collect::>() + .await; + assert!(r.eq(&["bar"])); + + let r = intersection_sorted_stream2(b.iter().stream(), a.iter().stream()) + .collect::>() + .await; + assert!(r.eq(&["bar"])); + + let a = ["aaa", "ccc", "xxx", "yyy"]; + let b = ["hhh", "iii", "jjj", "zzz"]; + let r = intersection_sorted_stream2(a.iter().stream(), b.iter().stream()) + .collect::>() + .await; + assert!(r.is_empty()); + + let a = ["aaa", "ccc", "eee", "ggg"]; + let b = ["aaa", "bbb", "ccc", "ddd", "eee"]; + let r = intersection_sorted_stream2(a.iter().stream(), b.iter().stream()) + .collect::>() + .await; + assert!(r.eq(&["aaa", "ccc", "eee"])); + + let a = ["aaa", "ccc", "eee", "ggg", "hhh", "iii"]; + let b = ["bbb", "ccc", "ddd", "fff", "ggg", "iii"]; + let r = intersection_sorted_stream2(a.iter().stream(), b.iter().stream()) + .collect::>() + .await; + assert!(r.eq(&["ccc", "ggg", "iii"])); +}