use std::{ convert::identity, mem::take, sync::{Arc, Mutex}, thread::JoinHandle, }; use async_channel::{bounded, Receiver, Sender}; use conduit::{debug, defer, err, implement, Result}; use futures::channel::oneshot; use crate::{keyval::KeyBuf, Handle, Map}; pub(crate) struct Pool { workers: Mutex>>, recv: Receiver, send: Sender, } #[derive(Default)] pub(crate) struct Opts { queue_size: Option, worker_num: Option, } const WORKER_THREAD_NAME: &str = "conduwuit:db"; const DEFAULT_QUEUE_SIZE: usize = 1024; const DEFAULT_WORKER_NUM: usize = 32; #[derive(Debug)] pub(crate) enum Cmd { Get(Get), } #[derive(Debug)] pub(crate) struct Get { pub(crate) map: Arc, pub(crate) key: KeyBuf, pub(crate) res: Option, } type ResultSender = oneshot::Sender>>; #[implement(Pool)] pub(crate) fn new(opts: &Opts) -> Result> { let queue_size = opts.queue_size.unwrap_or(DEFAULT_QUEUE_SIZE); let (send, recv) = bounded(queue_size); let pool = Arc::new(Self { workers: Vec::new().into(), recv, send, }); let worker_num = opts.worker_num.unwrap_or(DEFAULT_WORKER_NUM); pool.spawn_until(worker_num)?; Ok(pool) } #[implement(Pool)] fn spawn_until(self: &Arc, max: usize) -> Result { let mut workers = self.workers.lock()?; while workers.len() < max { self.clone().spawn_one(&mut workers)?; } Ok(()) } #[implement(Pool)] fn spawn_one(self: Arc, workers: &mut Vec>) -> Result { use std::thread::Builder; let id = workers.len(); debug!(?id, "spawning {WORKER_THREAD_NAME}..."); let thread = Builder::new() .name(WORKER_THREAD_NAME.into()) .spawn(move || self.worker(id))?; workers.push(thread); Ok(id) } #[implement(Pool)] pub(crate) fn close(self: &Arc) { debug!( senders = %self.send.sender_count(), receivers = %self.send.receiver_count(), "Closing pool channel" ); let closing = self.send.close(); debug_assert!(closing, "channel is not closing"); debug!("Shutting down pool..."); let mut workers = self.workers.lock().expect("locked"); debug!( workers = %workers.len(), "Waiting for workers to join..." ); take(&mut *workers) .into_iter() .map(JoinHandle::join) .try_for_each(identity) .expect("failed to join worker threads"); debug_assert!(self.send.is_empty(), "channel is not empty"); } #[implement(Pool)] #[tracing::instrument(skip(self, cmd), level = "trace")] pub(crate) async fn execute(&self, mut cmd: Cmd) -> Result> { let (send, recv) = oneshot::channel(); Self::prepare(&mut cmd, send); self.send .send(cmd) .await .map_err(|e| err!(error!("send failed {e:?}")))?; recv.await .map(into_recv_result) .map_err(|e| err!(error!("recv failed {e:?}")))? } #[implement(Pool)] fn prepare(cmd: &mut Cmd, send: ResultSender) { match cmd { Cmd::Get(ref mut cmd) => { _ = cmd.res.insert(send); }, }; } #[implement(Pool)] #[tracing::instrument(skip(self))] fn worker(self: Arc, id: usize) { debug!(?id, "worker spawned"); defer! {{ debug!(?id, "worker finished"); }} self.worker_loop(id); } #[implement(Pool)] fn worker_loop(&self, id: usize) { while let Ok(mut cmd) = self.recv.recv_blocking() { self.handle(id, &mut cmd); } } #[implement(Pool)] fn handle(&self, id: usize, cmd: &mut Cmd) { match cmd { Cmd::Get(get) => self.handle_get(id, get), } } #[implement(Pool)] #[tracing::instrument(skip(self, cmd), fields(%cmd.map), level = "trace")] fn handle_get(&self, id: usize, cmd: &mut Get) { debug_assert!(!cmd.key.is_empty(), "querying for empty key"); // Obtain the result channel. let chan = cmd.res.take().expect("missing result channel"); // It is worth checking if the future was dropped while the command was queued // so we can bail without paying for any query. if chan.is_canceled() { return; } // Perform the actual database query. We reuse our database::Map interface but // limited to the blocking calls, rather than creating another surface directly // with rocksdb here. let result = cmd.map.get_blocking(&cmd.key); // Send the result back to the submitter. let chan_result = chan.send(into_send_result(result)); // If the future was dropped during the query this will fail acceptably. let _chan_sent = chan_result.is_ok(); } fn into_send_result(result: Result>) -> Result> { // SAFETY: Necessary to send the Handle (rust_rocksdb::PinnableSlice) through // the channel. The lifetime on the handle is a device by rust-rocksdb to // associate a database lifetime with its assets. The Handle must be dropped // before the database is dropped. The handle must pass through recv_handle() on // the other end of the channel. unsafe { std::mem::transmute(result) } } fn into_recv_result(result: Result>) -> Result> { // SAFETY: This is to receive the Handle from the channel. Previously it had // passed through send_handle(). unsafe { std::mem::transmute(result) } }