diff --git a/src/router/Cargo.toml b/src/router/Cargo.toml index dae9d14c..4b197d0a 100644 --- a/src/router/Cargo.toml +++ b/src/router/Cargo.toml @@ -66,6 +66,8 @@ bytes.workspace = true clap.workspace = true http-body-util.workspace = true http.workspace = true +hyper.workspace = true +hyper-util.workspace = true regex.workspace = true ruma.workspace = true sentry.optional = true diff --git a/src/router/run.rs b/src/router/run.rs index 2603c04a..e6149fa2 100644 --- a/src/router/run.rs +++ b/src/router/run.rs @@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration}; use axum_server::Handle as ServerHandle; use tokio::{ signal, - sync::oneshot::{self, Sender}, + sync::broadcast::{self, Sender}, }; use tracing::{debug, info, warn}; @@ -40,14 +40,16 @@ pub(crate) async fn run(server: Arc) -> Result<(), Error> { .insert(handle.clone()); server.interrupt.store(false, Ordering::Release); - let (tx, rx) = oneshot::channel::<()>(); - let sigs = server.runtime().spawn(sighandle(server.clone(), tx)); + let (tx, _) = broadcast::channel::<()>(1); + let sigs = server + .runtime() + .spawn(sighandle(server.clone(), tx.clone())); // Prepare to serve http clients let res; // Serve clients if cfg!(unix) && config.unix_socket_path.is_some() { - res = serve::unix_socket(&server, app, rx).await; + res = serve::unix_socket(&server, app, tx.subscribe()).await; } else if config.tls.is_some() { res = serve::tls(&server, app, handle.clone(), addrs).await; } else { @@ -66,7 +68,7 @@ pub(crate) async fn run(server: Arc) -> Result<(), Error> { _ = services().admin.handle.lock().await.take(); debug_info!("Finished"); - Ok(res?) + res } /// Async initializations diff --git a/src/router/serve.rs b/src/router/serve.rs index 37ed9902..bddd9a2d 100644 --- a/src/router/serve.rs +++ b/src/router/serve.rs @@ -1,27 +1,31 @@ -#[cfg(unix)] -use std::fs::Permissions; // only for UNIX sockets stuff and *nix container checks -#[cfg(unix)] -use std::os::unix::fs::PermissionsExt as _; use std::{ - io, net::SocketAddr, + path::Path, sync::{atomic::Ordering, Arc}, }; -use axum::Router; +use axum::{extract::Request, routing::IntoMakeService, Router}; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; #[cfg(feature = "axum_dual_protocol")] use axum_server_dual_protocol::ServerExt; -use conduit::{debug_info, Server}; +use conduit::{debug_error, debug_info, utils, Error, Result, Server}; +use hyper::{body::Incoming, service::service_fn}; +use hyper_util::{ + rt::{TokioExecutor, TokioIo}, + server, +}; use tokio::{ - sync::oneshot::{self}, + fs, + sync::broadcast::{self}, task::JoinSet, }; +use tower::{Service, ServiceExt}; use tracing::{debug, info, warn}; +use utils::unwrap_infallible; pub(crate) async fn plain( - server: &Arc, app: axum::routing::IntoMakeService, handle: ServerHandle, addrs: Vec, -) -> io::Result<()> { + server: &Arc, app: IntoMakeService, handle: ServerHandle, addrs: Vec, +) -> Result<()> { let mut join_set = JoinSet::new(); for addr in &addrs { join_set.spawn_on(bind(*addr).handle(handle.clone()).serve(app.clone()), server.runtime()); @@ -48,8 +52,8 @@ pub(crate) async fn plain( } pub(crate) async fn tls( - server: &Arc, app: axum::routing::IntoMakeService, handle: ServerHandle, addrs: Vec, -) -> io::Result<()> { + server: &Arc, app: IntoMakeService, handle: ServerHandle, addrs: Vec, +) -> Result<()> { let config = &server.config; let tls = config.tls.as_ref().expect("TLS configuration"); @@ -107,31 +111,92 @@ pub(crate) async fn tls( } #[cfg(unix)] -#[allow(unused_variables)] pub(crate) async fn unix_socket( - server: &Arc, app: axum::routing::IntoMakeService, rx: oneshot::Receiver<()>, -) -> io::Result<()> { - let config = &server.config; - let path = config.unix_socket_path.as_ref().unwrap(); - - if path.exists() { - warn!( - "UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.", - path.display() - ); - tokio::fs::remove_file(&path).await?; + server: &Arc, app: IntoMakeService, mut shutdown: broadcast::Receiver<()>, +) -> Result<()> { + let mut tasks = JoinSet::<()>::new(); + let executor = TokioExecutor::new(); + let builder = server::conn::auto::Builder::new(executor); + let listener = unix_socket_init(server).await?; + loop { + let app = app.clone(); + let builder = builder.clone(); + tokio::select! { + _sig = shutdown.recv() => break, + accept = listener.accept() => match accept { + Ok(conn) => unix_socket_accept(server, &listener, &mut tasks, app, builder, conn).await, + Err(err) => debug_error!(?listener, "accept error: {err}"), + }, + } } - tokio::fs::create_dir_all(path.parent().unwrap()).await?; - - let socket_perms = config.unix_socket_perms.to_string(); - let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap(); - tokio::fs::set_permissions(&path, Permissions::from_mode(octal_perms)) - .await - .unwrap(); - - let bind = tokio::net::UnixListener::bind(path)?; - info!("Listening at {:?}", path); + drop(listener); + tasks.shutdown().await; Ok(()) } + +#[cfg(unix)] +async fn unix_socket_accept( + server: &Arc, listener: &tokio::net::UnixListener, tasks: &mut JoinSet<()>, + mut app: IntoMakeService, builder: server::conn::auto::Builder, + conn: (tokio::net::UnixStream, tokio::net::unix::SocketAddr), +) { + let (socket, remote) = conn; + let socket = TokioIo::new(socket); + debug!(?listener, ?socket, ?remote, "accepted"); + + let called = unwrap_infallible(app.call(()).await); + let handler = service_fn(move |req: Request| called.clone().oneshot(req)); + + let task = async move { + builder + .serve_connection(socket, handler) + .await + .map_err(|e| debug_error!(?remote, "connection error: {e}")) + .expect("connection error"); + }; + + _ = tasks.spawn_on(task, server.runtime()); + while tasks.try_join_next().is_some() {} +} + +#[cfg(unix)] +async fn unix_socket_init(server: &Arc) -> Result { + use std::os::unix::fs::PermissionsExt; + + let config = &server.config; + let path = config + .unix_socket_path + .as_ref() + .expect("failed to extract configured unix socket path"); + + if path.exists() { + warn!("Removing existing UNIX socket {:#?} (unclean shutdown?)...", path.display()); + fs::remove_file(&path) + .await + .map_err(|e| warn!("Failed to remove existing UNIX socket: {e}")) + .unwrap(); + } + + let dir = path.parent().unwrap_or_else(|| Path::new("/")); + if let Err(e) = fs::create_dir_all(dir).await { + return Err(Error::Err(format!("Failed to create {dir:?} for socket {path:?}: {e}"))); + } + + let listener = tokio::net::UnixListener::bind(path); + if let Err(e) = listener { + return Err(Error::Err(format!("Failed to bind listener {path:?}: {e}"))); + } + + let socket_perms = config.unix_socket_perms.to_string(); + let octal_perms = u32::from_str_radix(&socket_perms, 8).expect("failed to convert octal permissions"); + let perms = std::fs::Permissions::from_mode(octal_perms); + if let Err(e) = fs::set_permissions(&path, perms).await { + return Err(Error::Err(format!("Failed to set socket {path:?} permissions: {e}"))); + } + + info!("Listening at {:?}", path); + + Ok(listener.unwrap()) +}