Hot-Reloading Refactor

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-05-09 15:59:08 -07:00 committed by June 🍓🦴
parent ae1a4fd283
commit 6c1434c165
212 changed files with 5679 additions and 4206 deletions

85
src/router/Cargo.toml Normal file
View file

@ -0,0 +1,85 @@
[package]
name = "conduit_router"
version.workspace = true
edition.workspace = true
[lib]
path = "mod.rs"
crate-type = [
"rlib",
# "dylib",
]
[features]
default = [
"systemd",
"sentry_telemetry",
"gzip_compression",
"zstd_compression",
"brotli_compression",
"release_max_log_level",
]
dev_release_log_level = []
release_max_log_level = [
"tracing/max_level_trace",
"tracing/release_max_level_info",
"log/max_level_trace",
"log/release_max_level_info",
]
sentry_telemetry = [
"dep:sentry",
"dep:sentry-tracing",
"dep:sentry-tower",
]
zstd_compression = [
"tower-http/compression-zstd",
]
gzip_compression = [
"tower-http/compression-gzip",
]
brotli_compression = [
"tower-http/compression-br",
]
systemd = [
"dep:sd-notify",
]
axum_dual_protocol = [
"dep:axum-server-dual-protocol"
]
[dependencies]
axum-server-dual-protocol.optional = true
axum-server-dual-protocol.workspace = true
axum-server.workspace = true
axum.workspace = true
conduit-admin.workspace = true
conduit-api.workspace = true
conduit-core.workspace = true
conduit-database.workspace = true
conduit-service.workspace = true
log.workspace = true
tokio.workspace = true
tower.workspace = true
tracing.workspace = true
bytes.workspace = true
clap.workspace = true
http-body-util.workspace = true
http.workspace = true
regex.workspace = true
ruma.workspace = true
sentry.optional = true
sentry-tower.optional = true
sentry-tower.workspace = true
sentry-tracing.optional = true
sentry-tracing.workspace = true
sentry.workspace = true
serde_json.workspace = true
tower-http.workspace = true
[target.'cfg(unix)'.dependencies]
sd-notify.workspace = true
sd-notify.optional = true
[lints]
workspace = true

190
src/router/layers.rs Normal file
View file

@ -0,0 +1,190 @@
use std::{any::Any, io, sync::Arc, time::Duration};
use axum::{
extract::{DefaultBodyLimit, MatchedPath},
Router,
};
use conduit::Server;
use http::{
header::{self, HeaderName},
HeaderValue, Method, StatusCode,
};
use tower::ServiceBuilder;
use tower_http::{
catch_panic::CatchPanicLayer,
cors::{self, CorsLayer},
set_header::SetResponseHeaderLayer,
trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer},
ServiceBuilderExt as _,
};
use tracing::Level;
use crate::{request, router};
pub(crate) fn build(server: &Arc<Server>) -> io::Result<axum::routing::IntoMakeService<Router>> {
let layers = ServiceBuilder::new();
#[cfg(feature = "sentry_telemetry")]
let layers = layers.layer(sentry_tower::NewSentryLayer::<http::Request<_>>::new_from_top());
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
let layers = layers.layer(compression_layer(server));
let layers = layers
.sensitive_headers([header::AUTHORIZATION])
.sensitive_request_headers([HeaderName::from_static("x-forwarded-for")].into())
.layer(axum::middleware::from_fn_with_state(Arc::clone(server), request::spawn))
.layer(
TraceLayer::new_for_http()
.make_span_with(tracing_span::<_>)
.on_failure(DefaultOnFailure::new().level(Level::ERROR))
.on_request(DefaultOnRequest::new().level(Level::TRACE))
.on_response(DefaultOnResponse::new().level(Level::DEBUG)),
)
.layer(axum::middleware::from_fn_with_state(Arc::clone(server), request::handle))
.layer(SetResponseHeaderLayer::if_not_present(
HeaderName::from_static("origin-agent-cluster"), // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin-Agent-Cluster
HeaderValue::from_static("?1"),
))
.layer(SetResponseHeaderLayer::if_not_present(
header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
))
.layer(SetResponseHeaderLayer::if_not_present(
header::X_XSS_PROTECTION,
HeaderValue::from_static("0"),
))
.layer(SetResponseHeaderLayer::if_not_present(
header::X_FRAME_OPTIONS,
HeaderValue::from_static("DENY"),
))
.layer(SetResponseHeaderLayer::if_not_present(
HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("interest-cohort=(),browsing-topics=()"),
))
.layer(SetResponseHeaderLayer::if_not_present(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_static(
"sandbox; default-src 'none'; font-src 'none'; script-src 'none'; plugin-types application/pdf; \
style-src 'unsafe-inline'; object-src 'self'; frame-ancesors 'none';",
),
))
.layer(cors_layer(server))
.layer(body_limit_layer(server))
.layer(CatchPanicLayer::custom(catch_panic));
let routes = router::build(server);
let layers = routes.layer(layers);
Ok(layers.into_make_service())
}
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer {
let mut compression_layer = tower_http::compression::CompressionLayer::new();
#[cfg(feature = "zstd_compression")]
{
if server.config.zstd_compression {
compression_layer = compression_layer.zstd(true);
} else {
compression_layer = compression_layer.no_zstd();
};
};
#[cfg(feature = "gzip_compression")]
{
if server.config.gzip_compression {
compression_layer = compression_layer.gzip(true);
} else {
compression_layer = compression_layer.no_gzip();
};
};
#[cfg(feature = "brotli_compression")]
{
if server.config.brotli_compression {
compression_layer = compression_layer.br(true);
} else {
compression_layer = compression_layer.no_br();
};
};
compression_layer
}
fn cors_layer(_server: &Server) -> CorsLayer {
const METHODS: [Method; 7] = [
Method::GET,
Method::HEAD,
Method::PATCH,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
];
let headers: [HeaderName; 5] = [
header::ORIGIN,
HeaderName::from_lowercase(b"x-requested-with").unwrap(),
header::CONTENT_TYPE,
header::ACCEPT,
header::AUTHORIZATION,
];
CorsLayer::new()
.allow_origin(cors::Any)
.allow_methods(METHODS)
.allow_headers(headers)
.max_age(Duration::from_secs(86400))
}
fn body_limit_layer(server: &Server) -> DefaultBodyLimit {
DefaultBodyLimit::max(
server
.config
.max_request_size
.try_into()
.expect("failed to convert max request size"),
)
}
#[allow(clippy::needless_pass_by_value)]
#[tracing::instrument(skip_all)]
fn catch_panic(err: Box<dyn Any + Send + 'static>) -> http::Response<http_body_util::Full<bytes::Bytes>> {
conduit_service::services()
.server
.requests_panic
.fetch_add(1, std::sync::atomic::Ordering::Release);
let details = if let Some(s) = err.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = err.downcast_ref::<&str>() {
s.to_string()
} else {
"Unknown internal server error occurred.".to_owned()
};
let body = serde_json::json!({
"errcode": "M_UNKNOWN",
"error": "M_UNKNOWN: Internal server error occurred",
"details": details,
})
.to_string();
http::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, "application/json")
.body(http_body_util::Full::from(body))
.expect("Failed to create response for our panic catcher?")
}
fn tracing_span<T>(request: &http::Request<T>) -> tracing::Span {
let path = if let Some(path) = request.extensions().get::<MatchedPath>() {
path.as_str()
} else {
request.uri().path()
};
tracing::info_span!("router:", %path)
}

