use std::{fmt::Debug, hash::Hash, sync::Arc}; use tokio::sync::OwnedMutexGuard as Omg; /// Map of Mutexes pub struct MutexMap<Key, Val> { map: Map<Key, Val>, } pub struct Guard<Key, Val> { map: Map<Key, Val>, val: Omg<Val>, } type Map<Key, Val> = Arc<MapMutex<Key, Val>>; type MapMutex<Key, Val> = std::sync::Mutex<HashMap<Key, Val>>; type HashMap<Key, Val> = std::collections::HashMap<Key, Value<Val>>; type Value<Val> = Arc<tokio::sync::Mutex<Val>>; impl<Key, Val> MutexMap<Key, Val> where Key: Send + Hash + Eq + Clone, Val: Send + Default, { #[must_use] pub fn new() -> Self { Self { map: Map::new(MapMutex::new(HashMap::new())), } } #[tracing::instrument(skip(self), level = "debug")] pub async fn lock<K>(&self, k: &K) -> Guard<Key, Val> where K: ?Sized + Send + Sync + Debug, Key: for<'a> From<&'a K>, { let val = self .map .lock() .expect("locked") .entry(k.into()) .or_default() .clone(); Guard::<Key, Val> { map: Arc::clone(&self.map), val: val.lock_owned().await, } } #[must_use] pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) } #[must_use] pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() } #[must_use] pub fn len(&self) -> usize { self.map.lock().expect("locked").len() } } impl<Key, Val> Default for MutexMap<Key, Val> where Key: Send + Hash + Eq + Clone, Val: Send + Default, { fn default() -> Self { Self::new() } } impl<Key, Val> Drop for Guard<Key, Val> { fn drop(&mut self) { if Arc::strong_count(Omg::mutex(&self.val)) <= 2 { self.map.lock().expect("locked").retain(|_, val| { !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2 }); } } }