use std::{fmt::Debug, hash::Hash, sync::Arc}; use tokio::sync::OwnedMutexGuard as Omg; /// Map of Mutexes pub struct MutexMap { map: Map, } pub struct Guard { map: Map, val: Omg, } type Map = Arc>; type MapMutex = std::sync::Mutex>; type HashMap = std::collections::HashMap>; type Value = Arc>; impl MutexMap where Key: Send + Hash + Eq + Clone, Val: Send + Default, { #[must_use] pub fn new() -> Self { Self { map: Map::new(MapMutex::new(HashMap::new())), } } #[tracing::instrument(skip(self), level = "debug")] pub async fn lock(&self, k: &K) -> Guard where K: ?Sized + Send + Sync + Debug, Key: for<'a> From<&'a K>, { let val = self .map .lock() .expect("locked") .entry(k.into()) .or_default() .clone(); Guard:: { map: Arc::clone(&self.map), val: val.lock_owned().await, } } #[must_use] pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) } #[must_use] pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() } #[must_use] pub fn len(&self) -> usize { self.map.lock().expect("locked").len() } } impl Default for MutexMap where Key: Send + Hash + Eq + Clone, Val: Send + Default, { fn default() -> Self { Self::new() } } impl Drop for Guard { fn drop(&mut self) { if Arc::strong_count(Omg::mutex(&self.val)) <= 2 { self.map .lock() .expect("locked") .retain(|_, val| !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2); } } }