add back unix socket listener.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-05-29 16:59:20 +00:00 committed by June 🍓🦴
parent faa2b95c84
commit 0baa57f5d9
3 changed files with 108 additions and 39 deletions

View file

@ -66,6 +66,8 @@ bytes.workspace = true
clap.workspace = true clap.workspace = true
http-body-util.workspace = true http-body-util.workspace = true
http.workspace = true http.workspace = true
hyper.workspace = true
hyper-util.workspace = true
regex.workspace = true regex.workspace = true
ruma.workspace = true ruma.workspace = true
sentry.optional = true sentry.optional = true

View file

@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
use axum_server::Handle as ServerHandle; use axum_server::Handle as ServerHandle;
use tokio::{ use tokio::{
signal, signal,
sync::oneshot::{self, Sender}, sync::broadcast::{self, Sender},
}; };
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
@ -40,14 +40,16 @@ pub(crate) async fn run(server: Arc<Server>) -> Result<(), Error> {
.insert(handle.clone()); .insert(handle.clone());
server.interrupt.store(false, Ordering::Release); server.interrupt.store(false, Ordering::Release);
let (tx, rx) = oneshot::channel::<()>(); let (tx, _) = broadcast::channel::<()>(1);
let sigs = server.runtime().spawn(sighandle(server.clone(), tx)); let sigs = server
.runtime()
.spawn(sighandle(server.clone(), tx.clone()));
// Prepare to serve http clients // Prepare to serve http clients
let res; let res;
// Serve clients // Serve clients
if cfg!(unix) && config.unix_socket_path.is_some() { if cfg!(unix) && config.unix_socket_path.is_some() {
res = serve::unix_socket(&server, app, rx).await; res = serve::unix_socket(&server, app, tx.subscribe()).await;
} else if config.tls.is_some() { } else if config.tls.is_some() {
res = serve::tls(&server, app, handle.clone(), addrs).await; res = serve::tls(&server, app, handle.clone(), addrs).await;
} else { } else {
@ -66,7 +68,7 @@ pub(crate) async fn run(server: Arc<Server>) -> Result<(), Error> {
_ = services().admin.handle.lock().await.take(); _ = services().admin.handle.lock().await.take();
debug_info!("Finished"); debug_info!("Finished");
Ok(res?) res
} }
/// Async initializations /// Async initializations

View file

@ -1,27 +1,31 @@
#[cfg(unix)]
use std::fs::Permissions; // only for UNIX sockets stuff and *nix container checks
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt as _;
use std::{ use std::{
io,
net::SocketAddr, net::SocketAddr,
path::Path,
sync::{atomic::Ordering, Arc}, sync::{atomic::Ordering, Arc},
}; };
use axum::Router; use axum::{extract::Request, routing::IntoMakeService, Router};
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
#[cfg(feature = "axum_dual_protocol")] #[cfg(feature = "axum_dual_protocol")]
use axum_server_dual_protocol::ServerExt; use axum_server_dual_protocol::ServerExt;
use conduit::{debug_info, Server}; use conduit::{debug_error, debug_info, utils, Error, Result, Server};
use hyper::{body::Incoming, service::service_fn};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server,
};
use tokio::{ use tokio::{
sync::oneshot::{self}, fs,
sync::broadcast::{self},
task::JoinSet, task::JoinSet,
}; };
use tower::{Service, ServiceExt};
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use utils::unwrap_infallible;
pub(crate) async fn plain( pub(crate) async fn plain(
server: &Arc<Server>, app: axum::routing::IntoMakeService<Router>, handle: ServerHandle, addrs: Vec<SocketAddr>, server: &Arc<Server>, app: IntoMakeService<Router>, handle: ServerHandle, addrs: Vec<SocketAddr>,
) -> io::Result<()> { ) -> Result<()> {
let mut join_set = JoinSet::new(); let mut join_set = JoinSet::new();
for addr in &addrs { for addr in &addrs {
join_set.spawn_on(bind(*addr).handle(handle.clone()).serve(app.clone()), server.runtime()); join_set.spawn_on(bind(*addr).handle(handle.clone()).serve(app.clone()), server.runtime());
@ -48,8 +52,8 @@ pub(crate) async fn plain(
} }
pub(crate) async fn tls( pub(crate) async fn tls(
server: &Arc<Server>, app: axum::routing::IntoMakeService<Router>, handle: ServerHandle, addrs: Vec<SocketAddr>, server: &Arc<Server>, app: IntoMakeService<Router>, handle: ServerHandle, addrs: Vec<SocketAddr>,
) -> io::Result<()> { ) -> Result<()> {
let config = &server.config; let config = &server.config;
let tls = config.tls.as_ref().expect("TLS configuration"); let tls = config.tls.as_ref().expect("TLS configuration");
@ -107,31 +111,92 @@ pub(crate) async fn tls(
} }
#[cfg(unix)] #[cfg(unix)]
#[allow(unused_variables)]
pub(crate) async fn unix_socket( pub(crate) async fn unix_socket(
server: &Arc<Server>, app: axum::routing::IntoMakeService<Router>, rx: oneshot::Receiver<()>, server: &Arc<Server>, app: IntoMakeService<Router>, mut shutdown: broadcast::Receiver<()>,
) -> io::Result<()> { ) -> Result<()> {
let config = &server.config; let mut tasks = JoinSet::<()>::new();
let path = config.unix_socket_path.as_ref().unwrap(); let executor = TokioExecutor::new();
let builder = server::conn::auto::Builder::new(executor);
if path.exists() { let listener = unix_socket_init(server).await?;
warn!( loop {
"UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.", let app = app.clone();
path.display() let builder = builder.clone();
); tokio::select! {
tokio::fs::remove_file(&path).await?; _sig = shutdown.recv() => break,
accept = listener.accept() => match accept {
Ok(conn) => unix_socket_accept(server, &listener, &mut tasks, app, builder, conn).await,
Err(err) => debug_error!(?listener, "accept error: {err}"),
},
}
} }
tokio::fs::create_dir_all(path.parent().unwrap()).await?; drop(listener);
tasks.shutdown().await;
let socket_perms = config.unix_socket_perms.to_string();
let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap();
tokio::fs::set_permissions(&path, Permissions::from_mode(octal_perms))
.await
.unwrap();
let bind = tokio::net::UnixListener::bind(path)?;
info!("Listening at {:?}", path);
Ok(()) Ok(())
} }
#[cfg(unix)]
async fn unix_socket_accept(
server: &Arc<Server>, listener: &tokio::net::UnixListener, tasks: &mut JoinSet<()>,
mut app: IntoMakeService<Router>, builder: server::conn::auto::Builder<TokioExecutor>,
conn: (tokio::net::UnixStream, tokio::net::unix::SocketAddr),
) {
let (socket, remote) = conn;
let socket = TokioIo::new(socket);
debug!(?listener, ?socket, ?remote, "accepted");
let called = unwrap_infallible(app.call(()).await);
let handler = service_fn(move |req: Request<Incoming>| called.clone().oneshot(req));
let task = async move {
builder
.serve_connection(socket, handler)
.await
.map_err(|e| debug_error!(?remote, "connection error: {e}"))
.expect("connection error");
};
_ = tasks.spawn_on(task, server.runtime());
while tasks.try_join_next().is_some() {}
}
#[cfg(unix)]
async fn unix_socket_init(server: &Arc<Server>) -> Result<tokio::net::UnixListener> {
use std::os::unix::fs::PermissionsExt;
let config = &server.config;
let path = config
.unix_socket_path
.as_ref()
.expect("failed to extract configured unix socket path");
if path.exists() {
warn!("Removing existing UNIX socket {:#?} (unclean shutdown?)...", path.display());
fs::remove_file(&path)
.await
.map_err(|e| warn!("Failed to remove existing UNIX socket: {e}"))
.unwrap();
}
let dir = path.parent().unwrap_or_else(|| Path::new("/"));
if let Err(e) = fs::create_dir_all(dir).await {
return Err(Error::Err(format!("Failed to create {dir:?} for socket {path:?}: {e}")));
}
let listener = tokio::net::UnixListener::bind(path);
if let Err(e) = listener {
return Err(Error::Err(format!("Failed to bind listener {path:?}: {e}")));
}
let socket_perms = config.unix_socket_perms.to_string();
let octal_perms = u32::from_str_radix(&socket_perms, 8).expect("failed to convert octal permissions");
let perms = std::fs::Permissions::from_mode(octal_perms);
if let Err(e) = fs::set_permissions(&path, perms).await {
return Err(Error::Err(format!("Failed to set socket {path:?} permissions: {e}")));
}
info!("Listening at {:?}", path);
Ok(listener.unwrap())
}