diff --git a/src/api/client_server/mod.rs b/src/api/client_server/mod.rs index 171e9bbe..ced2bde5 100644 --- a/src/api/client_server/mod.rs +++ b/src/api/client_server/mod.rs @@ -1,76 +1,76 @@ -pub(crate) mod account; -pub(crate) mod alias; -pub(crate) mod backup; -pub(crate) mod capabilities; -pub(crate) mod config; -pub(crate) mod context; -pub(crate) mod device; -pub(crate) mod directory; -pub(crate) mod filter; -pub(crate) mod keys; -pub(crate) mod media; -pub(crate) mod membership; -pub(crate) mod message; -pub(crate) mod presence; -pub(crate) mod profile; -pub(crate) mod push; -pub(crate) mod read_marker; -pub(crate) mod redact; -pub(crate) mod relations; -pub(crate) mod report; -pub(crate) mod room; -pub(crate) mod search; -pub(crate) mod session; -pub(crate) mod space; -pub(crate) mod state; -pub(crate) mod sync; -pub(crate) mod tag; -pub(crate) mod thirdparty; -pub(crate) mod threads; -pub(crate) mod to_device; -pub(crate) mod typing; -pub(crate) mod unstable; -pub(crate) mod unversioned; -pub(crate) mod user_directory; -pub(crate) mod voip; +pub(super) mod account; +pub(super) mod alias; +pub(super) mod backup; +pub(super) mod capabilities; +pub(super) mod config; +pub(super) mod context; +pub(super) mod device; +pub(super) mod directory; +pub(super) mod filter; +pub(super) mod keys; +pub(super) mod media; +pub(super) mod membership; +pub(super) mod message; +pub(super) mod presence; +pub(super) mod profile; +pub(super) mod push; +pub(super) mod read_marker; +pub(super) mod redact; +pub(super) mod relations; +pub(super) mod report; +pub(super) mod room; +pub(super) mod search; +pub(super) mod session; +pub(super) mod space; +pub(super) mod state; +pub(super) mod sync; +pub(super) mod tag; +pub(super) mod thirdparty; +pub(super) mod threads; +pub(super) mod to_device; +pub(super) mod typing; +pub(super) mod unstable; +pub(super) mod unversioned; +pub(super) mod user_directory; +pub(super) mod voip; -pub(crate) use account::*; +pub(super) use account::*; pub use alias::get_alias_helper; -pub(crate) use alias::*; -pub(crate) use backup::*; -pub(crate) use capabilities::*; -pub(crate) use config::*; -pub(crate) use context::*; -pub(crate) use device::*; -pub(crate) use directory::*; -pub(crate) use filter::*; -pub(crate) use keys::*; -pub(crate) use media::*; -pub(crate) use membership::*; +pub(super) use alias::*; +pub(super) use backup::*; +pub(super) use capabilities::*; +pub(super) use config::*; +pub(super) use context::*; +pub(super) use device::*; +pub(super) use directory::*; +pub(super) use filter::*; +pub(super) use keys::*; +pub(super) use media::*; +pub(super) use membership::*; pub use membership::{join_room_by_id_helper, leave_all_rooms, leave_room}; -pub(crate) use message::*; -pub(crate) use presence::*; -pub(crate) use profile::*; -pub(crate) use push::*; -pub(crate) use read_marker::*; -pub(crate) use redact::*; -pub(crate) use relations::*; -pub(crate) use report::*; -pub(crate) use room::*; -pub(crate) use search::*; -pub(crate) use session::*; -pub(crate) use space::*; -pub(crate) use state::*; -pub(crate) use sync::*; -pub(crate) use tag::*; -pub(crate) use thirdparty::*; -pub(crate) use threads::*; -pub(crate) use to_device::*; -pub(crate) use typing::*; -pub(crate) use unstable::*; -pub(crate) use unversioned::*; -pub(crate) use user_directory::*; -pub(crate) use voip::*; +pub(super) use message::*; +pub(super) use presence::*; +pub(super) use profile::*; +pub(super) use push::*; +pub(super) use read_marker::*; +pub(super) use redact::*; +pub(super) use relations::*; +pub(super) use report::*; +pub(super) use room::*; +pub(super) use search::*; +pub(super) use session::*; +pub(super) use space::*; +pub(super) use state::*; +pub(super) use sync::*; +pub(super) use tag::*; +pub(super) use thirdparty::*; +pub(super) use threads::*; +pub(super) use to_device::*; +pub(super) use typing::*; +pub(super) use unstable::*; +pub(super) use unversioned::*; +pub(super) use user_directory::*; +pub(super) use voip::*; /// generated device ID length const DEVICE_ID_LENGTH: usize = 10; diff --git a/src/api/mod.rs b/src/api/mod.rs index 956bcdf7..b835f536 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -7,8 +7,8 @@ extern crate conduit_core as conduit; extern crate conduit_service as service; pub use client_server::membership::{join_room_by_id_helper, leave_all_rooms}; -pub(crate) use conduit::{debug_error, debug_info, debug_warn, error::RumaResponse, utils, Error, Result}; -pub(crate) use ruma_wrapper::Ruma; +pub(crate) use conduit::{debug_error, debug_info, debug_warn, utils, Error, Result}; +pub(crate) use ruma_wrapper::{Ruma, RumaResponse}; pub(crate) use service::{pdu::PduEvent, services, user_is_local}; conduit::mod_ctor! {} diff --git a/src/api/router.rs b/src/api/router.rs index 6081a089..90d69873 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -1,15 +1,13 @@ -use std::future::Future; - use axum::{ response::IntoResponse, - routing::{any, get, on, post, MethodFilter}, + routing::{any, get, post}, Router, }; -use conduit::{Error, Result, Server}; -use http::{Method, Uri}; -use ruma::api::{client::error::ErrorKind, IncomingRequest}; +use conduit::{Error, Server}; +use http::Uri; +use ruma::api::client::error::ErrorKind; -use crate::{client_server, server_server, Ruma, RumaResponse}; +use crate::{client_server, ruma_wrapper::RouterExt, server_server}; pub fn build(router: Router, server: &Server) -> Router { let config = &server.config; @@ -234,66 +232,3 @@ async fn initial_sync(_uri: Uri) -> impl IntoResponse { } async fn federation_disabled() -> impl IntoResponse { Error::bad_config("Federation is disabled.") } - -trait RouterExt { - fn ruma_route(self, handler: H) -> Self - where - H: RumaHandler, - T: 'static; -} - -impl RouterExt for Router { - #[inline(always)] - fn ruma_route(self, handler: H) -> Self - where - H: RumaHandler, - T: 'static, - { - handler.add_routes(self) - } -} - -trait RumaHandler { - fn add_routes(&self, router: Router) -> Router; - - fn add_route(&self, router: Router, path: &str) -> Router; -} - -impl RumaHandler> for F -where - Req: IncomingRequest + Send + 'static, - F: FnOnce(Ruma) -> Fut + Clone + Send + Sync + 'static, - Fut: Future> + Send, - E: IntoResponse, -{ - #[inline(always)] - fn add_routes(&self, router: Router) -> Router { - Req::METADATA - .history - .all_paths() - .fold(router, |router, path| self.add_route(router, path)) - } - - #[inline(always)] - fn add_route(&self, router: Router, path: &str) -> Router { - let handle = self.clone(); - let method = method_to_filter(Req::METADATA.method); - let action = |req| async { handle(req).await.map(RumaResponse) }; - router.route(path, on(method, action)) - } -} - -#[inline] -fn method_to_filter(method: Method) -> MethodFilter { - match method { - Method::DELETE => MethodFilter::DELETE, - Method::GET => MethodFilter::GET, - Method::HEAD => MethodFilter::HEAD, - Method::OPTIONS => MethodFilter::OPTIONS, - Method::PATCH => MethodFilter::PATCH, - Method::POST => MethodFilter::POST, - Method::PUT => MethodFilter::PUT, - Method::TRACE => MethodFilter::TRACE, - m => panic!("Unsupported HTTP method: {m:?}"), - } -} diff --git a/src/api/ruma_wrapper/mod.rs b/src/api/ruma_wrapper/mod.rs index 1e12995f..a130ddd9 100644 --- a/src/api/ruma_wrapper/mod.rs +++ b/src/api/ruma_wrapper/mod.rs @@ -1,11 +1,14 @@ mod auth; mod request; +mod router; mod xmatrix; use std::ops::Deref; +pub(super) use conduit::error::RumaResponse; use ruma::{CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId}; +pub(super) use self::router::RouterExt; use crate::service::appservice::RegistrationInfo; /// Extractor for Ruma request structs diff --git a/src/api/ruma_wrapper/router.rs b/src/api/ruma_wrapper/router.rs new file mode 100644 index 00000000..f769d2bb --- /dev/null +++ b/src/api/ruma_wrapper/router.rs @@ -0,0 +1,69 @@ +use std::future::Future; + +use axum::{ + response::IntoResponse, + routing::{on, MethodFilter}, + Router, +}; +use conduit::Result; +use http::Method; +use ruma::api::IncomingRequest; + +use super::{Ruma, RumaResponse}; + +pub(in super::super) trait RouterExt { + fn ruma_route(self, handler: H) -> Self + where + H: RumaHandler; +} + +impl RouterExt for Router { + fn ruma_route(self, handler: H) -> Self + where + H: RumaHandler, + { + handler.add_routes(self) + } +} + +pub(in super::super) trait RumaHandler { + fn add_routes(&self, router: Router) -> Router; + + fn add_route(&self, router: Router, path: &str) -> Router; +} + +impl RumaHandler> for F +where + Req: IncomingRequest + Send + 'static, + F: FnOnce(Ruma) -> Fut + Clone + Send + Sync + 'static, + Fut: Future> + Send, + E: IntoResponse, +{ + fn add_routes(&self, router: Router) -> Router { + Req::METADATA + .history + .all_paths() + .fold(router, |router, path| self.add_route(router, path)) + } + + fn add_route(&self, router: Router, path: &str) -> Router { + let handle = self.clone(); + let method = method_to_filter(&Req::METADATA.method); + let action = |req| async { handle(req).await.map(RumaResponse) }; + router.route(path, on(method, action)) + } +} + +const fn method_to_filter(method: &Method) -> MethodFilter { + match *method { + Method::DELETE => MethodFilter::DELETE, + Method::GET => MethodFilter::GET, + Method::HEAD => MethodFilter::HEAD, + Method::OPTIONS => MethodFilter::OPTIONS, + Method::PATCH => MethodFilter::PATCH, + Method::POST => MethodFilter::POST, + Method::PUT => MethodFilter::PUT, + Method::TRACE => MethodFilter::TRACE, + _ => panic!("Unsupported HTTP method"), + } +}