Port from Rocket to axum

This commit is contained in:
Jonas Platte 2022-01-20 11:51:31 +01:00
parent 8709c3ae7b
commit 1f7b3fa4ac
No known key found for this signature in database
GPG key ID: 7D261D771D915378
52 changed files with 1064 additions and 1885 deletions

View file

@ -4,13 +4,11 @@ use crate::{
pdu::EventHash,
utils, ConduitResult, Database, Error, PduEvent, Result, Ruma,
};
use axum::{response::IntoResponse, Json};
use futures_util::{stream::FuturesUnordered, StreamExt};
use get_profile_information::v1::ProfileField;
use http::header::{HeaderValue, AUTHORIZATION};
use regex::Regex;
use rocket::{
futures::{prelude::*, stream::FuturesUnordered},
response::content::Json,
};
use ruma::{
api::{
client::error::{Error as RumaError, ErrorKind},
@ -72,9 +70,6 @@ use std::{
use tokio::sync::{MutexGuard, Semaphore};
use tracing::{debug, error, info, trace, warn};
#[cfg(feature = "conduit_bin")]
use rocket::{get, post, put};
/// Wraps either an literal IP address plus port, or a hostname plus complement
/// (colon-plus-port if it was specified).
///
@ -495,10 +490,10 @@ async fn request_well_known(
/// # `GET /_matrix/federation/v1/version`
///
/// Get version information on this server.
#[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))]
#[tracing::instrument(skip(db))]
pub fn get_server_version_route(
#[tracing::instrument(skip(db, _body))]
pub async fn get_server_version_route(
db: DatabaseGuard,
_body: Ruma<get_server_version::v1::Request>,
) -> ConduitResult<get_server_version::v1::Response> {
if !db.globals.allow_federation() {
return Err(Error::bad_config("Federation is disabled."));
@ -520,12 +515,11 @@ pub fn get_server_version_route(
/// - Matrix does not support invalidating public keys, so the key returned by this will be valid
/// forever.
// Response type for this endpoint is Json because we need to calculate a signature for the response
#[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server"))]
#[tracing::instrument(skip(db))]
pub fn get_server_keys_route(db: DatabaseGuard) -> Json<String> {
pub async fn get_server_keys_route(db: DatabaseGuard) -> impl IntoResponse {
if !db.globals.allow_federation() {
// TODO: Use proper types
return Json("Federation is disabled.".to_owned());
return Json("Federation is disabled.").into_response();
}
let mut verify_keys: BTreeMap<Box<ServerSigningKeyId>, VerifyKey> = BTreeMap::new();
@ -563,7 +557,7 @@ pub fn get_server_keys_route(db: DatabaseGuard) -> Json<String> {
)
.unwrap();
Json(serde_json::to_string(&response).expect("JSON is canonical"))
Json(response).into_response()
}
/// # `GET /_matrix/key/v2/server/{keyId}`
@ -572,19 +566,14 @@ pub fn get_server_keys_route(db: DatabaseGuard) -> Json<String> {
///
/// - Matrix does not support invalidating public keys, so the key returned by this will be valid
/// forever.
#[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))]
#[tracing::instrument(skip(db))]
pub fn get_server_keys_deprecated_route(db: DatabaseGuard) -> Json<String> {
get_server_keys_route(db)
pub async fn get_server_keys_deprecated_route(db: DatabaseGuard) -> impl IntoResponse {
get_server_keys_route(db).await
}
/// # `POST /_matrix/federation/v1/publicRooms`
///
/// Lists the public rooms on this server.
#[cfg_attr(
feature = "conduit_bin",
post("/_matrix/federation/v1/publicRooms", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub async fn get_public_rooms_filtered_route(
db: DatabaseGuard,
@ -628,10 +617,6 @@ pub async fn get_public_rooms_filtered_route(
/// # `GET /_matrix/federation/v1/publicRooms`
///
/// Lists the public rooms on this server.
#[cfg_attr(
feature = "conduit_bin",
get("/_matrix/federation/v1/publicRooms", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub async fn get_public_rooms_route(
db: DatabaseGuard,
@ -675,10 +660,6 @@ pub async fn get_public_rooms_route(
/// # `PUT /_matrix/federation/v1/send/{txnId}`
///
/// Push EDUs and PDUs to this server.
#[cfg_attr(
feature = "conduit_bin",
put("/_matrix/federation/v1/send/<_>", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub async fn send_transaction_message_route(
db: DatabaseGuard,
@ -2309,12 +2290,8 @@ fn get_auth_chain_inner(
/// Retrieves a single event from the server.
///
/// - Only works if a user of this server is currently invited or joined the room
#[cfg_attr(
feature = "conduit_bin",
get("/_matrix/federation/v1/event/<_>", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub fn get_event_route(
pub async fn get_event_route(
db: DatabaseGuard,
body: Ruma<get_event::v1::Request<'_>>,
) -> ConduitResult<get_event::v1::Response> {
@ -2358,12 +2335,8 @@ pub fn get_event_route(
/// # `POST /_matrix/federation/v1/get_missing_events/{roomId}`
///
/// Retrieves events that the sender is missing.
#[cfg_attr(
feature = "conduit_bin",
post("/_matrix/federation/v1/get_missing_events/<_>", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub fn get_missing_events_route(
pub async fn get_missing_events_route(
db: DatabaseGuard,
body: Ruma<get_missing_events::v1::Request<'_>>,
) -> ConduitResult<get_missing_events::v1::Response> {
@ -2436,12 +2409,8 @@ pub fn get_missing_events_route(
/// Retrieves the auth chain for a given event.
///
/// - This does not include the event itself
#[cfg_attr(
feature = "conduit_bin",
get("/_matrix/federation/v1/event_auth/<_>/<_>", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub fn get_event_authorization_route(
pub async fn get_event_authorization_route(
db: DatabaseGuard,
body: Ruma<get_event_authorization::v1::Request<'_>>,
) -> ConduitResult<get_event_authorization::v1::Response> {
@ -2490,12 +2459,8 @@ pub fn get_event_authorization_route(
/// # `GET /_matrix/federation/v1/state/{roomId}`
///
/// Retrieves the current state of the room.
#[cfg_attr(
feature = "conduit_bin",
get("/_matrix/federation/v1/state/<_>", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub fn get_room_state_route(
pub async fn get_room_state_route(
db: DatabaseGuard,
body: Ruma<get_room_state::v1::Request<'_>>,
) -> ConduitResult<get_room_state::v1::Response> {
@ -2555,12 +2520,8 @@ pub fn get_room_state_route(
/// # `GET /_matrix/federation/v1/state_ids/{roomId}`
///
/// Retrieves the current state of the room.
#[cfg_attr(
feature = "conduit_bin",
get("/_matrix/federation/v1/state_ids/<_>", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub fn get_room_state_ids_route(
pub async fn get_room_state_ids_route(
db: DatabaseGuard,
body: Ruma<get_room_state_ids::v1::Request<'_>>,
) -> ConduitResult<get_room_state_ids::v1::Response> {
@ -2609,12 +2570,8 @@ pub fn get_room_state_ids_route(
/// # `GET /_matrix/federation/v1/make_join/{roomId}/{userId}`
///
/// Creates a join template.
#[cfg_attr(
feature = "conduit_bin",
get("/_matrix/federation/v1/make_join/<_>/<_>", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub fn create_join_event_template_route(
pub async fn create_join_event_template_route(
db: DatabaseGuard,
body: Ruma<create_join_event_template::v1::Request<'_>>,
) -> ConduitResult<create_join_event_template::v1::Response> {
@ -2895,10 +2852,6 @@ async fn create_join_event(
/// # `PUT /_matrix/federation/v1/send_join/{roomId}/{eventId}`
///
/// Submits a signed join event.
#[cfg_attr(
feature = "conduit_bin",
put("/_matrix/federation/v1/send_join/<_>/<_>", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub async fn create_join_event_v1_route(
db: DatabaseGuard,
@ -2917,10 +2870,6 @@ pub async fn create_join_event_v1_route(
/// # `PUT /_matrix/federation/v2/send_join/{roomId}/{eventId}`
///
/// Submits a signed join event.
#[cfg_attr(
feature = "conduit_bin",
put("/_matrix/federation/v2/send_join/<_>/<_>", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub async fn create_join_event_v2_route(
db: DatabaseGuard,
@ -2939,10 +2888,6 @@ pub async fn create_join_event_v2_route(
/// # `PUT /_matrix/federation/v2/invite/{roomId}/{eventId}`
///
/// Invites a remote user to a room.
#[cfg_attr(
feature = "conduit_bin",
put("/_matrix/federation/v2/invite/<_>/<_>", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub async fn create_invite_route(
db: DatabaseGuard,
@ -3055,12 +3000,8 @@ pub async fn create_invite_route(
/// # `GET /_matrix/federation/v1/user/devices/{userId}`
///
/// Gets information on all devices of the user.
#[cfg_attr(
feature = "conduit_bin",
get("/_matrix/federation/v1/user/devices/<_>", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub fn get_devices_route(
pub async fn get_devices_route(
db: DatabaseGuard,
body: Ruma<get_devices::v1::Request<'_>>,
) -> ConduitResult<get_devices::v1::Response> {
@ -3098,12 +3039,8 @@ pub fn get_devices_route(
/// # `GET /_matrix/federation/v1/query/directory`
///
/// Resolve a room alias to a room id.
#[cfg_attr(
feature = "conduit_bin",
get("/_matrix/federation/v1/query/directory", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub fn get_room_information_route(
pub async fn get_room_information_route(
db: DatabaseGuard,
body: Ruma<get_room_information::v1::Request<'_>>,
) -> ConduitResult<get_room_information::v1::Response> {
@ -3129,12 +3066,8 @@ pub fn get_room_information_route(
/// # `GET /_matrix/federation/v1/query/profile`
///
/// Gets information on a profile.
#[cfg_attr(
feature = "conduit_bin",
get("/_matrix/federation/v1/query/profile", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub fn get_profile_information_route(
pub async fn get_profile_information_route(
db: DatabaseGuard,
body: Ruma<get_profile_information::v1::Request<'_>>,
) -> ConduitResult<get_profile_information::v1::Response> {
@ -3172,10 +3105,6 @@ pub fn get_profile_information_route(
/// # `POST /_matrix/federation/v1/user/keys/query`
///
/// Gets devices and identity keys for the given users.
#[cfg_attr(
feature = "conduit_bin",
post("/_matrix/federation/v1/user/keys/query", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub async fn get_keys_route(
db: DatabaseGuard,
@ -3206,10 +3135,6 @@ pub async fn get_keys_route(
/// # `POST /_matrix/federation/v1/user/keys/claim`
///
/// Claims one-time keys.
#[cfg_attr(
feature = "conduit_bin",
post("/_matrix/federation/v1/user/keys/claim", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub async fn claim_keys_route(
db: DatabaseGuard,