Implement UNIX sockets

Initial implementation done in https://gitlab.com/famedly/conduit/-/merge_requests/507,
*substantially* reworked, corrected, improved by infamous <ehuff007@gmail.com>,
and few parts done by me.

Co-authored-by: infamous <ehuff007@gmail.com>
Signed-off-by: girlbossceo <june@girlboss.ceo>
This commit is contained in:
girlbossceo 2023-07-29 21:57:41 +00:00
parent 3bfdae795d
commit 42efc9deaf
8 changed files with 186 additions and 40 deletions

View file

@ -8,7 +8,10 @@
#![allow(clippy::suspicious_else_formatting)]
#![deny(clippy::dbg_macro)]
use std::{future::Future, io, net::SocketAddr, sync::atomic, time::Duration};
use std::{
fs::Permissions, future::Future, io, net::SocketAddr, os::unix::fs::PermissionsExt,
sync::atomic, time::Duration,
};
use axum::{
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
@ -26,6 +29,8 @@ use http::{
header::{self, HeaderName},
Method, StatusCode, Uri,
};
use hyper::Server;
use hyperlocal::SocketIncoming;
use ruma::api::{
client::{
error::{Error as RumaError, ErrorBody, ErrorKind},
@ -33,7 +38,7 @@ use ruma::api::{
},
IncomingRequest,
};
use tokio::signal;
use tokio::{net::UnixListener, signal, sync::oneshot};
use tower::ServiceBuilder;
use tower_http::{
cors::{self, CorsLayer},
@ -43,6 +48,8 @@ use tower_http::{
use tracing::{debug, error, info, warn};
use tracing_subscriber::{prelude::*, EnvFilter};
use tokio::sync::oneshot::Sender;
pub use conduit::*; // Re-export everything from the library crate
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
@ -69,12 +76,10 @@ async fn main() {
Ok(s) => s,
Err(e) => {
eprintln!("It looks like your config is invalid. The following error occurred: {e}");
std::process::exit(1);
return;
}
};
config.warn_deprecated();
let log = format!("{},ruma_state_res=error,_=off,sled=off", config.log);
if config.allow_jaeger {
@ -135,11 +140,15 @@ async fn main() {
#[cfg(unix)]
maximize_fd_limit().expect("should be able to increase the soft limit to the hard limit");
config.warn_deprecated();
if let Err(_) = config.error_dual_listening(raw_config) {
return;
};
info!("Loading database");
if let Err(error) = KeyValueDatabase::load_or_create(config).await {
error!(?error, "The database couldn't be loaded or created");
std::process::exit(1);
return;
};
let config = &services().globals.config;
@ -200,26 +209,57 @@ async fn run_server() -> io::Result<()> {
let app = routes().layer(middlewares).into_make_service();
let handle = ServerHandle::new();
let (tx, rx) = oneshot::channel::<()>();
tokio::spawn(shutdown_signal(handle.clone()));
tokio::spawn(shutdown_signal(handle.clone(), tx));
match &config.tls {
Some(tls) => {
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
let server = bind_rustls(addr, conf).handle(handle).serve(app);
#[cfg(feature = "systemd")]
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
server.await?
if let Some(path) = &config.unix_socket_path {
if path.exists() {
warn!(
"UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.",
path.display()
);
tokio::fs::remove_file(&path).await?;
}
None => {
let server = bind(addr).handle(handle).serve(app);
#[cfg(feature = "systemd")]
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
tokio::fs::create_dir_all(path.parent().unwrap()).await?;
server.await?
let socket_perms = config.unix_socket_perms.to_string();
let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap();
let listener = UnixListener::bind(path.clone()).unwrap();
tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms))
.await
.unwrap();
let socket = SocketIncoming::from_listener(listener);
#[cfg(feature = "systemd")]
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
let server = Server::builder(socket).serve(app);
let graceful = server.with_graceful_shutdown(async {
rx.await.ok();
});
if let Err(e) = graceful.await {
error!("Server error: {:?}", e);
}
} else {
match &config.tls {
Some(tls) => {
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
let server = bind_rustls(addr, conf).handle(handle).serve(app);
#[cfg(feature = "systemd")]
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
server.await?
}
None => {
let server = bind(addr).handle(handle).serve(app);
#[cfg(feature = "systemd")]
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
server.await?
}
}
}
@ -439,7 +479,7 @@ fn routes() -> Router {
.fallback(not_found)
}
async fn shutdown_signal(handle: ServerHandle) {
async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> {
let ctrl_c = async {
signal::ctrl_c()
.await
@ -471,6 +511,9 @@ async fn shutdown_signal(handle: ServerHandle) {
#[cfg(feature = "systemd")]
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]);
tx.send(()).unwrap();
Ok(())
}
async fn not_found(uri: Uri) -> impl IntoResponse {