#[cfg(unix)] use std::fs::Permissions; // not unix specific, just only for UNIX sockets stuff and *nix container checks #[cfg(unix)] use std::os::unix::fs::PermissionsExt as _; /* not unix specific, just only for UNIX sockets stuff and *nix * container checks */ use std::{io, net::SocketAddr, sync::atomic, time::Duration}; use axum::{ extract::{DefaultBodyLimit, MatchedPath}, response::IntoResponse, Router, }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; #[cfg(feature = "axum_dual_protocol")] use axum_server_dual_protocol::ServerExt; pub use conduit::*; // Re-export everything from the library crate use http::{ header::{self, HeaderName}, Method, StatusCode, }; #[cfg(unix)] use hyperlocal::SocketIncoming; use ruma::api::client::{ error::{Error as RumaError, ErrorBody, ErrorKind}, uiaa::UiaaResponse, }; use tokio::{ signal, sync::oneshot::{self, Sender}, task::JoinSet, }; use tower::ServiceBuilder; use tower_http::{ cors::{self, CorsLayer}, trace::{DefaultOnFailure, TraceLayer}, ServiceBuilderExt as _, }; use tracing::{debug, error, info, warn, Level}; use tracing_subscriber::{prelude::*, reload, EnvFilter, Registry}; mod routes; #[cfg(all(not(target_env = "msvc"), feature = "jemalloc", not(feature = "hardened_malloc")))] #[global_allocator] static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; #[cfg(all( not(target_env = "msvc"), not(target_os = "macos"), feature = "hardened_malloc", target_os = "linux", not(feature = "jemalloc") ))] #[global_allocator] static GLOBAL: hardened_malloc_rs::HardenedMalloc = hardened_malloc_rs::HardenedMalloc; struct Server { config: Config, runtime: tokio::runtime::Runtime, tracing_reload_handle: reload::Handle, #[cfg(feature = "sentry_telemetry")] _sentry_guard: Option, } fn main() -> Result<(), Error> { let args = clap::parse(); let conduwuit: Server = init(args)?; conduwuit .runtime .block_on(async { async_main(&conduwuit).await }) } async fn async_main(server: &Server) -> Result<(), Error> { if let Err(error) = start(server).await { error!("Critical error starting server: {error}"); return Err(Error::Error(format!("{error}"))); } if let Err(error) = run(server).await { error!("Critical error running server: {error}"); return Err(Error::Error(format!("{error}"))); }; if let Err(error) = stop(server).await { error!("Critical error stopping server: {error}"); return Err(Error::Error(format!("{error}"))); } Ok(()) } async fn run(server: &Server) -> io::Result<()> { let app = build(server).await?; let (tx, rx) = oneshot::channel::<()>(); let handle = ServerHandle::new(); tokio::spawn(shutdown(handle.clone(), tx)); #[cfg(unix)] if server.config.unix_socket_path.is_some() { return run_unix_socket_server(server, app, rx).await; } let addrs = server.config.get_bind_addrs(); if server.config.tls.is_some() { return run_tls_server(server, app, handle, addrs).await; } let mut join_set = JoinSet::new(); for addr in &addrs { join_set.spawn(bind(*addr).handle(handle.clone()).serve(app.clone())); } #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental #[cfg(feature = "systemd")] let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); info!("Listening on {:?}", addrs); join_set.join_next().await; Ok(()) } async fn run_tls_server( server: &Server, app: axum::routing::IntoMakeService, handle: ServerHandle, addrs: Vec, ) -> io::Result<()> { let tls = server.config.tls.as_ref().unwrap(); debug!( "Using direct TLS. Certificate path {} and certificate private key path {}", &tls.certs, &tls.key ); info!( "Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit directly with TLS." ); let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; if cfg!(feature = "axum_dual_protocol") { info!( "conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This will only \ take affect if `dual_protocol` is enabled in `[global.tls]`" ); } let mut join_set = JoinSet::new(); if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { #[cfg(feature = "axum_dual_protocol")] for addr in &addrs { join_set.spawn( axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone()) .set_upgrade(false) .handle(handle.clone()) .serve(app.clone()), ); } } else { for addr in &addrs { join_set.spawn( bind_rustls(*addr, conf.clone()) .handle(handle.clone()) .serve(app.clone()), ); } } #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental #[cfg(feature = "systemd")] let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { warn!( "Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too (insecure!)", addrs, &tls.certs ); } else { info!("Listening on {:?} with TLS certificate {}", addrs, &tls.certs); } join_set.join_next().await; Ok(()) } #[cfg(unix)] async fn run_unix_socket_server( server: &Server, app: axum::routing::IntoMakeService, rx: oneshot::Receiver<()>, ) -> io::Result<()> { let path = server.config.unix_socket_path.as_ref().unwrap(); if path.exists() { warn!( "UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.", path.display() ); tokio::fs::remove_file(&path).await?; } tokio::fs::create_dir_all(path.parent().unwrap()).await?; let socket_perms = server.config.unix_socket_perms.to_string(); let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap(); let listener = tokio::net::UnixListener::bind(path.clone())?; tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms)) .await .unwrap(); let socket = SocketIncoming::from_listener(listener); #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental #[cfg(feature = "systemd")] let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); info!("Listening at {:?}", path); let server = hyper::Server::builder(socket).serve(app); let graceful = server.with_graceful_shutdown(async { rx.await.ok(); }); if let Err(e) = graceful.await { error!("Server error: {:?}", e); } Ok(()) } async fn shutdown(handle: ServerHandle, tx: Sender<()>) -> Result<()> { let ctrl_c = async { signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install SIGTERM handler") .recv() .await; }; let sig: &str; #[cfg(unix)] tokio::select! { () = ctrl_c => { sig = "Ctrl+C"; }, () = terminate => { sig = "SIGTERM"; }, } #[cfg(not(unix))] tokio::select! { _ = ctrl_c => { sig = "Ctrl+C"; }, } warn!("Received {}, shutting down...", sig); handle.graceful_shutdown(Some(Duration::from_secs(180))); services().globals.shutdown(); #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental #[cfg(feature = "systemd")] let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]); tx.send(()).expect( "failed sending shutdown transaction to oneshot channel (this is unlikely a conduwuit bug and more so your \ system may not be in an okay/ideal state.)", ); Ok(()) } async fn stop(_server: &Server) -> io::Result<()> { info!("Shutdown complete."); Ok(()) } /// Async initializations async fn start(server: &Server) -> Result<(), Error> { KeyValueDatabase::load_or_create(server.config.clone(), server.tracing_reload_handle.clone()).await?; Ok(()) } async fn build(server: &Server) -> io::Result> { let base_middlewares = ServiceBuilder::new(); #[cfg(feature = "sentry_telemetry")] let base_middlewares = base_middlewares.layer(sentry_tower::NewSentryLayer::>::new_from_top()); let x_forwarded_for = HeaderName::from_static("x-forwarded-for"); let middlewares = base_middlewares .sensitive_headers([header::AUTHORIZATION]) .sensitive_request_headers([x_forwarded_for].into()) .layer(axum::middleware::from_fn(request_spawn)) .layer( TraceLayer::new_for_http() .make_span_with(tracing_span::<_>) .on_failure(DefaultOnFailure::new().level(Level::INFO)), ) .layer(axum::middleware::from_fn(request_handler)) .layer(cors_layer(server)) .layer(DefaultBodyLimit::max( server .config .max_request_size .try_into() .expect("failed to convert max request size"), )); #[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] { Ok(routes::routes(&server.config) .layer(compression_layer(server)) .layer(middlewares) .into_make_service()) } #[cfg(not(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression")))] { Ok(routes::routes().layer(middlewares).into_make_service()) } } async fn request_spawn( req: http::Request, next: axum::middleware::Next, ) -> Result { if services().globals.shutdown.load(atomic::Ordering::Relaxed) { return Err(StatusCode::SERVICE_UNAVAILABLE); } tokio::spawn(next.run(req)) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } async fn request_handler( req: http::Request, next: axum::middleware::Next, ) -> Result { let method = req.method().clone(); let uri = req.uri().clone(); let inner = next.run(req).await; if inner.status() == StatusCode::METHOD_NOT_ALLOWED { if uri.path().contains("_matrix/") { warn!("Method not allowed: {method} {uri}"); } else { info!("Method not allowed: {method} {uri}"); } return Ok(RumaResponse(UiaaResponse::MatrixError(RumaError { body: ErrorBody::Standard { kind: ErrorKind::Unrecognized, message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(), }, status_code: StatusCode::METHOD_NOT_ALLOWED, })) .into_response()); } Ok(inner) } fn cors_layer(_server: &Server) -> CorsLayer { let methods = [ Method::GET, Method::HEAD, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS, ]; let headers = [ header::ORIGIN, HeaderName::from_static("x-requested-with"), header::CONTENT_TYPE, header::ACCEPT, header::AUTHORIZATION, ]; CorsLayer::new() .allow_origin(cors::Any) .allow_methods(methods) .allow_headers(headers) .max_age(Duration::from_secs(86400)) } #[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer { let mut compression_layer = tower_http::compression::CompressionLayer::new(); #[cfg(feature = "zstd_compression")] { if server.config.zstd_compression { compression_layer = compression_layer.zstd(true); } else { compression_layer = compression_layer.no_zstd(); }; }; #[cfg(feature = "gzip_compression")] { if server.config.gzip_compression { compression_layer = compression_layer.gzip(true); } else { compression_layer = compression_layer.no_gzip(); }; }; #[cfg(feature = "brotli_compression")] { if server.config.brotli_compression { compression_layer = compression_layer.br(true); } else { compression_layer = compression_layer.no_br(); }; }; compression_layer } fn tracing_span(request: &http::Request) -> tracing::Span { let path = if let Some(path) = request.extensions().get::() { path.as_str() } else { request.uri().path() }; tracing::info_span!("handle", %path) } /// Non-async initializations fn init(args: clap::Args) -> Result { let config = Config::new(args.config)?; #[cfg(feature = "sentry_telemetry")] let sentry_guard = if config.sentry { Some(init_sentry(&config)) } else { None }; let tracing_reload_handle; #[cfg(feature = "perf_measurements")] { tracing_reload_handle = if config.allow_jaeger { init_tracing_jaeger(&config) } else if config.tracing_flame { #[cfg(feature = "perf_measurements")] init_tracing_flame(&config) } else { init_tracing_sub(&config) }; }; #[cfg(not(feature = "perf_measurements"))] { tracing_reload_handle = init_tracing_sub(&config); }; info!( server_name = ?config.server_name, database_path = ?config.database_path, log_levels = ?config.log, "{}", env!("CARGO_PKG_VERSION"), ); #[cfg(unix)] maximize_fd_limit().expect("Unable to increase maximum soft and hard file descriptor limit"); Ok(Server { config, runtime: tokio::runtime::Builder::new_multi_thread() .enable_io() .enable_time() .thread_name("conduwuit:worker") .worker_threads(num_cpus::get_physical()) .build() .unwrap(), tracing_reload_handle, #[cfg(feature = "sentry_telemetry")] _sentry_guard: sentry_guard, }) } #[cfg(feature = "sentry_telemetry")] fn init_sentry(config: &Config) -> sentry::ClientInitGuard { sentry::init(( "https://fe2eb4536aa04949e28eff3128d64757@o4506996327251968.ingest.us.sentry.io/4506996334657536", sentry::ClientOptions { release: sentry::release_name!(), traces_sample_rate: config.sentry_traces_sample_rate, server_name: if config.sentry_send_server_name { Some(config.server_name.to_string().into()) } else { None }, ..Default::default() }, )) } fn init_tracing_sub(config: &Config) -> reload::Handle { let registry = Registry::default(); let fmt_layer = tracing_subscriber::fmt::Layer::new(); let filter_layer = match EnvFilter::try_new(&config.log) { Ok(s) => s, Err(e) => { eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); EnvFilter::try_new("warn").unwrap() }, }; let (reload_filter, reload_handle) = reload::Layer::new(filter_layer); #[cfg(feature = "sentry_telemetry")] let sentry_layer = sentry_tracing::layer(); let subscriber; #[allow(clippy::unnecessary_operation)] // error[E0658]: attributes on expressions are experimental #[cfg(feature = "sentry_telemetry")] { subscriber = registry .with(reload_filter) .with(fmt_layer) .with(sentry_layer); }; #[allow(clippy::unnecessary_operation)] // error[E0658]: attributes on expressions are experimental #[cfg(not(feature = "sentry_telemetry"))] { subscriber = registry.with(reload_filter).with(fmt_layer); }; tracing::subscriber::set_global_default(subscriber).unwrap(); reload_handle } #[cfg(feature = "perf_measurements")] fn init_tracing_jaeger(config: &Config) -> reload::Handle { opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); let tracer = opentelemetry_jaeger::new_agent_pipeline() .with_auto_split_batch(true) .with_service_name("conduwuit") .install_batch(opentelemetry_sdk::runtime::Tokio) .unwrap(); let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); let filter_layer = match EnvFilter::try_new(&config.log) { Ok(s) => s, Err(e) => { eprintln!("It looks like your log config is invalid. The following error occurred: {e}"); EnvFilter::try_new("warn").unwrap() }, }; let (reload_filter, reload_handle) = reload::Layer::new(filter_layer); let subscriber = Registry::default().with(reload_filter).with(telemetry); tracing::subscriber::set_global_default(subscriber).unwrap(); reload_handle } #[cfg(feature = "perf_measurements")] fn init_tracing_flame(_config: &Config) -> reload::Handle { let registry = Registry::default(); let (flame_layer, _guard) = tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap(); let flame_layer = flame_layer.with_empty_samples(false); let filter_layer = EnvFilter::new("trace,h2=off"); let (reload_filter, reload_handle) = reload::Layer::new(filter_layer); let subscriber = registry.with(reload_filter).with(flame_layer); tracing::subscriber::set_global_default(subscriber).unwrap(); reload_handle } // This is needed for opening lots of file descriptors, which tends to // happen more often when using RocksDB and making lots of federation // connections at startup. The soft limit is usually 1024, and the hard // limit is usually 512000; I've personally seen it hit >2000. // // * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6 // * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741 #[cfg(unix)] fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { use nix::sys::resource::{getrlimit, setrlimit, Resource::RLIMIT_NOFILE as NOFILE}; let (soft_limit, hard_limit) = getrlimit(NOFILE)?; if soft_limit < hard_limit { setrlimit(NOFILE, hard_limit, hard_limit)?; assert_eq!((hard_limit, hard_limit), getrlimit(NOFILE)?, "getrlimit != setrlimit"); debug!(to = hard_limit, from = soft_limit, "Raised RLIMIT_NOFILE",); } Ok(()) }