cleanup on drop for utils::mutex_map.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-09 20:04:43 +00:00
parent 01b2928d55
commit 2d251eb19c
10 changed files with 131 additions and 54 deletions

View file

@ -20,7 +20,7 @@ pub use debug::slice_truncated as debug_slice_truncated;
pub use hash::calculate_hash;
pub use html::Escape as HtmlEscape;
pub use json::{deserialize_from_str, to_canonical_object};
pub use mutex_map::MutexMap;
pub use mutex_map::{Guard as MutexMapGuard, MutexMap};
pub use rand::string as random_string;
pub use string::{str_from_bytes, string_from_bytes};
pub use sys::available_parallelism;

View file

@ -1,20 +1,22 @@
use std::{hash::Hash, sync::Arc};
use std::{fmt::Debug, hash::Hash, sync::Arc};
type Value<Val> = tokio::sync::Mutex<Val>;
type ArcMutex<Val> = Arc<Value<Val>>;
type HashMap<Key, Val> = std::collections::HashMap<Key, ArcMutex<Val>>;
type MapMutex<Key, Val> = std::sync::Mutex<HashMap<Key, Val>>;
type Map<Key, Val> = MapMutex<Key, Val>;
use tokio::sync::OwnedMutexGuard as Omg;
/// Map of Mutexes
pub struct MutexMap<Key, Val> {
map: Map<Key, Val>,
}
pub struct Guard<Val> {
_guard: tokio::sync::OwnedMutexGuard<Val>,
pub struct Guard<Key, Val> {
map: Map<Key, Val>,
val: Omg<Val>,
}
type Map<Key, Val> = Arc<MapMutex<Key, Val>>;
type MapMutex<Key, Val> = std::sync::Mutex<HashMap<Key, Val>>;
type HashMap<Key, Val> = std::collections::HashMap<Key, Value<Val>>;
type Value<Val> = Arc<tokio::sync::Mutex<Val>>;
impl<Key, Val> MutexMap<Key, Val>
where
Key: Send + Hash + Eq + Clone,
@ -23,28 +25,38 @@ where
#[must_use]
pub fn new() -> Self {
Self {
map: Map::<Key, Val>::new(HashMap::<Key, Val>::new()),
map: Map::new(MapMutex::new(HashMap::new())),
}
}
pub async fn lock<K>(&self, k: &K) -> Guard<Val>
#[tracing::instrument(skip(self), level = "debug")]
pub async fn lock<K>(&self, k: &K) -> Guard<Key, Val>
where
K: ?Sized + Send + Sync,
K: ?Sized + Send + Sync + Debug,
Key: for<'a> From<&'a K>,
{
let val = self
.map
.lock()
.expect("map mutex locked")
.expect("locked")
.entry(k.into())
.or_default()
.clone();
let guard = val.lock_owned().await;
Guard::<Val> {
_guard: guard,
Guard::<Key, Val> {
map: Arc::clone(&self.map),
val: val.lock_owned().await,
}
}
#[must_use]
pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) }
#[must_use]
pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() }
#[must_use]
pub fn len(&self) -> usize { self.map.lock().expect("locked").len() }
}
impl<Key, Val> Default for MutexMap<Key, Val>
@ -54,3 +66,14 @@ where
{
fn default() -> Self { Self::new() }
}
impl<Key, Val> Drop for Guard<Key, Val> {
fn drop(&mut self) {
if Arc::strong_count(Omg::mutex(&self.val)) <= 2 {
self.map
.lock()
.expect("locked")
.retain(|_, val| !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2);
}
}
}

View file

@ -81,3 +81,56 @@ fn checked_add_overflow() {
let res = checked!(a + 1).expect("overflow");
assert_eq!(res, 0);
}
#[tokio::test]
async fn mutex_map_cleanup() {
use crate::utils::MutexMap;
let map = MutexMap::<String, ()>::new();
let lock = map.lock("foo").await;
assert!(!map.is_empty(), "map must not be empty");
drop(lock);
assert!(map.is_empty(), "map must be empty");
}
#[tokio::test]
async fn mutex_map_contend() {
use std::sync::Arc;
use tokio::sync::Barrier;
use crate::utils::MutexMap;
let map = Arc::new(MutexMap::<String, ()>::new());
let seq = Arc::new([Barrier::new(2), Barrier::new(2)]);
let str = "foo".to_owned();
let seq_ = seq.clone();
let map_ = map.clone();
let str_ = str.clone();
let join_a = tokio::spawn(async move {
let _lock = map_.lock(&str_).await;
assert!(!map_.is_empty(), "A0 must not be empty");
seq_[0].wait().await;
assert!(map_.contains(&str_), "A1 must contain key");
});
let seq_ = seq.clone();
let map_ = map.clone();
let str_ = str.clone();
let join_b = tokio::spawn(async move {
let _lock = map_.lock(&str_).await;
assert!(!map_.is_empty(), "B0 must not be empty");
seq_[1].wait().await;
assert!(map_.contains(&str_), "B1 must contain key");
});
seq[0].wait().await;
assert!(map.contains(&str), "Must contain key");
seq[1].wait().await;
tokio::try_join!(join_b, join_a).expect("joined");
assert!(map.is_empty(), "Must be empty");
}