View file

@ -1,258 +1,29 @@
use std::{any::Any, io, sync::atomic, time::Duration};
pub(crate) mod layers;
pub(crate) mod request;
pub(crate) mod router;
pub(crate) mod run;
pub(crate) mod serve;
use axum::{
extract::{DefaultBodyLimit, MatchedPath},
response::IntoResponse,
Router,
};
use http::{
header::{self, HeaderName, HeaderValue},
Method, StatusCode, Uri,
};
use ruma::api::client::{
error::{Error as RumaError, ErrorBody, ErrorKind},
uiaa::UiaaResponse,
};
use tower::ServiceBuilder;
use tower_http::{
catch_panic::CatchPanicLayer,
cors::{self, CorsLayer},
set_header::SetResponseHeaderLayer,
trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer},
ServiceBuilderExt as _,
};
use tracing::{debug, error, trace, Level};
extern crate conduit_core as conduit;
use super::{api::ruma_wrapper::RumaResponse, debug_error, services, utils::error::Result, Server};
use std::{future::Future, pin::Pin, sync::Arc};
mod routes;
use conduit::{Result, Server};
pub(crate) async fn build(server: &Server) -> io::Result<axum::routing::IntoMakeService<Router>> {
let base_middlewares = ServiceBuilder::new();
#[cfg(feature = "sentry_telemetry")]
let base_middlewares = base_middlewares.layer(sentry_tower::NewSentryLayer::<http::Request<_>>::new_from_top());
conduit::mod_ctor! {}
conduit::mod_dtor! {}
let x_forwarded_for = HeaderName::from_static("x-forwarded-for");
let permissions_policy = HeaderName::from_static("permissions-policy");
let origin_agent_cluster = HeaderName::from_static("origin-agent-cluster"); // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin-Agent-Cluster
let middlewares = base_middlewares
.sensitive_headers([header::AUTHORIZATION])
.sensitive_request_headers([x_forwarded_for].into())
.layer(axum::middleware::from_fn(request_spawn))
.layer(
TraceLayer::new_for_http()
.make_span_with(tracing_span::<_>)
.on_failure(DefaultOnFailure::new().level(Level::ERROR))
.on_request(DefaultOnRequest::new().level(Level::TRACE))
.on_response(DefaultOnResponse::new().level(Level::DEBUG)),
)
.layer(axum::middleware::from_fn(request_handle))
.layer(SetResponseHeaderLayer::if_not_present(
origin_agent_cluster,
HeaderValue::from_static("?1"),
))
.layer(SetResponseHeaderLayer::if_not_present(
header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
))
.layer(SetResponseHeaderLayer::if_not_present(
header::X_XSS_PROTECTION,
HeaderValue::from_static("0"),
))
.layer(SetResponseHeaderLayer::if_not_present(
header::X_FRAME_OPTIONS,
HeaderValue::from_static("DENY"),
))
.layer(SetResponseHeaderLayer::if_not_present(
permissions_policy,
HeaderValue::from_static("interest-cohort=(),browsing-topics=()"),
))
.layer(SetResponseHeaderLayer::if_not_present(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_static(
"sandbox; default-src 'none'; font-src 'none'; script-src 'none'; plugin-types application/pdf; \
style-src 'unsafe-inline'; object-src 'self'; frame-ancesors 'none';",
),
))
.layer(cors_layer(server))
.layer(DefaultBodyLimit::max(
server
.config
.max_request_size
.try_into()
.expect("failed to convert max request size"),
))
.layer(CatchPanicLayer::custom(catch_panic_layer));
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
{
Ok(routes::routes(&server.config)
.layer(compression_layer(server))
.layer(middlewares)
.into_make_service())
}
#[cfg(not(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression")))]
{
Ok(routes::routes().layer(middlewares).into_make_service())
}
#[no_mangle]
pub extern "Rust" fn start(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>>>> {
Box::pin(run::start(server.clone()))
}
#[tracing::instrument(skip_all, name = "spawn")]
async fn request_spawn(
req: http::Request<axum::body::Body>, next: axum::middleware::Next,
) -> Result<axum::response::Response, StatusCode> {
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
let fut = next.run(req);
let task = tokio::spawn(fut);
task.await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
#[no_mangle]
pub extern "Rust" fn stop(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>>>> {
Box::pin(run::stop(server.clone()))
}
#[tracing::instrument(skip_all, name = "handle")]
async fn request_handle(
req: http::Request<axum::body::Body>, next: axum::middleware::Next,
) -> Result<axum::response::Response, StatusCode> {
let method = req.method().clone();
let uri = req.uri().clone();
let result = next.run(req).await;
request_result(&method, &uri, result)
}
fn request_result(
method: &Method, uri: &Uri, result: axum::response::Response,
) -> Result<axum::response::Response, StatusCode> {
request_result_log(method, uri, &result);
match result.status() {
StatusCode::METHOD_NOT_ALLOWED => request_result_403(method, uri, &result),
_ => Ok(result),
}
}
#[allow(clippy::unnecessary_wraps)]
fn request_result_403(
_method: &Method, _uri: &Uri, result: &axum::response::Response,
) -> Result<axum::response::Response, StatusCode> {
let error = UiaaResponse::MatrixError(RumaError {
status_code: result.status(),
body: ErrorBody::Standard {
kind: ErrorKind::Unrecognized,
message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(),
},
});
Ok(RumaResponse(error).into_response())
}
fn request_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) {
let status = result.status();
let reason = status.canonical_reason().unwrap_or("Unknown Reason");
let code = status.as_u16();
if status.is_server_error() {
error!(method = ?method, uri = ?uri, "{code} {reason}");
} else if status.is_client_error() {
debug_error!(method = ?method, uri = ?uri, "{code} {reason}");
} else if status.is_redirection() {
debug!(method = ?method, uri = ?uri, "{code} {reason}");
} else {
trace!(method = ?method, uri = ?uri, "{code} {reason}");
}
}
/// Cross-Origin-Resource-Sharing header as defined by spec:
/// <https://spec.matrix.org/latest/client-server-api/#web-browser-clients>
fn cors_layer(_server: &Server) -> CorsLayer {
const METHODS: [Method; 7] = [
Method::GET,
Method::HEAD,
Method::PATCH,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
];
let headers: [HeaderName; 5] = [
header::ORIGIN,
HeaderName::from_lowercase(b"x-requested-with").unwrap(),
header::CONTENT_TYPE,
header::ACCEPT,
header::AUTHORIZATION,
];
CorsLayer::new()
.allow_origin(cors::Any)
.allow_methods(METHODS)
.allow_headers(headers)
.max_age(Duration::from_secs(86400))
}
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer {
let mut compression_layer = tower_http::compression::CompressionLayer::new();
#[cfg(feature = "zstd_compression")]
{
if server.config.zstd_compression {
compression_layer = compression_layer.zstd(true);
} else {
compression_layer = compression_layer.no_zstd();
};
};
#[cfg(feature = "gzip_compression")]
{
if server.config.gzip_compression {
compression_layer = compression_layer.gzip(true);
} else {
compression_layer = compression_layer.no_gzip();
};
};
#[cfg(feature = "brotli_compression")]
{
if server.config.brotli_compression {
compression_layer = compression_layer.br(true);
} else {
compression_layer = compression_layer.no_br();
};
};
compression_layer
}
fn tracing_span<T>(request: &http::Request<T>) -> tracing::Span {
let path = if let Some(path) = request.extensions().get::<MatchedPath>() {
path.as_str()
} else {
request.uri().path()
};
tracing::info_span!("router:", %path)
}
#[allow(clippy::needless_pass_by_value)]
fn catch_panic_layer(err: Box<dyn Any + Send + 'static>) -> http::Response<http_body_util::Full<bytes::Bytes>> {
let details = if let Some(s) = err.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = err.downcast_ref::<&str>() {
s.to_string()
} else {
"Unknown internal server error occurred.".to_owned()
};
let body = serde_json::json!({
"errcode": "M_UNKNOWN",
"error": "M_UNKNOWN: Internal server error occurred",
"details": details,
})
.to_string();
http::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, "application/json")
.body(http_body_util::Full::from(body))
.expect("Failed to create response for our panic catcher?")
#[no_mangle]
pub extern "Rust" fn run(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>>>> {
Box::pin(run::run(server.clone()))
}

102
src/router/request.rs Normal file
View file

@ -0,0 +1,102 @@
use std::sync::{atomic::Ordering, Arc};
use axum::{extract::State, response::IntoResponse};
use conduit::{debug_error, debug_warn, defer, Result, RumaResponse, Server};
use http::{Method, StatusCode, Uri};
use ruma::api::client::{
error::{Error as RumaError, ErrorBody, ErrorKind},
uiaa::UiaaResponse,
};
use tracing::{debug, error, trace};
#[tracing::instrument(skip_all)]
pub(crate) async fn spawn(
State(server): State<Arc<Server>>, req: http::Request<axum::body::Body>, next: axum::middleware::Next,
) -> Result<axum::response::Response, StatusCode> {
if server.interrupt.load(Ordering::Relaxed) {
debug_warn!("unavailable pending shutdown");
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
let active = server.requests_spawn_active.fetch_add(1, Ordering::Relaxed);
trace!(active, "enter");
defer! {{
let active = server.requests_spawn_active.fetch_sub(1, Ordering::Relaxed);
let finished = server.requests_spawn_finished.fetch_add(1, Ordering::Relaxed);
trace!(active, finished, "leave");
}};
let fut = next.run(req);
let task = server.runtime().spawn(fut);
task.await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
#[tracing::instrument(skip_all, name = "handle")]
pub(crate) async fn handle(
State(server): State<Arc<Server>>, req: http::Request<axum::body::Body>, next: axum::middleware::Next,
) -> Result<axum::response::Response, StatusCode> {
if server.interrupt.load(Ordering::Relaxed) {
debug_warn!(
method = %req.method(),
uri = %req.uri(),
"unavailable pending shutdown"
);
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
let active = server
.requests_handle_active
.fetch_add(1, Ordering::Relaxed);
trace!(active, "enter");
defer! {{
let active = server.requests_handle_active.fetch_sub(1, Ordering::Relaxed);
let finished = server.requests_handle_finished.fetch_add(1, Ordering::Relaxed);
trace!(active, finished, "leave");
}};
let method = req.method().clone();
let uri = req.uri().clone();
let result = next.run(req).await;
handle_result(&method, &uri, result)
}
fn handle_result(
method: &Method, uri: &Uri, result: axum::response::Response,
) -> Result<axum::response::Response, StatusCode> {
handle_result_log(method, uri, &result);
match result.status() {
StatusCode::METHOD_NOT_ALLOWED => handle_result_403(method, uri, &result),
_ => Ok(result),
}
}
#[allow(clippy::unnecessary_wraps)]
fn handle_result_403(
_method: &Method, _uri: &Uri, result: &axum::response::Response,
) -> Result<axum::response::Response, StatusCode> {
let error = UiaaResponse::MatrixError(RumaError {
status_code: result.status(),
body: ErrorBody::Standard {
kind: ErrorKind::Unrecognized,
message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(),
},
});
Ok(RumaResponse(error).into_response())
}
fn handle_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) {
let status = result.status();
let reason = status.canonical_reason().unwrap_or("Unknown Reason");
let code = status.as_u16();
if status.is_server_error() {
error!(method = ?method, uri = ?uri, "{code} {reason}");
} else if status.is_client_error() {
debug_error!(method = ?method, uri = ?uri, "{code} {reason}");
} else if status.is_redirection() {
debug!(method = ?method, uri = ?uri, "{code} {reason}");
} else {
trace!(method = ?method, uri = ?uri, "{code} {reason}");
}
}

20
src/router/router.rs Normal file
View file

@ -0,0 +1,20 @@
use std::sync::Arc;
use axum::{response::IntoResponse, routing::get, Router};
use conduit::{Error, Server};
use http::Uri;
use ruma::api::client::error::ErrorKind;
extern crate conduit_api as api;
pub(crate) fn build(server: &Arc<Server>) -> Router {
let router = Router::new().fallback(not_found).route("/", get(it_works));
api::router::build(router, server)
}
async fn not_found(_uri: Uri) -> impl IntoResponse {
Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request")
}
async fn it_works() -> &'static str { "hewwo from conduwuit woof!" }

View file

@ -1,322 +0,0 @@
use std::future::Future;
use axum::{
extract::FromRequestParts,
response::IntoResponse,
routing::{any, get, on, post, MethodFilter},
Router,
};
use http::{Method, Uri};
use ruma::api::{client::error::ErrorKind, IncomingRequest};
use crate::{
api::{client_server, server_server},
Config, Error, Result, Ruma, RumaResponse,
};
pub(crate) fn routes(config: &Config) -> Router {
let router = Router::new()
.ruma_route(client_server::get_supported_versions_route)
.ruma_route(client_server::get_register_available_route)
.ruma_route(client_server::register_route)
.ruma_route(client_server::get_login_types_route)
.ruma_route(client_server::login_route)
.ruma_route(client_server::whoami_route)
.ruma_route(client_server::logout_route)
.ruma_route(client_server::logout_all_route)
.ruma_route(client_server::change_password_route)
.ruma_route(client_server::deactivate_route)
.ruma_route(client_server::third_party_route)
.ruma_route(client_server::request_3pid_management_token_via_email_route)
.ruma_route(client_server::request_3pid_management_token_via_msisdn_route)
.ruma_route(client_server::get_capabilities_route)
.ruma_route(client_server::get_pushrules_all_route)
.ruma_route(client_server::set_pushrule_route)
.ruma_route(client_server::get_pushrule_route)
.ruma_route(client_server::set_pushrule_enabled_route)
.ruma_route(client_server::get_pushrule_enabled_route)
.ruma_route(client_server::get_pushrule_actions_route)
.ruma_route(client_server::set_pushrule_actions_route)
.ruma_route(client_server::delete_pushrule_route)
.ruma_route(client_server::get_room_event_route)
.ruma_route(client_server::get_room_aliases_route)
.ruma_route(client_server::get_filter_route)
.ruma_route(client_server::create_filter_route)
.ruma_route(client_server::set_global_account_data_route)
.ruma_route(client_server::set_room_account_data_route)
.ruma_route(client_server::get_global_account_data_route)
.ruma_route(client_server::get_room_account_data_route)
.ruma_route(client_server::set_displayname_route)
.ruma_route(client_server::get_displayname_route)
.ruma_route(client_server::set_avatar_url_route)
.ruma_route(client_server::get_avatar_url_route)
.ruma_route(client_server::get_profile_route)
.ruma_route(client_server::set_presence_route)
.ruma_route(client_server::get_presence_route)
.ruma_route(client_server::upload_keys_route)
.ruma_route(client_server::get_keys_route)
.ruma_route(client_server::claim_keys_route)
.ruma_route(client_server::create_backup_version_route)
.ruma_route(client_server::update_backup_version_route)
.ruma_route(client_server::delete_backup_version_route)
.ruma_route(client_server::get_latest_backup_info_route)
.ruma_route(client_server::get_backup_info_route)
.ruma_route(client_server::add_backup_keys_route)
.ruma_route(client_server::add_backup_keys_for_room_route)
.ruma_route(client_server::add_backup_keys_for_session_route)
.ruma_route(client_server::delete_backup_keys_for_room_route)
.ruma_route(client_server::delete_backup_keys_for_session_route)
.ruma_route(client_server::delete_backup_keys_route)
.ruma_route(client_server::get_backup_keys_for_room_route)
.ruma_route(client_server::get_backup_keys_for_session_route)
.ruma_route(client_server::get_backup_keys_route)
.ruma_route(client_server::set_read_marker_route)
.ruma_route(client_server::create_receipt_route)
.ruma_route(client_server::create_typing_event_route)
.ruma_route(client_server::create_room_route)
.ruma_route(client_server::redact_event_route)
.ruma_route(client_server::report_event_route)
.ruma_route(client_server::create_alias_route)
.ruma_route(client_server::delete_alias_route)
.ruma_route(client_server::get_alias_route)
.ruma_route(client_server::join_room_by_id_route)
.ruma_route(client_server::join_room_by_id_or_alias_route)
.ruma_route(client_server::joined_members_route)
.ruma_route(client_server::leave_room_route)
.ruma_route(client_server::forget_room_route)
.ruma_route(client_server::joined_rooms_route)
.ruma_route(client_server::kick_user_route)
.ruma_route(client_server::ban_user_route)
.ruma_route(client_server::unban_user_route)
.ruma_route(client_server::invite_user_route)
.ruma_route(client_server::set_room_visibility_route)
.ruma_route(client_server::get_room_visibility_route)
.ruma_route(client_server::get_public_rooms_route)
.ruma_route(client_server::get_public_rooms_filtered_route)
.ruma_route(client_server::search_users_route)
.ruma_route(client_server::get_member_events_route)
.ruma_route(client_server::get_protocols_route)
.ruma_route(client_server::send_message_event_route)
.ruma_route(client_server::send_state_event_for_key_route)
.ruma_route(client_server::get_state_events_route)
.ruma_route(client_server::get_state_events_for_key_route)
// Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes
// share one Ruma request / response type pair with {get,send}_state_event_for_key_route
.route(
"/_matrix/client/r0/rooms/:room_id/state/:event_type",
get(client_server::get_state_events_for_empty_key_route)
.put(client_server::send_state_event_for_empty_key_route),
)
.route(
"/_matrix/client/v3/rooms/:room_id/state/:event_type",
get(client_server::get_state_events_for_empty_key_route)
.put(client_server::send_state_event_for_empty_key_route),
)
// These two endpoints allow trailing slashes
.route(
"/_matrix/client/r0/rooms/:room_id/state/:event_type/",
get(client_server::get_state_events_for_empty_key_route)
.put(client_server::send_state_event_for_empty_key_route),
)
.route(
"/_matrix/client/v3/rooms/:room_id/state/:event_type/",
get(client_server::get_state_events_for_empty_key_route)
.put(client_server::send_state_event_for_empty_key_route),
)
.ruma_route(client_server::sync_events_route)
.ruma_route(client_server::sync_events_v4_route)
.ruma_route(client_server::get_context_route)
.ruma_route(client_server::get_message_events_route)
.ruma_route(client_server::search_events_route)
.ruma_route(client_server::turn_server_route)
.ruma_route(client_server::send_event_to_device_route)
.ruma_route(client_server::get_media_config_route)
.ruma_route(client_server::get_media_preview_route)
.ruma_route(client_server::create_content_route)
// legacy v1 media routes
.route(
"/_matrix/media/v1/preview_url",
get(client_server::get_media_preview_v1_route)
)
.route(
"/_matrix/media/v1/config",
get(client_server::get_media_config_v1_route)
)
.route(
"/_matrix/media/v1/upload",
post(client_server::create_content_v1_route)
)
.route(
"/_matrix/media/v1/download/:server_name/:media_id",
get(client_server::get_content_v1_route)
)
.route(
"/_matrix/media/v1/download/:server_name/:media_id/:file_name",
get(client_server::get_content_as_filename_v1_route)
)
.route(
"/_matrix/media/v1/thumbnail/:server_name/:media_id",
get(client_server::get_content_thumbnail_v1_route)
)
.ruma_route(client_server::get_content_route)
.ruma_route(client_server::get_content_as_filename_route)
.ruma_route(client_server::get_content_thumbnail_route)
.ruma_route(client_server::get_devices_route)
.ruma_route(client_server::get_device_route)
.ruma_route(client_server::update_device_route)
.ruma_route(client_server::delete_device_route)
.ruma_route(client_server::delete_devices_route)
.ruma_route(client_server::get_tags_route)
.ruma_route(client_server::update_tag_route)
.ruma_route(client_server::delete_tag_route)
.ruma_route(client_server::upload_signing_keys_route)
.ruma_route(client_server::upload_signatures_route)
.ruma_route(client_server::get_key_changes_route)
.ruma_route(client_server::get_pushers_route)
.ruma_route(client_server::set_pushers_route)
// .ruma_route(client_server::third_party_route)
.ruma_route(client_server::upgrade_room_route)
.ruma_route(client_server::get_threads_route)
.ruma_route(client_server::get_relating_events_with_rel_type_and_event_type_route)
.ruma_route(client_server::get_relating_events_with_rel_type_route)
.ruma_route(client_server::get_relating_events_route)
.ruma_route(client_server::get_hierarchy_route)
.ruma_route(client_server::get_mutual_rooms_route)
.ruma_route(client_server::well_known_support)
.ruma_route(client_server::well_known_client)
.route("/_conduwuit/server_version", get(client_server::conduwuit_server_version))
.route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync))
.route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync))
.route("/client/server.json", get(client_server::syncv3_client_server_json))
.route("/", get(it_works))
.fallback(not_found);
if config.allow_federation {
router
.ruma_route(server_server::get_server_version_route)
.route("/_matrix/key/v2/server", get(server_server::get_server_keys_route))
.route(
"/_matrix/key/v2/server/:key_id",
get(server_server::get_server_keys_deprecated_route),
)
.ruma_route(server_server::get_public_rooms_route)
.ruma_route(server_server::get_public_rooms_filtered_route)
.ruma_route(server_server::send_transaction_message_route)
.ruma_route(server_server::get_event_route)
.ruma_route(server_server::get_backfill_route)
.ruma_route(server_server::get_missing_events_route)
.ruma_route(server_server::get_event_authorization_route)
.ruma_route(server_server::get_room_state_route)
.ruma_route(server_server::get_room_state_ids_route)
.ruma_route(server_server::create_leave_event_template_route)
.ruma_route(server_server::create_leave_event_v1_route)
.ruma_route(server_server::create_leave_event_v2_route)
.ruma_route(server_server::create_join_event_template_route)
.ruma_route(server_server::create_join_event_v1_route)
.ruma_route(server_server::create_join_event_v2_route)
.ruma_route(server_server::create_invite_route)
.ruma_route(server_server::get_devices_route)
.ruma_route(server_server::get_room_information_route)
.ruma_route(server_server::get_profile_information_route)
.ruma_route(server_server::get_keys_route)
.ruma_route(server_server::claim_keys_route)
.ruma_route(server_server::get_hierarchy_route)
.ruma_route(server_server::well_known_server)
} else {
router
.route("/_matrix/federation/*path", any(federation_disabled))
.route("/.well-known/matrix/server", any(federation_disabled))
.route("/_matrix/key/*path", any(federation_disabled))
}
}
async fn not_found(_uri: Uri) -> impl IntoResponse {
Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request")
}
async fn initial_sync(_uri: Uri) -> impl IntoResponse {
Error::BadRequest(ErrorKind::GuestAccessForbidden, "Guest access not implemented")
}
async fn it_works() -> &'static str { "hewwo from conduwuit woof!" }
async fn federation_disabled() -> impl IntoResponse { Error::bad_config("Federation is disabled.") }
trait RouterExt {
fn ruma_route<H, T>(self, handler: H) -> Self
where
H: RumaHandler<T>,
T: 'static;
}
impl RouterExt for Router {
fn ruma_route<H, T>(self, handler: H) -> Self
where
H: RumaHandler<T>,
T: 'static,
{
handler.add_to_router(self)
}
}
pub(crate) trait RumaHandler<T> {
// Can't transform to a handler without boxing or relying on the nightly-only
// impl-trait-in-traits feature. Moving a small amount of extra logic into the
// trait allows bypassing both.
fn add_to_router(self, router: Router) -> Router;
}
macro_rules! impl_ruma_handler {
( $($ty:ident),* $(,)? ) => {
#[axum::async_trait]
#[allow(non_snake_case)]
impl<Req, E, F, Fut, $($ty,)*> RumaHandler<($($ty,)* Ruma<Req>,)> for F
where
Req: IncomingRequest + Send + 'static,
F: FnOnce($($ty,)* Ruma<Req>) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Result<Req::OutgoingResponse, E>>
+ Send,
E: IntoResponse,
$( $ty: FromRequestParts<()> + Send + 'static, )*
{
fn add_to_router(self, mut router: Router) -> Router {
let meta = Req::METADATA;
let method_filter = method_to_filter(meta.method);
for path in meta.history.all_paths() {
let handler = self.clone();
router = router.route(path, on(method_filter, |$( $ty: $ty, )* req| async move {
handler($($ty,)* req).await.map(RumaResponse)
}))
}
router
}
}
};
}
impl_ruma_handler!();
impl_ruma_handler!(T1);
impl_ruma_handler!(T1, T2);
impl_ruma_handler!(T1, T2, T3);
impl_ruma_handler!(T1, T2, T3, T4);
impl_ruma_handler!(T1, T2, T3, T4, T5);
impl_ruma_handler!(T1, T2, T3, T4, T5, T6);
impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7);
impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8);
fn method_to_filter(method: Method) -> MethodFilter {
match method {
Method::DELETE => MethodFilter::DELETE,
Method::GET => MethodFilter::GET,
Method::HEAD => MethodFilter::HEAD,
Method::OPTIONS => MethodFilter::OPTIONS,
Method::PATCH => MethodFilter::PATCH,
Method::POST => MethodFilter::POST,
Method::PUT => MethodFilter::PUT,
Method::TRACE => MethodFilter::TRACE,
m => panic!("Unsupported HTTP method: {m:?}"),
}
}

185
src/router/run.rs Normal file
View file

@ -0,0 +1,185 @@
use std::{sync::Arc, time::Duration};
use axum_server::Handle as ServerHandle;
use tokio::{
signal,
sync::oneshot::{self, Sender},
};
use tracing::{debug, info, warn};
extern crate conduit_admin as admin;
extern crate conduit_core as conduit;
extern crate conduit_database as database;
extern crate conduit_service as service;
use std::sync::atomic::Ordering;
use conduit::{debug_info, trace, Error, Result, Server};
use database::KeyValueDatabase;
use service::{services, Services};
use crate::{layers, serve};
/// Main loop base
#[tracing::instrument(skip_all)]
pub(crate) async fn run(server: Arc<Server>) -> Result<(), Error> {
let config = &server.config;
let app = layers::build(&server)?;
let addrs = config.get_bind_addrs();
// Install the admin room callback here for now
_ = services().admin.handle.lock().await.insert(admin::handle);
// Setup shutdown/signal handling
let handle = ServerHandle::new();
_ = server
.shutdown
.lock()
.expect("locked")
.insert(handle.clone());
server.interrupt.store(false, Ordering::Release);
let (tx, rx) = oneshot::channel::<()>();
let sigs = server.runtime().spawn(sighandle(server.clone(), tx));
// Prepare to serve http clients
let res;
// Serve clients
if cfg!(unix) && config.unix_socket_path.is_some() {
res = serve::unix_socket(&server, app, rx).await;
} else if config.tls.is_some() {
res = serve::tls(&server, app, handle.clone(), addrs).await;
} else {
res = serve::plain(&server, app, handle.clone(), addrs).await;
}
// Join the signal handler before we leave.
sigs.abort();
_ = sigs.await;
// Reset the axum handle instance; this should be reusable and might be
// reload-survivable but better to be safe than sorry.
_ = server.shutdown.lock().expect("locked").take();
// Remove the admin room callback
_ = services().admin.handle.lock().await.take();
debug_info!("Finished");
Ok(res?)
}
/// Async initializations
#[tracing::instrument(skip_all)]
pub(crate) async fn start(server: Arc<Server>) -> Result<(), Error> {
debug!("Starting...");
let d = Arc::new(KeyValueDatabase::load_or_create(&server).await?);
let s = Box::new(Services::build(server, d.clone()).await?);
_ = service::SERVICES
.write()
.expect("write locked")
.insert(Box::leak(s));
services().start().await?;
#[cfg(feature = "systemd")]
#[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
debug!("Started");
Ok(())
}
/// Async destructions
#[tracing::instrument(skip_all)]
pub(crate) async fn stop(_server: Arc<Server>) -> Result<(), Error> {
debug!("Shutting down...");
// Wait for all completions before dropping or we'll lose them to the module
// unload and explode.
services().shutdown().await;
// Deactivate services(). Any further use will panic the caller.
let s = service::SERVICES
.write()
.expect("write locked")
.take()
.unwrap();
let s = std::ptr::from_ref(s) as *mut Services;
//SAFETY: Services was instantiated in start() and leaked into the SERVICES
// global perusing as 'static for the duration of run_server(). Now we reclaim
// it to drop it before unloading the module. If this is not done there will be
// multiple instances after module reload.
let s = unsafe { Box::from_raw(s) };
debug!("Cleaning up...");
// Drop it so we encounter any trouble before the infolog message
drop(s);
#[cfg(feature = "systemd")]
#[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]);
info!("Shutdown complete.");
Ok(())
}
#[tracing::instrument(skip_all)]
async fn sighandle(server: Arc<Server>, tx: Sender<()>) -> Result<(), Error> {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
let reload = cfg!(unix) && cfg!(debug_assertions);
server.reload.store(reload, Ordering::Release);
};
#[cfg(unix)]
let ctrl_bs = async {
signal::unix::signal(signal::unix::SignalKind::quit())
.expect("failed to install Ctrl+\\ handler")
.recv()
.await;
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler")
.recv()
.await;
};
debug!("Installed signal handlers");
let sig: &str;
#[cfg(unix)]
tokio::select! {
() = ctrl_c => { sig = "Ctrl+C"; },
() = ctrl_bs => { sig = "Ctrl+\\"; },
() = terminate => { sig = "SIGTERM"; },
}
#[cfg(not(unix))]
tokio::select! {
_ = ctrl_c => { sig = "Ctrl+C"; },
}
warn!("Received {}", sig);
server.interrupt.store(true, Ordering::Release);
services().globals.rotate.fire();
tx.send(())
.expect("failed sending shutdown transaction to oneshot channel");
if let Some(handle) = server.shutdown.lock().expect("locked").as_ref() {
let pending = server.requests_spawn_active.load(Ordering::Relaxed);
if pending > 0 {
let timeout = Duration::from_secs(36);
trace!(pending, ?timeout, "Notifying for graceful shutdown");
handle.graceful_shutdown(Some(timeout));
} else {
debug!(pending, "Notifying for immediate shutdown");
handle.shutdown();
}
}
Ok(())
}

