Hot-Reloading Refactor
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
ae1a4fd283
commit
6c1434c165
212 changed files with 5679 additions and 4206 deletions
85
src/router/Cargo.toml
Normal file
85
src/router/Cargo.toml
Normal file
|
@ -0,0 +1,85 @@
|
|||
[package]
|
||||
name = "conduit_router"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "mod.rs"
|
||||
crate-type = [
|
||||
"rlib",
|
||||
# "dylib",
|
||||
]
|
||||
|
||||
[features]
|
||||
default = [
|
||||
"systemd",
|
||||
"sentry_telemetry",
|
||||
"gzip_compression",
|
||||
"zstd_compression",
|
||||
"brotli_compression",
|
||||
"release_max_log_level",
|
||||
]
|
||||
|
||||
dev_release_log_level = []
|
||||
release_max_log_level = [
|
||||
"tracing/max_level_trace",
|
||||
"tracing/release_max_level_info",
|
||||
"log/max_level_trace",
|
||||
"log/release_max_level_info",
|
||||
]
|
||||
sentry_telemetry = [
|
||||
"dep:sentry",
|
||||
"dep:sentry-tracing",
|
||||
"dep:sentry-tower",
|
||||
]
|
||||
zstd_compression = [
|
||||
"tower-http/compression-zstd",
|
||||
]
|
||||
gzip_compression = [
|
||||
"tower-http/compression-gzip",
|
||||
]
|
||||
brotli_compression = [
|
||||
"tower-http/compression-br",
|
||||
]
|
||||
systemd = [
|
||||
"dep:sd-notify",
|
||||
]
|
||||
axum_dual_protocol = [
|
||||
"dep:axum-server-dual-protocol"
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
axum-server-dual-protocol.optional = true
|
||||
axum-server-dual-protocol.workspace = true
|
||||
axum-server.workspace = true
|
||||
axum.workspace = true
|
||||
conduit-admin.workspace = true
|
||||
conduit-api.workspace = true
|
||||
conduit-core.workspace = true
|
||||
conduit-database.workspace = true
|
||||
conduit-service.workspace = true
|
||||
log.workspace = true
|
||||
tokio.workspace = true
|
||||
tower.workspace = true
|
||||
tracing.workspace = true
|
||||
bytes.workspace = true
|
||||
clap.workspace = true
|
||||
http-body-util.workspace = true
|
||||
http.workspace = true
|
||||
regex.workspace = true
|
||||
ruma.workspace = true
|
||||
sentry.optional = true
|
||||
sentry-tower.optional = true
|
||||
sentry-tower.workspace = true
|
||||
sentry-tracing.optional = true
|
||||
sentry-tracing.workspace = true
|
||||
sentry.workspace = true
|
||||
serde_json.workspace = true
|
||||
tower-http.workspace = true
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
sd-notify.workspace = true
|
||||
sd-notify.optional = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
190
src/router/layers.rs
Normal file
190
src/router/layers.rs
Normal file
|
@ -0,0 +1,190 @@
|
|||
use std::{any::Any, io, sync::Arc, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{DefaultBodyLimit, MatchedPath},
|
||||
Router,
|
||||
};
|
||||
use conduit::Server;
|
||||
use http::{
|
||||
header::{self, HeaderName},
|
||||
HeaderValue, Method, StatusCode,
|
||||
};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::{
|
||||
catch_panic::CatchPanicLayer,
|
||||
cors::{self, CorsLayer},
|
||||
set_header::SetResponseHeaderLayer,
|
||||
trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer},
|
||||
ServiceBuilderExt as _,
|
||||
};
|
||||
use tracing::Level;
|
||||
|
||||
use crate::{request, router};
|
||||
|
||||
pub(crate) fn build(server: &Arc<Server>) -> io::Result<axum::routing::IntoMakeService<Router>> {
|
||||
let layers = ServiceBuilder::new();
|
||||
|
||||
#[cfg(feature = "sentry_telemetry")]
|
||||
let layers = layers.layer(sentry_tower::NewSentryLayer::<http::Request<_>>::new_from_top());
|
||||
|
||||
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
|
||||
let layers = layers.layer(compression_layer(server));
|
||||
|
||||
let layers = layers
|
||||
.sensitive_headers([header::AUTHORIZATION])
|
||||
.sensitive_request_headers([HeaderName::from_static("x-forwarded-for")].into())
|
||||
.layer(axum::middleware::from_fn_with_state(Arc::clone(server), request::spawn))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(tracing_span::<_>)
|
||||
.on_failure(DefaultOnFailure::new().level(Level::ERROR))
|
||||
.on_request(DefaultOnRequest::new().level(Level::TRACE))
|
||||
.on_response(DefaultOnResponse::new().level(Level::DEBUG)),
|
||||
)
|
||||
.layer(axum::middleware::from_fn_with_state(Arc::clone(server), request::handle))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
HeaderName::from_static("origin-agent-cluster"), // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin-Agent-Cluster
|
||||
HeaderValue::from_static("?1"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
header::X_CONTENT_TYPE_OPTIONS,
|
||||
HeaderValue::from_static("nosniff"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
header::X_XSS_PROTECTION,
|
||||
HeaderValue::from_static("0"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
header::X_FRAME_OPTIONS,
|
||||
HeaderValue::from_static("DENY"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
HeaderName::from_static("permissions-policy"),
|
||||
HeaderValue::from_static("interest-cohort=(),browsing-topics=()"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
header::CONTENT_SECURITY_POLICY,
|
||||
HeaderValue::from_static(
|
||||
"sandbox; default-src 'none'; font-src 'none'; script-src 'none'; plugin-types application/pdf; \
|
||||
style-src 'unsafe-inline'; object-src 'self'; frame-ancesors 'none';",
|
||||
),
|
||||
))
|
||||
.layer(cors_layer(server))
|
||||
.layer(body_limit_layer(server))
|
||||
.layer(CatchPanicLayer::custom(catch_panic));
|
||||
|
||||
let routes = router::build(server);
|
||||
let layers = routes.layer(layers);
|
||||
|
||||
Ok(layers.into_make_service())
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
|
||||
fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer {
|
||||
let mut compression_layer = tower_http::compression::CompressionLayer::new();
|
||||
|
||||
#[cfg(feature = "zstd_compression")]
|
||||
{
|
||||
if server.config.zstd_compression {
|
||||
compression_layer = compression_layer.zstd(true);
|
||||
} else {
|
||||
compression_layer = compression_layer.no_zstd();
|
||||
};
|
||||
};
|
||||
|
||||
#[cfg(feature = "gzip_compression")]
|
||||
{
|
||||
if server.config.gzip_compression {
|
||||
compression_layer = compression_layer.gzip(true);
|
||||
} else {
|
||||
compression_layer = compression_layer.no_gzip();
|
||||
};
|
||||
};
|
||||
|
||||
#[cfg(feature = "brotli_compression")]
|
||||
{
|
||||
if server.config.brotli_compression {
|
||||
compression_layer = compression_layer.br(true);
|
||||
} else {
|
||||
compression_layer = compression_layer.no_br();
|
||||
};
|
||||
};
|
||||
|
||||
compression_layer
|
||||
}
|
||||
|
||||
fn cors_layer(_server: &Server) -> CorsLayer {
|
||||
const METHODS: [Method; 7] = [
|
||||
Method::GET,
|
||||
Method::HEAD,
|
||||
Method::PATCH,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::DELETE,
|
||||
Method::OPTIONS,
|
||||
];
|
||||
|
||||
let headers: [HeaderName; 5] = [
|
||||
header::ORIGIN,
|
||||
HeaderName::from_lowercase(b"x-requested-with").unwrap(),
|
||||
header::CONTENT_TYPE,
|
||||
header::ACCEPT,
|
||||
header::AUTHORIZATION,
|
||||
];
|
||||
|
||||
CorsLayer::new()
|
||||
.allow_origin(cors::Any)
|
||||
.allow_methods(METHODS)
|
||||
.allow_headers(headers)
|
||||
.max_age(Duration::from_secs(86400))
|
||||
}
|
||||
|
||||
fn body_limit_layer(server: &Server) -> DefaultBodyLimit {
|
||||
DefaultBodyLimit::max(
|
||||
server
|
||||
.config
|
||||
.max_request_size
|
||||
.try_into()
|
||||
.expect("failed to convert max request size"),
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn catch_panic(err: Box<dyn Any + Send + 'static>) -> http::Response<http_body_util::Full<bytes::Bytes>> {
|
||||
conduit_service::services()
|
||||
.server
|
||||
.requests_panic
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Release);
|
||||
|
||||
let details = if let Some(s) = err.downcast_ref::<String>() {
|
||||
s.clone()
|
||||
} else if let Some(s) = err.downcast_ref::<&str>() {
|
||||
s.to_string()
|
||||
} else {
|
||||
"Unknown internal server error occurred.".to_owned()
|
||||
};
|
||||
|
||||
let body = serde_json::json!({
|
||||
"errcode": "M_UNKNOWN",
|
||||
"error": "M_UNKNOWN: Internal server error occurred",
|
||||
"details": details,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
http::Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(http_body_util::Full::from(body))
|
||||
.expect("Failed to create response for our panic catcher?")
|
||||
}
|
||||
|
||||
fn tracing_span<T>(request: &http::Request<T>) -> tracing::Span {
|
||||
let path = if let Some(path) = request.extensions().get::<MatchedPath>() {
|
||||
path.as_str()
|
||||
} else {
|
||||
request.uri().path()
|
||||
};
|
||||
|
||||
tracing::info_span!("router:", %path)
|
||||
}
|
|
@ -1,258 +1,29 @@
|
|||
use std::{any::Any, io, sync::atomic, time::Duration};
|
||||
pub(crate) mod layers;
|
||||
pub(crate) mod request;
|
||||
pub(crate) mod router;
|
||||
pub(crate) mod run;
|
||||
pub(crate) mod serve;
|
||||
|
||||
use axum::{
|
||||
extract::{DefaultBodyLimit, MatchedPath},
|
||||
response::IntoResponse,
|
||||
Router,
|
||||
};
|
||||
use http::{
|
||||
header::{self, HeaderName, HeaderValue},
|
||||
Method, StatusCode, Uri,
|
||||
};
|
||||
use ruma::api::client::{
|
||||
error::{Error as RumaError, ErrorBody, ErrorKind},
|
||||
uiaa::UiaaResponse,
|
||||
};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::{
|
||||
catch_panic::CatchPanicLayer,
|
||||
cors::{self, CorsLayer},
|
||||
set_header::SetResponseHeaderLayer,
|
||||
trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer},
|
||||
ServiceBuilderExt as _,
|
||||
};
|
||||
use tracing::{debug, error, trace, Level};
|
||||
extern crate conduit_core as conduit;
|
||||
|
||||
use super::{api::ruma_wrapper::RumaResponse, debug_error, services, utils::error::Result, Server};
|
||||
use std::{future::Future, pin::Pin, sync::Arc};
|
||||
|
||||
mod routes;
|
||||
use conduit::{Result, Server};
|
||||
|
||||
pub(crate) async fn build(server: &Server) -> io::Result<axum::routing::IntoMakeService<Router>> {
|
||||
let base_middlewares = ServiceBuilder::new();
|
||||
#[cfg(feature = "sentry_telemetry")]
|
||||
let base_middlewares = base_middlewares.layer(sentry_tower::NewSentryLayer::<http::Request<_>>::new_from_top());
|
||||
conduit::mod_ctor! {}
|
||||
conduit::mod_dtor! {}
|
||||
|
||||
let x_forwarded_for = HeaderName::from_static("x-forwarded-for");
|
||||
let permissions_policy = HeaderName::from_static("permissions-policy");
|
||||
let origin_agent_cluster = HeaderName::from_static("origin-agent-cluster"); // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin-Agent-Cluster
|
||||
|
||||
let middlewares = base_middlewares
|
||||
.sensitive_headers([header::AUTHORIZATION])
|
||||
.sensitive_request_headers([x_forwarded_for].into())
|
||||
.layer(axum::middleware::from_fn(request_spawn))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(tracing_span::<_>)
|
||||
.on_failure(DefaultOnFailure::new().level(Level::ERROR))
|
||||
.on_request(DefaultOnRequest::new().level(Level::TRACE))
|
||||
.on_response(DefaultOnResponse::new().level(Level::DEBUG)),
|
||||
)
|
||||
.layer(axum::middleware::from_fn(request_handle))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
origin_agent_cluster,
|
||||
HeaderValue::from_static("?1"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
header::X_CONTENT_TYPE_OPTIONS,
|
||||
HeaderValue::from_static("nosniff"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
header::X_XSS_PROTECTION,
|
||||
HeaderValue::from_static("0"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
header::X_FRAME_OPTIONS,
|
||||
HeaderValue::from_static("DENY"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
permissions_policy,
|
||||
HeaderValue::from_static("interest-cohort=(),browsing-topics=()"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
header::CONTENT_SECURITY_POLICY,
|
||||
HeaderValue::from_static(
|
||||
"sandbox; default-src 'none'; font-src 'none'; script-src 'none'; plugin-types application/pdf; \
|
||||
style-src 'unsafe-inline'; object-src 'self'; frame-ancesors 'none';",
|
||||
),
|
||||
))
|
||||
.layer(cors_layer(server))
|
||||
.layer(DefaultBodyLimit::max(
|
||||
server
|
||||
.config
|
||||
.max_request_size
|
||||
.try_into()
|
||||
.expect("failed to convert max request size"),
|
||||
))
|
||||
.layer(CatchPanicLayer::custom(catch_panic_layer));
|
||||
|
||||
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
|
||||
{
|
||||
Ok(routes::routes(&server.config)
|
||||
.layer(compression_layer(server))
|
||||
.layer(middlewares)
|
||||
.into_make_service())
|
||||
}
|
||||
#[cfg(not(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression")))]
|
||||
{
|
||||
Ok(routes::routes().layer(middlewares).into_make_service())
|
||||
}
|
||||
#[no_mangle]
|
||||
pub extern "Rust" fn start(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>>>> {
|
||||
Box::pin(run::start(server.clone()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "spawn")]
|
||||
async fn request_spawn(
|
||||
req: http::Request<axum::body::Body>, next: axum::middleware::Next,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
|
||||
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
||||
}
|
||||
|
||||
let fut = next.run(req);
|
||||
let task = tokio::spawn(fut);
|
||||
task.await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
#[no_mangle]
|
||||
pub extern "Rust" fn stop(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>>>> {
|
||||
Box::pin(run::stop(server.clone()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "handle")]
|
||||
async fn request_handle(
|
||||
req: http::Request<axum::body::Body>, next: axum::middleware::Next,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
let result = next.run(req).await;
|
||||
request_result(&method, &uri, result)
|
||||
}
|
||||
|
||||
fn request_result(
|
||||
method: &Method, uri: &Uri, result: axum::response::Response,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
request_result_log(method, uri, &result);
|
||||
match result.status() {
|
||||
StatusCode::METHOD_NOT_ALLOWED => request_result_403(method, uri, &result),
|
||||
_ => Ok(result),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn request_result_403(
|
||||
_method: &Method, _uri: &Uri, result: &axum::response::Response,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
let error = UiaaResponse::MatrixError(RumaError {
|
||||
status_code: result.status(),
|
||||
body: ErrorBody::Standard {
|
||||
kind: ErrorKind::Unrecognized,
|
||||
message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(),
|
||||
},
|
||||
});
|
||||
|
||||
Ok(RumaResponse(error).into_response())
|
||||
}
|
||||
|
||||
fn request_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) {
|
||||
let status = result.status();
|
||||
let reason = status.canonical_reason().unwrap_or("Unknown Reason");
|
||||
let code = status.as_u16();
|
||||
if status.is_server_error() {
|
||||
error!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
} else if status.is_client_error() {
|
||||
debug_error!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
} else if status.is_redirection() {
|
||||
debug!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
} else {
|
||||
trace!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Cross-Origin-Resource-Sharing header as defined by spec:
|
||||
/// <https://spec.matrix.org/latest/client-server-api/#web-browser-clients>
|
||||
fn cors_layer(_server: &Server) -> CorsLayer {
|
||||
const METHODS: [Method; 7] = [
|
||||
Method::GET,
|
||||
Method::HEAD,
|
||||
Method::PATCH,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::DELETE,
|
||||
Method::OPTIONS,
|
||||
];
|
||||
|
||||
let headers: [HeaderName; 5] = [
|
||||
header::ORIGIN,
|
||||
HeaderName::from_lowercase(b"x-requested-with").unwrap(),
|
||||
header::CONTENT_TYPE,
|
||||
header::ACCEPT,
|
||||
header::AUTHORIZATION,
|
||||
];
|
||||
|
||||
CorsLayer::new()
|
||||
.allow_origin(cors::Any)
|
||||
.allow_methods(METHODS)
|
||||
.allow_headers(headers)
|
||||
.max_age(Duration::from_secs(86400))
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
|
||||
fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer {
|
||||
let mut compression_layer = tower_http::compression::CompressionLayer::new();
|
||||
|
||||
#[cfg(feature = "zstd_compression")]
|
||||
{
|
||||
if server.config.zstd_compression {
|
||||
compression_layer = compression_layer.zstd(true);
|
||||
} else {
|
||||
compression_layer = compression_layer.no_zstd();
|
||||
};
|
||||
};
|
||||
|
||||
#[cfg(feature = "gzip_compression")]
|
||||
{
|
||||
if server.config.gzip_compression {
|
||||
compression_layer = compression_layer.gzip(true);
|
||||
} else {
|
||||
compression_layer = compression_layer.no_gzip();
|
||||
};
|
||||
};
|
||||
|
||||
#[cfg(feature = "brotli_compression")]
|
||||
{
|
||||
if server.config.brotli_compression {
|
||||
compression_layer = compression_layer.br(true);
|
||||
} else {
|
||||
compression_layer = compression_layer.no_br();
|
||||
};
|
||||
};
|
||||
|
||||
compression_layer
|
||||
}
|
||||
|
||||
fn tracing_span<T>(request: &http::Request<T>) -> tracing::Span {
|
||||
let path = if let Some(path) = request.extensions().get::<MatchedPath>() {
|
||||
path.as_str()
|
||||
} else {
|
||||
request.uri().path()
|
||||
};
|
||||
|
||||
tracing::info_span!("router:", %path)
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
fn catch_panic_layer(err: Box<dyn Any + Send + 'static>) -> http::Response<http_body_util::Full<bytes::Bytes>> {
|
||||
let details = if let Some(s) = err.downcast_ref::<String>() {
|
||||
s.clone()
|
||||
} else if let Some(s) = err.downcast_ref::<&str>() {
|
||||
s.to_string()
|
||||
} else {
|
||||
"Unknown internal server error occurred.".to_owned()
|
||||
};
|
||||
|
||||
let body = serde_json::json!({
|
||||
"errcode": "M_UNKNOWN",
|
||||
"error": "M_UNKNOWN: Internal server error occurred",
|
||||
"details": details,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
http::Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(http_body_util::Full::from(body))
|
||||
.expect("Failed to create response for our panic catcher?")
|
||||
#[no_mangle]
|
||||
pub extern "Rust" fn run(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>>>> {
|
||||
Box::pin(run::run(server.clone()))
|
||||
}
|
||||
|
|
102
src/router/request.rs
Normal file
102
src/router/request.rs
Normal file
|
@ -0,0 +1,102 @@
|
|||
use std::sync::{atomic::Ordering, Arc};
|
||||
|
||||
use axum::{extract::State, response::IntoResponse};
|
||||
use conduit::{debug_error, debug_warn, defer, Result, RumaResponse, Server};
|
||||
use http::{Method, StatusCode, Uri};
|
||||
use ruma::api::client::{
|
||||
error::{Error as RumaError, ErrorBody, ErrorKind},
|
||||
uiaa::UiaaResponse,
|
||||
};
|
||||
use tracing::{debug, error, trace};
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn spawn(
|
||||
State(server): State<Arc<Server>>, req: http::Request<axum::body::Body>, next: axum::middleware::Next,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
if server.interrupt.load(Ordering::Relaxed) {
|
||||
debug_warn!("unavailable pending shutdown");
|
||||
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
||||
}
|
||||
|
||||
let active = server.requests_spawn_active.fetch_add(1, Ordering::Relaxed);
|
||||
trace!(active, "enter");
|
||||
defer! {{
|
||||
let active = server.requests_spawn_active.fetch_sub(1, Ordering::Relaxed);
|
||||
let finished = server.requests_spawn_finished.fetch_add(1, Ordering::Relaxed);
|
||||
trace!(active, finished, "leave");
|
||||
}};
|
||||
|
||||
let fut = next.run(req);
|
||||
let task = server.runtime().spawn(fut);
|
||||
task.await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "handle")]
|
||||
pub(crate) async fn handle(
|
||||
State(server): State<Arc<Server>>, req: http::Request<axum::body::Body>, next: axum::middleware::Next,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
if server.interrupt.load(Ordering::Relaxed) {
|
||||
debug_warn!(
|
||||
method = %req.method(),
|
||||
uri = %req.uri(),
|
||||
"unavailable pending shutdown"
|
||||
);
|
||||
|
||||
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
||||
}
|
||||
|
||||
let active = server
|
||||
.requests_handle_active
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
trace!(active, "enter");
|
||||
defer! {{
|
||||
let active = server.requests_handle_active.fetch_sub(1, Ordering::Relaxed);
|
||||
let finished = server.requests_handle_finished.fetch_add(1, Ordering::Relaxed);
|
||||
trace!(active, finished, "leave");
|
||||
}};
|
||||
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
let result = next.run(req).await;
|
||||
handle_result(&method, &uri, result)
|
||||
}
|
||||
|
||||
fn handle_result(
|
||||
method: &Method, uri: &Uri, result: axum::response::Response,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
handle_result_log(method, uri, &result);
|
||||
match result.status() {
|
||||
StatusCode::METHOD_NOT_ALLOWED => handle_result_403(method, uri, &result),
|
||||
_ => Ok(result),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn handle_result_403(
|
||||
_method: &Method, _uri: &Uri, result: &axum::response::Response,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
let error = UiaaResponse::MatrixError(RumaError {
|
||||
status_code: result.status(),
|
||||
body: ErrorBody::Standard {
|
||||
kind: ErrorKind::Unrecognized,
|
||||
message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(),
|
||||
},
|
||||
});
|
||||
|
||||
Ok(RumaResponse(error).into_response())
|
||||
}
|
||||
|
||||
fn handle_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) {
|
||||
let status = result.status();
|
||||
let reason = status.canonical_reason().unwrap_or("Unknown Reason");
|
||||
let code = status.as_u16();
|
||||
if status.is_server_error() {
|
||||
error!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
} else if status.is_client_error() {
|
||||
debug_error!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
} else if status.is_redirection() {
|
||||
debug!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
} else {
|
||||
trace!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
}
|
||||
}
|
20
src/router/router.rs
Normal file
20
src/router/router.rs
Normal file
|
@ -0,0 +1,20 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::{response::IntoResponse, routing::get, Router};
|
||||
use conduit::{Error, Server};
|
||||
use http::Uri;
|
||||
use ruma::api::client::error::ErrorKind;
|
||||
|
||||
extern crate conduit_api as api;
|
||||
|
||||
pub(crate) fn build(server: &Arc<Server>) -> Router {
|
||||
let router = Router::new().fallback(not_found).route("/", get(it_works));
|
||||
|
||||
api::router::build(router, server)
|
||||
}
|
||||
|
||||
async fn not_found(_uri: Uri) -> impl IntoResponse {
|
||||
Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request")
|
||||
}
|
||||
|
||||
async fn it_works() -> &'static str { "hewwo from conduwuit woof!" }
|
|
@ -1,322 +0,0 @@
|
|||
use std::future::Future;
|
||||
|
||||
use axum::{
|
||||
extract::FromRequestParts,
|
||||
response::IntoResponse,
|
||||
routing::{any, get, on, post, MethodFilter},
|
||||
Router,
|
||||
};
|
||||
use http::{Method, Uri};
|
||||
use ruma::api::{client::error::ErrorKind, IncomingRequest};
|
||||
|
||||
use crate::{
|
||||
api::{client_server, server_server},
|
||||
Config, Error, Result, Ruma, RumaResponse,
|
||||
};
|
||||
|
||||
pub(crate) fn routes(config: &Config) -> Router {
|
||||
let router = Router::new()
|
||||
.ruma_route(client_server::get_supported_versions_route)
|
||||
.ruma_route(client_server::get_register_available_route)
|
||||
.ruma_route(client_server::register_route)
|
||||
.ruma_route(client_server::get_login_types_route)
|
||||
.ruma_route(client_server::login_route)
|
||||
.ruma_route(client_server::whoami_route)
|
||||
.ruma_route(client_server::logout_route)
|
||||
.ruma_route(client_server::logout_all_route)
|
||||
.ruma_route(client_server::change_password_route)
|
||||
.ruma_route(client_server::deactivate_route)
|
||||
.ruma_route(client_server::third_party_route)
|
||||
.ruma_route(client_server::request_3pid_management_token_via_email_route)
|
||||
.ruma_route(client_server::request_3pid_management_token_via_msisdn_route)
|
||||
.ruma_route(client_server::get_capabilities_route)
|
||||
.ruma_route(client_server::get_pushrules_all_route)
|
||||
.ruma_route(client_server::set_pushrule_route)
|
||||
.ruma_route(client_server::get_pushrule_route)
|
||||
.ruma_route(client_server::set_pushrule_enabled_route)
|
||||
.ruma_route(client_server::get_pushrule_enabled_route)
|
||||
.ruma_route(client_server::get_pushrule_actions_route)
|
||||
.ruma_route(client_server::set_pushrule_actions_route)
|
||||
.ruma_route(client_server::delete_pushrule_route)
|
||||
.ruma_route(client_server::get_room_event_route)
|
||||
.ruma_route(client_server::get_room_aliases_route)
|
||||
.ruma_route(client_server::get_filter_route)
|
||||
.ruma_route(client_server::create_filter_route)
|
||||
.ruma_route(client_server::set_global_account_data_route)
|
||||
.ruma_route(client_server::set_room_account_data_route)
|
||||
.ruma_route(client_server::get_global_account_data_route)
|
||||
.ruma_route(client_server::get_room_account_data_route)
|
||||
.ruma_route(client_server::set_displayname_route)
|
||||
.ruma_route(client_server::get_displayname_route)
|
||||
.ruma_route(client_server::set_avatar_url_route)
|
||||
.ruma_route(client_server::get_avatar_url_route)
|
||||
.ruma_route(client_server::get_profile_route)
|
||||
.ruma_route(client_server::set_presence_route)
|
||||
.ruma_route(client_server::get_presence_route)
|
||||
.ruma_route(client_server::upload_keys_route)
|
||||
.ruma_route(client_server::get_keys_route)
|
||||
.ruma_route(client_server::claim_keys_route)
|
||||
.ruma_route(client_server::create_backup_version_route)
|
||||
.ruma_route(client_server::update_backup_version_route)
|
||||
.ruma_route(client_server::delete_backup_version_route)
|
||||
.ruma_route(client_server::get_latest_backup_info_route)
|
||||
.ruma_route(client_server::get_backup_info_route)
|
||||
.ruma_route(client_server::add_backup_keys_route)
|
||||
.ruma_route(client_server::add_backup_keys_for_room_route)
|
||||
.ruma_route(client_server::add_backup_keys_for_session_route)
|
||||
.ruma_route(client_server::delete_backup_keys_for_room_route)
|
||||
.ruma_route(client_server::delete_backup_keys_for_session_route)
|
||||
.ruma_route(client_server::delete_backup_keys_route)
|
||||
.ruma_route(client_server::get_backup_keys_for_room_route)
|
||||
.ruma_route(client_server::get_backup_keys_for_session_route)
|
||||
.ruma_route(client_server::get_backup_keys_route)
|
||||
.ruma_route(client_server::set_read_marker_route)
|
||||
.ruma_route(client_server::create_receipt_route)
|
||||
.ruma_route(client_server::create_typing_event_route)
|
||||
.ruma_route(client_server::create_room_route)
|
||||
.ruma_route(client_server::redact_event_route)
|
||||
.ruma_route(client_server::report_event_route)
|
||||
.ruma_route(client_server::create_alias_route)
|
||||
.ruma_route(client_server::delete_alias_route)
|
||||
.ruma_route(client_server::get_alias_route)
|
||||
.ruma_route(client_server::join_room_by_id_route)
|
||||
.ruma_route(client_server::join_room_by_id_or_alias_route)
|
||||
.ruma_route(client_server::joined_members_route)
|
||||
.ruma_route(client_server::leave_room_route)
|
||||
.ruma_route(client_server::forget_room_route)
|
||||
.ruma_route(client_server::joined_rooms_route)
|
||||
.ruma_route(client_server::kick_user_route)
|
||||
.ruma_route(client_server::ban_user_route)
|
||||
.ruma_route(client_server::unban_user_route)
|
||||
.ruma_route(client_server::invite_user_route)
|
||||
.ruma_route(client_server::set_room_visibility_route)
|
||||
.ruma_route(client_server::get_room_visibility_route)
|
||||
.ruma_route(client_server::get_public_rooms_route)
|
||||
.ruma_route(client_server::get_public_rooms_filtered_route)
|
||||
.ruma_route(client_server::search_users_route)
|
||||
.ruma_route(client_server::get_member_events_route)
|
||||
.ruma_route(client_server::get_protocols_route)
|
||||
.ruma_route(client_server::send_message_event_route)
|
||||
.ruma_route(client_server::send_state_event_for_key_route)
|
||||
.ruma_route(client_server::get_state_events_route)
|
||||
.ruma_route(client_server::get_state_events_for_key_route)
|
||||
// Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes
|
||||
// share one Ruma request / response type pair with {get,send}_state_event_for_key_route
|
||||
.route(
|
||||
"/_matrix/client/r0/rooms/:room_id/state/:event_type",
|
||||
get(client_server::get_state_events_for_empty_key_route)
|
||||
.put(client_server::send_state_event_for_empty_key_route),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/client/v3/rooms/:room_id/state/:event_type",
|
||||
get(client_server::get_state_events_for_empty_key_route)
|
||||
.put(client_server::send_state_event_for_empty_key_route),
|
||||
)
|
||||
// These two endpoints allow trailing slashes
|
||||
.route(
|
||||
"/_matrix/client/r0/rooms/:room_id/state/:event_type/",
|
||||
get(client_server::get_state_events_for_empty_key_route)
|
||||
.put(client_server::send_state_event_for_empty_key_route),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/client/v3/rooms/:room_id/state/:event_type/",
|
||||
get(client_server::get_state_events_for_empty_key_route)
|
||||
.put(client_server::send_state_event_for_empty_key_route),
|
||||
)
|
||||
.ruma_route(client_server::sync_events_route)
|
||||
.ruma_route(client_server::sync_events_v4_route)
|
||||
.ruma_route(client_server::get_context_route)
|
||||
.ruma_route(client_server::get_message_events_route)
|
||||
.ruma_route(client_server::search_events_route)
|
||||
.ruma_route(client_server::turn_server_route)
|
||||
.ruma_route(client_server::send_event_to_device_route)
|
||||
.ruma_route(client_server::get_media_config_route)
|
||||
.ruma_route(client_server::get_media_preview_route)
|
||||
.ruma_route(client_server::create_content_route)
|
||||
// legacy v1 media routes
|
||||
.route(
|
||||
"/_matrix/media/v1/preview_url",
|
||||
get(client_server::get_media_preview_v1_route)
|
||||
)
|
||||
.route(
|
||||
"/_matrix/media/v1/config",
|
||||
get(client_server::get_media_config_v1_route)
|
||||
)
|
||||
.route(
|
||||
"/_matrix/media/v1/upload",
|
||||
post(client_server::create_content_v1_route)
|
||||
)
|
||||
.route(
|
||||
"/_matrix/media/v1/download/:server_name/:media_id",
|
||||
get(client_server::get_content_v1_route)
|
||||
)
|
||||
.route(
|
||||
"/_matrix/media/v1/download/:server_name/:media_id/:file_name",
|
||||
get(client_server::get_content_as_filename_v1_route)
|
||||
)
|
||||
.route(
|
||||
"/_matrix/media/v1/thumbnail/:server_name/:media_id",
|
||||
get(client_server::get_content_thumbnail_v1_route)
|
||||
)
|
||||
.ruma_route(client_server::get_content_route)
|
||||
.ruma_route(client_server::get_content_as_filename_route)
|
||||
.ruma_route(client_server::get_content_thumbnail_route)
|
||||
.ruma_route(client_server::get_devices_route)
|
||||
.ruma_route(client_server::get_device_route)
|
||||
.ruma_route(client_server::update_device_route)
|
||||
.ruma_route(client_server::delete_device_route)
|
||||
.ruma_route(client_server::delete_devices_route)
|
||||
.ruma_route(client_server::get_tags_route)
|
||||
.ruma_route(client_server::update_tag_route)
|
||||
.ruma_route(client_server::delete_tag_route)
|
||||
.ruma_route(client_server::upload_signing_keys_route)
|
||||
.ruma_route(client_server::upload_signatures_route)
|
||||
.ruma_route(client_server::get_key_changes_route)
|
||||
.ruma_route(client_server::get_pushers_route)
|
||||
.ruma_route(client_server::set_pushers_route)
|
||||
// .ruma_route(client_server::third_party_route)
|
||||
.ruma_route(client_server::upgrade_room_route)
|
||||
.ruma_route(client_server::get_threads_route)
|
||||
.ruma_route(client_server::get_relating_events_with_rel_type_and_event_type_route)
|
||||
.ruma_route(client_server::get_relating_events_with_rel_type_route)
|
||||
.ruma_route(client_server::get_relating_events_route)
|
||||
.ruma_route(client_server::get_hierarchy_route)
|
||||
.ruma_route(client_server::get_mutual_rooms_route)
|
||||
.ruma_route(client_server::well_known_support)
|
||||
.ruma_route(client_server::well_known_client)
|
||||
.route("/_conduwuit/server_version", get(client_server::conduwuit_server_version))
|
||||
.route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync))
|
||||
.route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync))
|
||||
.route("/client/server.json", get(client_server::syncv3_client_server_json))
|
||||
.route("/", get(it_works))
|
||||
.fallback(not_found);
|
||||
|
||||
if config.allow_federation {
|
||||
router
|
||||
.ruma_route(server_server::get_server_version_route)
|
||||
.route("/_matrix/key/v2/server", get(server_server::get_server_keys_route))
|
||||
.route(
|
||||
"/_matrix/key/v2/server/:key_id",
|
||||
get(server_server::get_server_keys_deprecated_route),
|
||||
)
|
||||
.ruma_route(server_server::get_public_rooms_route)
|
||||
.ruma_route(server_server::get_public_rooms_filtered_route)
|
||||
.ruma_route(server_server::send_transaction_message_route)
|
||||
.ruma_route(server_server::get_event_route)
|
||||
.ruma_route(server_server::get_backfill_route)
|
||||
.ruma_route(server_server::get_missing_events_route)
|
||||
.ruma_route(server_server::get_event_authorization_route)
|
||||
.ruma_route(server_server::get_room_state_route)
|
||||
.ruma_route(server_server::get_room_state_ids_route)
|
||||
.ruma_route(server_server::create_leave_event_template_route)
|
||||
.ruma_route(server_server::create_leave_event_v1_route)
|
||||
.ruma_route(server_server::create_leave_event_v2_route)
|
||||
.ruma_route(server_server::create_join_event_template_route)
|
||||
.ruma_route(server_server::create_join_event_v1_route)
|
||||
.ruma_route(server_server::create_join_event_v2_route)
|
||||
.ruma_route(server_server::create_invite_route)
|
||||
.ruma_route(server_server::get_devices_route)
|
||||
.ruma_route(server_server::get_room_information_route)
|
||||
.ruma_route(server_server::get_profile_information_route)
|
||||
.ruma_route(server_server::get_keys_route)
|
||||
.ruma_route(server_server::claim_keys_route)
|
||||
.ruma_route(server_server::get_hierarchy_route)
|
||||
.ruma_route(server_server::well_known_server)
|
||||
} else {
|
||||
router
|
||||
.route("/_matrix/federation/*path", any(federation_disabled))
|
||||
.route("/.well-known/matrix/server", any(federation_disabled))
|
||||
.route("/_matrix/key/*path", any(federation_disabled))
|
||||
}
|
||||
}
|
||||
|
||||
async fn not_found(_uri: Uri) -> impl IntoResponse {
|
||||
Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request")
|
||||
}
|
||||
|
||||
async fn initial_sync(_uri: Uri) -> impl IntoResponse {
|
||||
Error::BadRequest(ErrorKind::GuestAccessForbidden, "Guest access not implemented")
|
||||
}
|
||||
|
||||
async fn it_works() -> &'static str { "hewwo from conduwuit woof!" }
|
||||
|
||||
async fn federation_disabled() -> impl IntoResponse { Error::bad_config("Federation is disabled.") }
|
||||
|
||||
trait RouterExt {
|
||||
fn ruma_route<H, T>(self, handler: H) -> Self
|
||||
where
|
||||
H: RumaHandler<T>,
|
||||
T: 'static;
|
||||
}
|
||||
|
||||
impl RouterExt for Router {
|
||||
fn ruma_route<H, T>(self, handler: H) -> Self
|
||||
where
|
||||
H: RumaHandler<T>,
|
||||
T: 'static,
|
||||
{
|
||||
handler.add_to_router(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait RumaHandler<T> {
|
||||
// Can't transform to a handler without boxing or relying on the nightly-only
|
||||
// impl-trait-in-traits feature. Moving a small amount of extra logic into the
|
||||
// trait allows bypassing both.
|
||||
fn add_to_router(self, router: Router) -> Router;
|
||||
}
|
||||
|
||||
macro_rules! impl_ruma_handler {
|
||||
( $($ty:ident),* $(,)? ) => {
|
||||
#[axum::async_trait]
|
||||
#[allow(non_snake_case)]
|
||||
impl<Req, E, F, Fut, $($ty,)*> RumaHandler<($($ty,)* Ruma<Req>,)> for F
|
||||
where
|
||||
Req: IncomingRequest + Send + 'static,
|
||||
F: FnOnce($($ty,)* Ruma<Req>) -> Fut + Clone + Send + 'static,
|
||||
Fut: Future<Output = Result<Req::OutgoingResponse, E>>
|
||||
+ Send,
|
||||
E: IntoResponse,
|
||||
$( $ty: FromRequestParts<()> + Send + 'static, )*
|
||||
{
|
||||
fn add_to_router(self, mut router: Router) -> Router {
|
||||
let meta = Req::METADATA;
|
||||
let method_filter = method_to_filter(meta.method);
|
||||
|
||||
for path in meta.history.all_paths() {
|
||||
let handler = self.clone();
|
||||
|
||||
router = router.route(path, on(method_filter, |$( $ty: $ty, )* req| async move {
|
||||
handler($($ty,)* req).await.map(RumaResponse)
|
||||
}))
|
||||
}
|
||||
|
||||
router
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_ruma_handler!();
|
||||
impl_ruma_handler!(T1);
|
||||
impl_ruma_handler!(T1, T2);
|
||||
impl_ruma_handler!(T1, T2, T3);
|
||||
impl_ruma_handler!(T1, T2, T3, T4);
|
||||
impl_ruma_handler!(T1, T2, T3, T4, T5);
|
||||
impl_ruma_handler!(T1, T2, T3, T4, T5, T6);
|
||||
impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7);
|
||||
impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8);
|
||||
|
||||
fn method_to_filter(method: Method) -> MethodFilter {
|
||||
match method {
|
||||
Method::DELETE => MethodFilter::DELETE,
|
||||
Method::GET => MethodFilter::GET,
|
||||
Method::HEAD => MethodFilter::HEAD,
|
||||
Method::OPTIONS => MethodFilter::OPTIONS,
|
||||
Method::PATCH => MethodFilter::PATCH,
|
||||
Method::POST => MethodFilter::POST,
|
||||
Method::PUT => MethodFilter::PUT,
|
||||
Method::TRACE => MethodFilter::TRACE,
|
||||
m => panic!("Unsupported HTTP method: {m:?}"),
|
||||
}
|
||||
}
|
185
src/router/run.rs
Normal file
185
src/router/run.rs
Normal file
|
@ -0,0 +1,185 @@
|
|||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use axum_server::Handle as ServerHandle;
|
||||
use tokio::{
|
||||
signal,
|
||||
sync::oneshot::{self, Sender},
|
||||
};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
extern crate conduit_admin as admin;
|
||||
extern crate conduit_core as conduit;
|
||||
extern crate conduit_database as database;
|
||||
extern crate conduit_service as service;
|
||||
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use conduit::{debug_info, trace, Error, Result, Server};
|
||||
use database::KeyValueDatabase;
|
||||
use service::{services, Services};
|
||||
|
||||
use crate::{layers, serve};
|
||||
|
||||
/// Main loop base
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn run(server: Arc<Server>) -> Result<(), Error> {
|
||||
let config = &server.config;
|
||||
let app = layers::build(&server)?;
|
||||
let addrs = config.get_bind_addrs();
|
||||
|
||||
// Install the admin room callback here for now
|
||||
_ = services().admin.handle.lock().await.insert(admin::handle);
|
||||
|
||||
// Setup shutdown/signal handling
|
||||
let handle = ServerHandle::new();
|
||||
_ = server
|
||||
.shutdown
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.insert(handle.clone());
|
||||
|
||||
server.interrupt.store(false, Ordering::Release);
|
||||
let (tx, rx) = oneshot::channel::<()>();
|
||||
let sigs = server.runtime().spawn(sighandle(server.clone(), tx));
|
||||
|
||||
// Prepare to serve http clients
|
||||
let res;
|
||||
// Serve clients
|
||||
if cfg!(unix) && config.unix_socket_path.is_some() {
|
||||
res = serve::unix_socket(&server, app, rx).await;
|
||||
} else if config.tls.is_some() {
|
||||
res = serve::tls(&server, app, handle.clone(), addrs).await;
|
||||
} else {
|
||||
res = serve::plain(&server, app, handle.clone(), addrs).await;
|
||||
}
|
||||
|
||||
// Join the signal handler before we leave.
|
||||
sigs.abort();
|
||||
_ = sigs.await;
|
||||
|
||||
// Reset the axum handle instance; this should be reusable and might be
|
||||
// reload-survivable but better to be safe than sorry.
|
||||
_ = server.shutdown.lock().expect("locked").take();
|
||||
|
||||
// Remove the admin room callback
|
||||
_ = services().admin.handle.lock().await.take();
|
||||
|
||||
debug_info!("Finished");
|
||||
Ok(res?)
|
||||
}
|
||||
|
||||
/// Async initializations
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn start(server: Arc<Server>) -> Result<(), Error> {
|
||||
debug!("Starting...");
|
||||
let d = Arc::new(KeyValueDatabase::load_or_create(&server).await?);
|
||||
let s = Box::new(Services::build(server, d.clone()).await?);
|
||||
_ = service::SERVICES
|
||||
.write()
|
||||
.expect("write locked")
|
||||
.insert(Box::leak(s));
|
||||
services().start().await?;
|
||||
|
||||
#[cfg(feature = "systemd")]
|
||||
#[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental
|
||||
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
|
||||
|
||||
debug!("Started");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Async destructions
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn stop(_server: Arc<Server>) -> Result<(), Error> {
|
||||
debug!("Shutting down...");
|
||||
|
||||
// Wait for all completions before dropping or we'll lose them to the module
|
||||
// unload and explode.
|
||||
services().shutdown().await;
|
||||
|
||||
// Deactivate services(). Any further use will panic the caller.
|
||||
let s = service::SERVICES
|
||||
.write()
|
||||
.expect("write locked")
|
||||
.take()
|
||||
.unwrap();
|
||||
|
||||
let s = std::ptr::from_ref(s) as *mut Services;
|
||||
//SAFETY: Services was instantiated in start() and leaked into the SERVICES
|
||||
// global perusing as 'static for the duration of run_server(). Now we reclaim
|
||||
// it to drop it before unloading the module. If this is not done there will be
|
||||
// multiple instances after module reload.
|
||||
let s = unsafe { Box::from_raw(s) };
|
||||
debug!("Cleaning up...");
|
||||
// Drop it so we encounter any trouble before the infolog message
|
||||
drop(s);
|
||||
|
||||
#[cfg(feature = "systemd")]
|
||||
#[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental
|
||||
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]);
|
||||
|
||||
info!("Shutdown complete.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn sighandle(server: Arc<Server>, tx: Sender<()>) -> Result<(), Error> {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
|
||||
let reload = cfg!(unix) && cfg!(debug_assertions);
|
||||
server.reload.store(reload, Ordering::Release);
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let ctrl_bs = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::quit())
|
||||
.expect("failed to install Ctrl+\\ handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install SIGTERM handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
debug!("Installed signal handlers");
|
||||
let sig: &str;
|
||||
#[cfg(unix)]
|
||||
tokio::select! {
|
||||
() = ctrl_c => { sig = "Ctrl+C"; },
|
||||
() = ctrl_bs => { sig = "Ctrl+\\"; },
|
||||
() = terminate => { sig = "SIGTERM"; },
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
tokio::select! {
|
||||
_ = ctrl_c => { sig = "Ctrl+C"; },
|
||||
}
|
||||
|
||||
warn!("Received {}", sig);
|
||||
server.interrupt.store(true, Ordering::Release);
|
||||
services().globals.rotate.fire();
|
||||
tx.send(())
|
||||
.expect("failed sending shutdown transaction to oneshot channel");
|
||||
|
||||
if let Some(handle) = server.shutdown.lock().expect("locked").as_ref() {
|
||||
let pending = server.requests_spawn_active.load(Ordering::Relaxed);
|
||||
if pending > 0 {
|
||||
let timeout = Duration::from_secs(36);
|
||||
trace!(pending, ?timeout, "Notifying for graceful shutdown");
|
||||
handle.graceful_shutdown(Some(timeout));
|
||||
} else {
|
||||
debug!(pending, "Notifying for immediate shutdown");
|
||||
handle.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
137
src/router/serve.rs
Normal file
137
src/router/serve.rs
Normal file
|
@ -0,0 +1,137 @@
|
|||
#[cfg(unix)]
|
||||
use std::fs::Permissions; // only for UNIX sockets stuff and *nix container checks
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::PermissionsExt as _;
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
sync::{atomic::Ordering, Arc},
|
||||
};
|
||||
|
||||
use axum::Router;
|
||||
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
|
||||
#[cfg(feature = "axum_dual_protocol")]
|
||||
use axum_server_dual_protocol::ServerExt;
|
||||
use conduit::{debug_info, Server};
|
||||
use tokio::{
|
||||
sync::oneshot::{self},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
pub(crate) async fn plain(
|
||||
server: &Arc<Server>, app: axum::routing::IntoMakeService<Router>, handle: ServerHandle, addrs: Vec<SocketAddr>,
|
||||
) -> io::Result<()> {
|
||||
let mut join_set = JoinSet::new();
|
||||
for addr in &addrs {
|
||||
join_set.spawn_on(bind(*addr).handle(handle.clone()).serve(app.clone()), server.runtime());
|
||||
}
|
||||
|
||||
info!("Listening on {addrs:?}");
|
||||
while join_set.join_next().await.is_some() {}
|
||||
|
||||
let spawn_active = server.requests_spawn_active.load(Ordering::Relaxed);
|
||||
let handle_active = server.requests_handle_active.load(Ordering::Relaxed);
|
||||
debug_info!(
|
||||
spawn_finished = server.requests_spawn_finished.load(Ordering::Relaxed),
|
||||
handle_finished = server.requests_handle_finished.load(Ordering::Relaxed),
|
||||
panics = server.requests_panic.load(Ordering::Relaxed),
|
||||
spawn_active,
|
||||
handle_active,
|
||||
"Stopped listening on {addrs:?}",
|
||||
);
|
||||
|
||||
debug_assert!(spawn_active == 0, "active request tasks are not joined");
|
||||
debug_assert!(handle_active == 0, "active request handles still pending");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn tls(
|
||||
server: &Arc<Server>, app: axum::routing::IntoMakeService<Router>, handle: ServerHandle, addrs: Vec<SocketAddr>,
|
||||
) -> io::Result<()> {
|
||||
let config = &server.config;
|
||||
let tls = config.tls.as_ref().expect("TLS configuration");
|
||||
|
||||
debug!(
|
||||
"Using direct TLS. Certificate path {} and certificate private key path {}",
|
||||
&tls.certs, &tls.key
|
||||
);
|
||||
info!(
|
||||
"Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit directly with TLS."
|
||||
);
|
||||
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
|
||||
|
||||
if cfg!(feature = "axum_dual_protocol") {
|
||||
info!(
|
||||
"conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This will only \
|
||||
take effect if `dual_protocol` is enabled in `[global.tls]`"
|
||||
);
|
||||
}
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol {
|
||||
#[cfg(feature = "axum_dual_protocol")]
|
||||
for addr in &addrs {
|
||||
join_set.spawn_on(
|
||||
axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone())
|
||||
.set_upgrade(false)
|
||||
.handle(handle.clone())
|
||||
.serve(app.clone()),
|
||||
server.runtime(),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
for addr in &addrs {
|
||||
join_set.spawn_on(
|
||||
bind_rustls(*addr, conf.clone())
|
||||
.handle(handle.clone())
|
||||
.serve(app.clone()),
|
||||
server.runtime(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol {
|
||||
warn!(
|
||||
"Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too (insecure!)",
|
||||
addrs, &tls.certs
|
||||
);
|
||||
} else {
|
||||
info!("Listening on {:?} with TLS certificate {}", addrs, &tls.certs);
|
||||
}
|
||||
|
||||
while join_set.join_next().await.is_some() {}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) async fn unix_socket(
|
||||
server: &Arc<Server>, app: axum::routing::IntoMakeService<Router>, rx: oneshot::Receiver<()>,
|
||||
) -> io::Result<()> {
|
||||
let config = &server.config;
|
||||
let path = config.unix_socket_path.as_ref().unwrap();
|
||||
|
||||
if path.exists() {
|
||||
warn!(
|
||||
"UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.",
|
||||
path.display()
|
||||
);
|
||||
tokio::fs::remove_file(&path).await?;
|
||||
}
|
||||
|
||||
tokio::fs::create_dir_all(path.parent().unwrap()).await?;
|
||||
|
||||
let socket_perms = config.unix_socket_perms.to_string();
|
||||
let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap();
|
||||
tokio::fs::set_permissions(&path, Permissions::from_mode(octal_perms))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let bind = tokio::net::UnixListener::bind(path)?;
|
||||
info!("Listening at {:?}", path);
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue