From 9a9c071e8204604d617b213c6a977f5649ba2fc0 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 30 Nov 2024 03:16:57 +0000 Subject: [PATCH] use tokio for threadpool mgmt Signed-off-by: Jason Volk --- src/database/database.rs | 2 +- src/database/engine.rs | 10 +- src/database/map/get.rs | 30 +++--- src/database/pool.rs | 204 +++++++++++++++++++++++++-------------- 4 files changed, 155 insertions(+), 91 deletions(-) diff --git a/src/database/database.rs b/src/database/database.rs index 40aec312..3df95dce 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -16,7 +16,7 @@ pub struct Database { impl Database { /// Load an existing database or create a new one. pub async fn open(server: &Arc) -> Result> { - let db = Engine::open(server)?; + let db = Engine::open(server).await?; Ok(Arc::new(Self { db: db.clone(), maps: maps::open(&db)?, diff --git a/src/database/engine.rs b/src/database/engine.rs index e700be62..a6ed7d86 100644 --- a/src/database/engine.rs +++ b/src/database/engine.rs @@ -23,7 +23,7 @@ use crate::{ }; pub struct Engine { - server: Arc, + pub(crate) server: Arc, row_cache: Cache, col_cache: RwLock>, opts: Options, @@ -40,7 +40,7 @@ pub(crate) type Db = DBWithThreadMode; impl Engine { #[tracing::instrument(skip_all)] - pub(crate) fn open(server: &Arc) -> Result> { + pub(crate) async fn open(server: &Arc) -> Result> { let config = &server.config; let cache_capacity_bytes = config.db_cache_capacity_mb * 1024.0 * 1024.0; @@ -119,7 +119,7 @@ impl Engine { corks: AtomicU32::new(0), read_only: config.rocksdb_read_only, secondary: config.rocksdb_secondary, - pool: Pool::new(&pool_opts)?, + pool: Pool::new(server, &pool_opts).await?, })) } @@ -305,7 +305,7 @@ pub(crate) fn repair(db_opts: &Options, path: &PathBuf) -> Result<()> { Ok(()) } -#[tracing::instrument(skip_all, name = "rocksdb", level = "debug")] +#[tracing::instrument(skip(msg), name = "rocksdb", level = "trace")] pub(crate) fn handle_log(level: LogLevel, msg: &str) { let msg = msg.trim(); if msg.starts_with("Options") { @@ -325,7 +325,7 @@ impl Drop for Engine { fn drop(&mut self) { const BLOCKING: bool = true; - debug!("Joining request threads..."); + debug!("Shutting down request pool..."); self.pool.close(); debug!("Waiting for background tasks to finish..."); diff --git a/src/database/map/get.rs b/src/database/map/get.rs index befc0b24..4699fec4 100644 --- a/src/database/map/get.rs +++ b/src/database/map/get.rs @@ -2,7 +2,7 @@ use std::{convert::AsRef, fmt::Debug, io::Write, sync::Arc}; use arrayvec::ArrayVec; use conduit::{err, implement, utils::IterStream, Err, Result}; -use futures::{future, Future, FutureExt, Stream}; +use futures::{future, Future, FutureExt, Stream, StreamExt}; use rocksdb::DBPinnableSlice; use serde::Serialize; @@ -54,6 +54,18 @@ where self.get(key) } +#[implement(super::Map)] +#[tracing::instrument(skip(self, keys), fields(%self), level = "trace")] +pub fn get_batch<'a, I, K>(self: &'a Arc, keys: I) -> impl Stream>> + Send + 'a +where + I: Iterator + ExactSizeIterator + Debug + Send + 'a, + K: AsRef<[u8]> + Debug + Send + ?Sized + Sync + 'a, +{ + keys.stream() + .map(move |key| self.get(key)) + .buffered(self.db.server.config.db_pool_workers.saturating_mul(2)) +} + /// Fetch a value from the database into cache, returning a reference-handle /// asynchronously. The key is referenced directly to perform the query. #[implement(super::Map)] @@ -80,17 +92,8 @@ where } #[implement(super::Map)] -#[tracing::instrument(skip(self, keys), fields(%self), level = "trace")] -pub fn get_batch<'a, I, K>(&self, keys: I) -> impl Stream>> -where - I: Iterator + ExactSizeIterator + Debug + Send, - K: AsRef<[u8]> + Debug + Send + ?Sized + Sync + 'a, -{ - self.get_batch_blocking(keys).stream() -} - -#[implement(super::Map)] -pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> impl Iterator>> +#[tracing::instrument(skip(self, keys), name = "batch_blocking", level = "trace")] +pub(crate) fn get_batch_blocking<'a, I, K>(&self, keys: I) -> impl Iterator>> + Send where I: Iterator + ExactSizeIterator + Debug + Send, K: AsRef<[u8]> + Debug + Send + ?Sized + Sync + 'a, @@ -111,6 +114,7 @@ where /// The key is referenced directly to perform the query. This is a thread- /// blocking call. #[implement(super::Map)] +#[tracing::instrument(skip(self, key), name = "blocking", level = "trace")] pub fn get_blocking(&self, key: &K) -> Result> where K: AsRef<[u8]> + ?Sized, @@ -125,7 +129,7 @@ where /// Fetch a value from the cache without I/O. #[implement(super::Map)] -#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] +#[tracing::instrument(skip(self, key), name = "cache", level = "trace")] pub(crate) fn get_cached(&self, key: &K) -> Result>> where K: AsRef<[u8]> + Debug + ?Sized, diff --git a/src/database/pool.rs b/src/database/pool.rs index ee3e67dd..a9697625 100644 --- a/src/database/pool.rs +++ b/src/database/pool.rs @@ -1,20 +1,25 @@ use std::{ - convert::identity, mem::take, - sync::{Arc, Mutex}, - thread::JoinHandle, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, }; -use async_channel::{bounded, Receiver, Sender}; -use conduit::{debug, defer, err, implement, Result}; +use async_channel::{bounded, Receiver, RecvError, Sender}; +use conduit::{debug, debug_warn, defer, err, implement, result::DebugInspect, Result, Server}; use futures::channel::oneshot; +use tokio::{sync::Mutex, task::JoinSet}; use crate::{keyval::KeyBuf, Handle, Map}; pub(crate) struct Pool { - workers: Mutex>>, - recv: Receiver, - send: Sender, + server: Arc, + workers: Mutex>, + queue: Sender, + busy: AtomicUsize, + busy_max: AtomicUsize, + queued_max: AtomicUsize, } pub(crate) struct Opts { @@ -22,10 +27,6 @@ pub(crate) struct Opts { pub(crate) worker_num: usize, } -const QUEUE_LIMIT: (usize, usize) = (1, 8192); -const WORKER_LIMIT: (usize, usize) = (1, 512); -const WORKER_THREAD_NAME: &str = "conduwuit:db"; - #[derive(Debug)] pub(crate) enum Cmd { Get(Get), @@ -40,83 +41,111 @@ pub(crate) struct Get { type ResultSender = oneshot::Sender>>; -#[implement(Pool)] -pub(crate) fn new(opts: &Opts) -> Result> { - let queue_size = opts.queue_size.clamp(QUEUE_LIMIT.0, QUEUE_LIMIT.1); +const QUEUE_LIMIT: (usize, usize) = (1, 3072); +const WORKER_LIMIT: (usize, usize) = (1, 512); +impl Drop for Pool { + fn drop(&mut self) { + debug_assert!(self.queue.is_empty(), "channel must be empty on drop"); + debug_assert!(self.queue.is_closed(), "channel should be closed on drop"); + } +} + +#[implement(Pool)] +pub(crate) async fn new(server: &Arc, opts: &Opts) -> Result> { + let queue_size = opts.queue_size.clamp(QUEUE_LIMIT.0, QUEUE_LIMIT.1); let (send, recv) = bounded(queue_size); let pool = Arc::new(Self { - workers: Vec::new().into(), - recv, - send, + server: server.clone(), + workers: JoinSet::new().into(), + queue: send, + busy: AtomicUsize::default(), + busy_max: AtomicUsize::default(), + queued_max: AtomicUsize::default(), }); let worker_num = opts.worker_num.clamp(WORKER_LIMIT.0, WORKER_LIMIT.1); - pool.spawn_until(worker_num)?; + pool.spawn_until(recv, worker_num).await?; Ok(pool) } #[implement(Pool)] -fn spawn_until(self: &Arc, max: usize) -> Result { - let mut workers = self.workers.lock()?; +pub(crate) async fn _shutdown(self: &Arc) { + if !self.queue.is_closed() { + self.close(); + } + let workers = take(&mut *self.workers.lock().await); + debug!(workers = workers.len(), "Waiting for workers to join..."); + + workers.join_all().await; + debug_assert!(self.queue.is_empty(), "channel is not empty"); +} + +#[implement(Pool)] +pub(crate) fn close(&self) { + debug_assert!(!self.queue.is_closed(), "channel already closed"); + debug!( + senders = self.queue.sender_count(), + receivers = self.queue.receiver_count(), + "Closing pool channel" + ); + + let closing = self.queue.close(); + debug_assert!(closing, "channel is not closing"); +} + +#[implement(Pool)] +async fn spawn_until(self: &Arc, recv: Receiver, max: usize) -> Result { + let mut workers = self.workers.lock().await; while workers.len() < max { - self.clone().spawn_one(&mut workers)?; + self.spawn_one(&mut workers, recv.clone())?; } Ok(()) } #[implement(Pool)] -fn spawn_one(self: Arc, workers: &mut Vec>) -> Result { - use std::thread::Builder; - +fn spawn_one(self: &Arc, workers: &mut JoinSet<()>, recv: Receiver) -> Result { let id = workers.len(); - debug!(?id, "spawning {WORKER_THREAD_NAME}..."); - let thread = Builder::new() - .name(WORKER_THREAD_NAME.into()) - .spawn(move || self.worker(id))?; + debug!(?id, "spawning"); + let self_ = self.clone(); + let _abort = workers.spawn_blocking_on(move || self_.worker(id, recv), self.server.runtime()); - workers.push(thread); - - Ok(id) + Ok(()) } #[implement(Pool)] -pub(crate) fn close(self: &Arc) { - debug!( - senders = %self.send.sender_count(), - receivers = %self.send.receiver_count(), - "Closing pool channel" - ); - let closing = self.send.close(); - debug_assert!(closing, "channel is not closing"); - - debug!("Shutting down pool..."); - let mut workers = self.workers.lock().expect("locked"); - - debug!( - workers = %workers.len(), - "Waiting for workers to join..." - ); - take(&mut *workers) - .into_iter() - .map(JoinHandle::join) - .try_for_each(identity) - .expect("failed to join worker threads"); - - debug_assert!(self.send.is_empty(), "channel is not empty"); -} - -#[implement(Pool)] -#[tracing::instrument(skip(self, cmd), level = "trace")] +#[tracing::instrument( + level = "trace" + skip(self, cmd), + fields( + task = ?tokio::task::try_id(), + receivers = self.queue.receiver_count(), + senders = self.queue.sender_count(), + queued = self.queue.len(), + queued_max = self.queued_max.load(Ordering::Relaxed), + ), +)] pub(crate) async fn execute(&self, mut cmd: Cmd) -> Result> { let (send, recv) = oneshot::channel(); Self::prepare(&mut cmd, send); - self.send + if cfg!(debug_assertions) { + self.queued_max + .fetch_max(self.queue.len(), Ordering::Relaxed); + } + + if self.queue.is_full() { + debug_warn!( + capacity = ?self.queue.capacity(), + "pool queue is full" + ); + } + + self.queue .send(cmd) .await .map_err(|e| err!(error!("send failed {e:?}")))?; @@ -136,30 +165,61 @@ fn prepare(cmd: &mut Cmd, send: ResultSender) { } #[implement(Pool)] -#[tracing::instrument(skip(self))] -fn worker(self: Arc, id: usize) { - debug!(?id, "worker spawned"); - defer! {{ debug!(?id, "worker finished"); }} - self.worker_loop(id); +#[tracing::instrument(skip(self, recv))] +fn worker(self: Arc, id: usize, recv: Receiver) { + debug!("worker spawned"); + defer! {{ debug!("worker finished"); }} + + self.worker_loop(&recv); } #[implement(Pool)] -fn worker_loop(&self, id: usize) { - while let Ok(mut cmd) = self.recv.recv_blocking() { - self.worker_handle(id, &mut cmd); +fn worker_loop(&self, recv: &Receiver) { + // initial +1 needed prior to entering wait + self.busy.fetch_add(1, Ordering::Relaxed); + + while let Ok(mut cmd) = self.worker_wait(recv) { + self.worker_handle(&mut cmd); } } #[implement(Pool)] -fn worker_handle(&self, id: usize, cmd: &mut Cmd) { +#[tracing::instrument( + name = "wait", + level = "trace", + skip_all, + fields( + receivers = recv.receiver_count(), + senders = recv.sender_count(), + queued = recv.len(), + busy = self.busy.load(Ordering::Relaxed), + busy_max = self.busy_max.fetch_max( + self.busy.fetch_sub(1, Ordering::Relaxed), + Ordering::Relaxed + ), + ), +)] +fn worker_wait(&self, recv: &Receiver) -> Result { + recv.recv_blocking().debug_inspect(|_| { + self.busy.fetch_add(1, Ordering::Relaxed); + }) +} + +#[implement(Pool)] +fn worker_handle(&self, cmd: &mut Cmd) { match cmd { - Cmd::Get(get) => self.handle_get(id, get), + Cmd::Get(cmd) => self.handle_get(cmd), } } #[implement(Pool)] -#[tracing::instrument(skip(self, cmd), fields(%cmd.map), level = "trace")] -fn handle_get(&self, id: usize, cmd: &mut Get) { +#[tracing::instrument( + name = "get", + level = "trace", + skip_all, + fields(%cmd.map), +)] +fn handle_get(&self, cmd: &mut Get) { debug_assert!(!cmd.key.is_empty(), "querying for empty key"); // Obtain the result channel.