137
src/router/serve.rs Normal file
View file

@ -0,0 +1,137 @@
#[cfg(unix)]
use std::fs::Permissions; // only for UNIX sockets stuff and *nix container checks
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt as _;
use std::{
io,
net::SocketAddr,
sync::{atomic::Ordering, Arc},
};
use axum::Router;
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
#[cfg(feature = "axum_dual_protocol")]
use axum_server_dual_protocol::ServerExt;
use conduit::{debug_info, Server};
use tokio::{
sync::oneshot::{self},
task::JoinSet,
};
use tracing::{debug, info, warn};
pub(crate) async fn plain(
server: &Arc<Server>, app: axum::routing::IntoMakeService<Router>, handle: ServerHandle, addrs: Vec<SocketAddr>,
) -> io::Result<()> {
let mut join_set = JoinSet::new();
for addr in &addrs {
join_set.spawn_on(bind(*addr).handle(handle.clone()).serve(app.clone()), server.runtime());
}
info!("Listening on {addrs:?}");
while join_set.join_next().await.is_some() {}
let spawn_active = server.requests_spawn_active.load(Ordering::Relaxed);
let handle_active = server.requests_handle_active.load(Ordering::Relaxed);
debug_info!(
spawn_finished = server.requests_spawn_finished.load(Ordering::Relaxed),
handle_finished = server.requests_handle_finished.load(Ordering::Relaxed),
panics = server.requests_panic.load(Ordering::Relaxed),
spawn_active,
handle_active,
"Stopped listening on {addrs:?}",
);
debug_assert!(spawn_active == 0, "active request tasks are not joined");
debug_assert!(handle_active == 0, "active request handles still pending");
Ok(())
}
pub(crate) async fn tls(
server: &Arc<Server>, app: axum::routing::IntoMakeService<Router>, handle: ServerHandle, addrs: Vec<SocketAddr>,
) -> io::Result<()> {
let config = &server.config;
let tls = config.tls.as_ref().expect("TLS configuration");
debug!(
"Using direct TLS. Certificate path {} and certificate private key path {}",
&tls.certs, &tls.key
);
info!(
"Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit directly with TLS."
);
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
if cfg!(feature = "axum_dual_protocol") {
info!(
"conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This will only \
take effect if `dual_protocol` is enabled in `[global.tls]`"
);
}
let mut join_set = JoinSet::new();
if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol {
#[cfg(feature = "axum_dual_protocol")]
for addr in &addrs {
join_set.spawn_on(
axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone())
.set_upgrade(false)
.handle(handle.clone())
.serve(app.clone()),
server.runtime(),
);
}
} else {
for addr in &addrs {
join_set.spawn_on(
bind_rustls(*addr, conf.clone())
.handle(handle.clone())
.serve(app.clone()),
server.runtime(),
);
}
}
if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol {
warn!(
"Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too (insecure!)",
addrs, &tls.certs
);
} else {
info!("Listening on {:?} with TLS certificate {}", addrs, &tls.certs);
}
while join_set.join_next().await.is_some() {}
Ok(())
}
#[cfg(unix)]
#[allow(unused_variables)]
pub(crate) async fn unix_socket(
server: &Arc<Server>, app: axum::routing::IntoMakeService<Router>, rx: oneshot::Receiver<()>,
) -> io::Result<()> {
let config = &server.config;
let path = config.unix_socket_path.as_ref().unwrap();
if path.exists() {
warn!(
"UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.",
path.display()
);
tokio::fs::remove_file(&path).await?;
}
tokio::fs::create_dir_all(path.parent().unwrap()).await?;
let socket_perms = config.unix_socket_perms.to_string();
let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap();
tokio::fs::set_permissions(&path, Permissions::from_mode(octal_perms))
.await
.unwrap();
let bind = tokio::net::UnixListener::bind(path)?;
info!("Listening at {:?}", path);
Ok(())
}