Port from Rocket to axum

This commit is contained in:
Jonas Platte 2022-01-20 11:51:31 +01:00
parent 8709c3ae7b
commit 1f7b3fa4ac
No known key found for this signature in database
GPG key ID: 7D261D771D915378
52 changed files with 1064 additions and 1885 deletions

View file

@ -7,24 +7,37 @@
#![allow(clippy::suspicious_else_formatting)]
#![deny(clippy::dbg_macro)]
use std::sync::Arc;
use std::{future::Future, net::SocketAddr, sync::Arc, time::Duration};
use maplit::hashset;
use opentelemetry::trace::{FutureExt, Tracer};
use rocket::{
catch, catchers,
figment::{
providers::{Env, Format, Toml},
Figment,
},
routes, Request,
use axum::{
extract::{FromRequest, MatchedPath},
handler::Handler,
routing::{get, on, MethodFilter},
Router,
};
use figment::{
providers::{Env, Format, Toml},
Figment,
};
use http::{
header::{self, HeaderName},
Method,
};
use opentelemetry::trace::{FutureExt, Tracer};
use ruma::{
api::{IncomingRequest, Metadata},
Outgoing,
};
use tokio::{signal, sync::RwLock};
use tower::ServiceBuilder;
use tower_http::{
cors::{self, CorsLayer},
trace::TraceLayer,
ServiceBuilderExt as _,
};
use ruma::api::client::error::ErrorKind;
use tokio::sync::RwLock;
use tracing_subscriber::{prelude::*, EnvFilter};
pub use conduit::*; // Re-export everything from the library crate
pub use rocket::State;
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
use tikv_jemallocator::Jemalloc;
@ -33,160 +46,10 @@ use tikv_jemallocator::Jemalloc;
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;
fn setup_rocket(config: Figment, data: Arc<RwLock<Database>>) -> rocket::Rocket<rocket::Build> {
rocket::custom(config)
.manage(data)
.mount(
"/",
routes![
client_server::get_supported_versions_route,
client_server::get_register_available_route,
client_server::register_route,
client_server::get_login_types_route,
client_server::login_route,
client_server::whoami_route,
client_server::logout_route,
client_server::logout_all_route,
client_server::change_password_route,
client_server::deactivate_route,
client_server::third_party_route,
client_server::get_capabilities_route,
client_server::get_pushrules_all_route,
client_server::set_pushrule_route,
client_server::get_pushrule_route,
client_server::set_pushrule_enabled_route,
client_server::get_pushrule_enabled_route,
client_server::get_pushrule_actions_route,
client_server::set_pushrule_actions_route,
client_server::delete_pushrule_route,
client_server::get_room_event_route,
client_server::get_room_aliases_route,
client_server::get_filter_route,
client_server::create_filter_route,
client_server::set_global_account_data_route,
client_server::set_room_account_data_route,
client_server::get_global_account_data_route,
client_server::get_room_account_data_route,
client_server::set_displayname_route,
client_server::get_displayname_route,
client_server::set_avatar_url_route,
client_server::get_avatar_url_route,
client_server::get_profile_route,
client_server::set_presence_route,
client_server::get_presence_route,
client_server::upload_keys_route,
client_server::get_keys_route,
client_server::claim_keys_route,
client_server::create_backup_route,
client_server::update_backup_route,
client_server::delete_backup_route,
client_server::get_latest_backup_route,
client_server::get_backup_route,
client_server::add_backup_key_sessions_route,
client_server::add_backup_keys_route,
client_server::delete_backup_key_session_route,
client_server::delete_backup_key_sessions_route,
client_server::delete_backup_keys_route,
client_server::get_backup_key_session_route,
client_server::get_backup_key_sessions_route,
client_server::get_backup_keys_route,
client_server::set_read_marker_route,
client_server::create_receipt_route,
client_server::create_typing_event_route,
client_server::create_room_route,
client_server::redact_event_route,
client_server::report_event_route,
client_server::create_alias_route,
client_server::delete_alias_route,
client_server::get_alias_route,
client_server::join_room_by_id_route,
client_server::join_room_by_id_or_alias_route,
client_server::joined_members_route,
client_server::leave_room_route,
client_server::forget_room_route,
client_server::joined_rooms_route,
client_server::kick_user_route,
client_server::ban_user_route,
client_server::unban_user_route,
client_server::invite_user_route,
client_server::set_room_visibility_route,
client_server::get_room_visibility_route,
client_server::get_public_rooms_route,
client_server::get_public_rooms_filtered_route,
client_server::search_users_route,
client_server::get_member_events_route,
client_server::get_protocols_route,
client_server::send_message_event_route,
client_server::send_state_event_for_key_route,
client_server::send_state_event_for_empty_key_route,
client_server::get_state_events_route,
client_server::get_state_events_for_key_route,
client_server::get_state_events_for_empty_key_route,
client_server::sync_events_route,
client_server::get_context_route,
client_server::get_message_events_route,
client_server::search_events_route,
client_server::turn_server_route,
client_server::send_event_to_device_route,
client_server::get_media_config_route,
client_server::create_content_route,
client_server::get_content_as_filename_route,
client_server::get_content_route,
client_server::get_content_thumbnail_route,
client_server::get_devices_route,
client_server::get_device_route,
client_server::update_device_route,
client_server::delete_device_route,
client_server::delete_devices_route,
client_server::get_tags_route,
client_server::update_tag_route,
client_server::delete_tag_route,
client_server::options_route,
client_server::upload_signing_keys_route,
client_server::upload_signatures_route,
client_server::get_key_changes_route,
client_server::get_pushers_route,
client_server::set_pushers_route,
// client_server::third_party_route,
client_server::upgrade_room_route,
server_server::get_server_version_route,
server_server::get_server_keys_route,
server_server::get_server_keys_deprecated_route,
server_server::get_public_rooms_route,
server_server::get_public_rooms_filtered_route,
server_server::send_transaction_message_route,
server_server::get_event_route,
server_server::get_missing_events_route,
server_server::get_event_authorization_route,
server_server::get_room_state_route,
server_server::get_room_state_ids_route,
server_server::create_join_event_template_route,
server_server::create_join_event_v1_route,
server_server::create_join_event_v2_route,
server_server::create_invite_route,
server_server::get_devices_route,
server_server::get_room_information_route,
server_server::get_profile_information_route,
server_server::get_keys_route,
server_server::claim_keys_route,
],
)
.register(
"/",
catchers![
not_found_catcher,
forbidden_catcher,
unknown_token_catcher,
missing_token_catcher,
bad_json_catcher
],
)
}
#[rocket::main]
#[tokio::main]
async fn main() {
let raw_config =
Figment::from(default_config())
Figment::new()
.merge(
Toml::file(Env::var("CONDUIT_CONFIG").expect(
"The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml",
@ -217,14 +80,7 @@ async fn main() {
}
};
let rocket = setup_rocket(raw_config, Arc::clone(&db))
.ignite()
.await
.unwrap();
Database::start_on_shutdown_tasks(db, rocket.shutdown()).await;
rocket.launch().await.unwrap();
run_server(&config, db).await.unwrap();
};
if config.allow_jaeger {
@ -264,55 +120,282 @@ async fn main() {
}
}
#[catch(404)]
fn not_found_catcher(_: &Request<'_>) -> String {
"404 Not Found".to_owned()
async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> hyper::Result<()> {
let listen_addr = SocketAddr::from((config.address, config.port));
let x_requested_with = HeaderName::from_static("x-requested-with");
let middlewares = ServiceBuilder::new()
.sensitive_headers([header::AUTHORIZATION])
.layer(
TraceLayer::new_for_http().make_span_with(|request: &http::Request<_>| {
let path = if let Some(path) = request.extensions().get::<MatchedPath>() {
path.as_str()
} else {
request.uri().path()
};
tracing::info_span!("http_request", %path)
}),
)
.compression()
.layer(
CorsLayer::new()
.allow_origin(cors::any())
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
])
.allow_headers([
header::ORIGIN,
x_requested_with,
header::CONTENT_TYPE,
header::ACCEPT,
header::AUTHORIZATION,
])
.max_age(Duration::from_secs(86400)),
)
.add_extension(db.clone());
axum::Server::bind(&listen_addr)
.serve(routes().layer(middlewares).into_make_service())
.with_graceful_shutdown(shutdown_signal())
.await?;
// After serve exits and before exiting, shutdown the DB
Database::on_shutdown(db).await;
Ok(())
}
#[catch(580)]
fn forbidden_catcher() -> Result<()> {
Err(Error::BadRequest(ErrorKind::Forbidden, "Forbidden."))
fn routes() -> Router {
Router::new()
.ruma_route(client_server::get_supported_versions_route)
.ruma_route(client_server::get_register_available_route)
.ruma_route(client_server::register_route)
.ruma_route(client_server::get_login_types_route)
.ruma_route(client_server::login_route)
.ruma_route(client_server::whoami_route)
.ruma_route(client_server::logout_route)
.ruma_route(client_server::logout_all_route)
.ruma_route(client_server::change_password_route)
.ruma_route(client_server::deactivate_route)
.ruma_route(client_server::third_party_route)
.ruma_route(client_server::get_capabilities_route)
.ruma_route(client_server::get_pushrules_all_route)
.ruma_route(client_server::set_pushrule_route)
.ruma_route(client_server::get_pushrule_route)
.ruma_route(client_server::set_pushrule_enabled_route)
.ruma_route(client_server::get_pushrule_enabled_route)
.ruma_route(client_server::get_pushrule_actions_route)
.ruma_route(client_server::set_pushrule_actions_route)
.ruma_route(client_server::delete_pushrule_route)
.ruma_route(client_server::get_room_event_route)
.ruma_route(client_server::get_room_aliases_route)
.ruma_route(client_server::get_filter_route)
.ruma_route(client_server::create_filter_route)
.ruma_route(client_server::set_global_account_data_route)
.ruma_route(client_server::set_room_account_data_route)
.ruma_route(client_server::get_global_account_data_route)
.ruma_route(client_server::get_room_account_data_route)
.ruma_route(client_server::set_displayname_route)
.ruma_route(client_server::get_displayname_route)
.ruma_route(client_server::set_avatar_url_route)
.ruma_route(client_server::get_avatar_url_route)
.ruma_route(client_server::get_profile_route)
.ruma_route(client_server::set_presence_route)
.ruma_route(client_server::get_presence_route)
.ruma_route(client_server::upload_keys_route)
.ruma_route(client_server::get_keys_route)
.ruma_route(client_server::claim_keys_route)
.ruma_route(client_server::create_backup_route)
.ruma_route(client_server::update_backup_route)
.ruma_route(client_server::delete_backup_route)
.ruma_route(client_server::get_latest_backup_route)
.ruma_route(client_server::get_backup_route)
.ruma_route(client_server::add_backup_key_sessions_route)
.ruma_route(client_server::add_backup_keys_route)
.ruma_route(client_server::delete_backup_key_session_route)
.ruma_route(client_server::delete_backup_key_sessions_route)
.ruma_route(client_server::delete_backup_keys_route)
.ruma_route(client_server::get_backup_key_session_route)
.ruma_route(client_server::get_backup_key_sessions_route)
.ruma_route(client_server::get_backup_keys_route)
.ruma_route(client_server::set_read_marker_route)
.ruma_route(client_server::create_receipt_route)
.ruma_route(client_server::create_typing_event_route)
.ruma_route(client_server::create_room_route)
.ruma_route(client_server::redact_event_route)
.ruma_route(client_server::report_event_route)
.ruma_route(client_server::create_alias_route)
.ruma_route(client_server::delete_alias_route)
.ruma_route(client_server::get_alias_route)
.ruma_route(client_server::join_room_by_id_route)
.ruma_route(client_server::join_room_by_id_or_alias_route)
.ruma_route(client_server::joined_members_route)
.ruma_route(client_server::leave_room_route)
.ruma_route(client_server::forget_room_route)
.ruma_route(client_server::joined_rooms_route)
.ruma_route(client_server::kick_user_route)
.ruma_route(client_server::ban_user_route)
.ruma_route(client_server::unban_user_route)
.ruma_route(client_server::invite_user_route)
.ruma_route(client_server::set_room_visibility_route)
.ruma_route(client_server::get_room_visibility_route)
.ruma_route(client_server::get_public_rooms_route)
.ruma_route(client_server::get_public_rooms_filtered_route)
.ruma_route(client_server::search_users_route)
.ruma_route(client_server::get_member_events_route)
.ruma_route(client_server::get_protocols_route)
.ruma_route(client_server::send_message_event_route)
.ruma_route(client_server::send_state_event_for_key_route)
.ruma_route(client_server::send_state_event_for_empty_key_route)
.ruma_route(client_server::get_state_events_route)
.ruma_route(client_server::get_state_events_for_key_route)
.ruma_route(client_server::get_state_events_for_empty_key_route)
.route(
"/_matrix/client/r0/sync",
get(client_server::sync_events_route),
)
.ruma_route(client_server::get_context_route)
.ruma_route(client_server::get_message_events_route)
.ruma_route(client_server::search_events_route)
.ruma_route(client_server::turn_server_route)
.ruma_route(client_server::send_event_to_device_route)
.ruma_route(client_server::get_media_config_route)
.ruma_route(client_server::create_content_route)
.ruma_route(client_server::get_content_route)
.ruma_route(client_server::get_content_as_filename_route)
.ruma_route(client_server::get_content_thumbnail_route)
.ruma_route(client_server::get_devices_route)
.ruma_route(client_server::get_device_route)
.ruma_route(client_server::update_device_route)
.ruma_route(client_server::delete_device_route)
.ruma_route(client_server::delete_devices_route)
.ruma_route(client_server::get_tags_route)
.ruma_route(client_server::update_tag_route)
.ruma_route(client_server::delete_tag_route)
.ruma_route(client_server::upload_signing_keys_route)
.ruma_route(client_server::upload_signatures_route)
.ruma_route(client_server::get_key_changes_route)
.ruma_route(client_server::get_pushers_route)
.ruma_route(client_server::set_pushers_route)
// .ruma_route(client_server::third_party_route)
.ruma_route(client_server::upgrade_room_route)
.ruma_route(server_server::get_server_version_route)
.route(
"/_matrix/key/v2/server",
get(server_server::get_server_keys_route),
)
.route(
"/_matrix/key/v2/server/:key_id",
get(server_server::get_server_keys_deprecated_route),
)
.ruma_route(server_server::get_public_rooms_route)
.ruma_route(server_server::get_public_rooms_filtered_route)
.ruma_route(server_server::send_transaction_message_route)
.ruma_route(server_server::get_event_route)
.ruma_route(server_server::get_missing_events_route)
.ruma_route(server_server::get_event_authorization_route)
.ruma_route(server_server::get_room_state_route)
.ruma_route(server_server::get_room_state_ids_route)
.ruma_route(server_server::create_join_event_template_route)
.ruma_route(server_server::create_join_event_v1_route)
.ruma_route(server_server::create_join_event_v2_route)
.ruma_route(server_server::create_invite_route)
.ruma_route(server_server::get_devices_route)
.ruma_route(server_server::get_room_information_route)
.ruma_route(server_server::get_profile_information_route)
.ruma_route(server_server::get_keys_route)
.ruma_route(server_server::claim_keys_route)
}
#[catch(581)]
fn unknown_token_catcher() -> Result<()> {
Err(Error::BadRequest(
ErrorKind::UnknownToken { soft_logout: false },
"Unknown token.",
))
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[catch(582)]
fn missing_token_catcher() -> Result<()> {
Err(Error::BadRequest(ErrorKind::MissingToken, "Missing token."))
}
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[catch(583)]
fn bad_json_catcher() -> Result<()> {
Err(Error::BadRequest(ErrorKind::BadJson, "Bad json."))
}
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
fn default_config() -> rocket::Config {
use rocket::config::{LogLevel, Shutdown, Sig};
rocket::Config {
// Disable rocket's logging to get only tracing-subscriber's log output
log_level: LogLevel::Off,
shutdown: Shutdown {
// Once shutdown is triggered, this is the amount of seconds before rocket
// will forcefully start shutting down connections, this gives enough time to /sync
// requests and the like (which havent gotten the memo, somehow) to still complete gracefully.
grace: 35,
// After the grace period, rocket starts shutting down connections, and waits at least this
// many seconds before forcefully shutting all of them down.
mercy: 10,
#[cfg(unix)]
signals: hashset![Sig::Term, Sig::Int],
..Shutdown::default()
},
..rocket::Config::release_default()
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
}
trait RouterExt {
fn ruma_route<H, T>(self, handler: H) -> Self
where
H: RumaHandler<T>,
T: 'static;
}
impl RouterExt for Router {
fn ruma_route<H, T>(self, handler: H) -> Self
where
H: RumaHandler<T>,
T: 'static,
{
let meta = H::METADATA;
let method_filter = match meta.method {
Method::DELETE => MethodFilter::DELETE,
Method::GET => MethodFilter::GET,
Method::HEAD => MethodFilter::HEAD,
Method::OPTIONS => MethodFilter::OPTIONS,
Method::PATCH => MethodFilter::PATCH,
Method::POST => MethodFilter::POST,
Method::PUT => MethodFilter::PUT,
Method::TRACE => MethodFilter::TRACE,
_ => panic!(""),
};
self.route(meta.path, on(method_filter, handler))
}
}
pub trait RumaHandler<T>: Handler<T> {
const METADATA: Metadata;
}
macro_rules! impl_ruma_handler {
( $($ty:ident),* $(,)? ) => {
#[axum::async_trait]
#[allow(non_snake_case)]
impl<Req, F, Fut, $($ty,)*> RumaHandler<($($ty,)* Ruma<Req>,)> for F
where
Req: Outgoing,
Req::Incoming: IncomingRequest + Send,
F: FnOnce($($ty,)* Ruma<Req>) -> Fut + Clone + Send + 'static,
Fut: Future<Output = ConduitResult<
<Req::Incoming as IncomingRequest>::OutgoingResponse
>> + Send,
$( $ty: FromRequest<axum::body::Body> + Send, )*
{
const METADATA: Metadata = Req::Incoming::METADATA;
}
};
}
impl_ruma_handler!();
impl_ruma_handler!(T1);
impl_ruma_handler!(T1, T2);
impl_ruma_handler!(T1, T2, T3);
impl_ruma_handler!(T1, T2, T3, T4);
impl_ruma_handler!(T1, T2, T3, T4, T5);
impl_ruma_handler!(T1, T2, T3, T4, T5, T6);
impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7);
impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8);