diff --git a/Cargo.lock b/Cargo.lock index 36da41e0..dd0c0436 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -637,6 +637,7 @@ dependencies = [ "futures-util", "hmac", "http 1.1.0", + "http-body-util", "hyper 1.4.0", "image", "ipaddress", diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index 2f8bce54..45cae73d 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -44,6 +44,7 @@ conduit-service.workspace = true futures-util.workspace = true hmac.workspace = true http.workspace = true +http-body-util.workspace = true hyper.workspace = true image.workspace = true ipaddress.workspace = true diff --git a/src/api/client/unstable.rs b/src/api/client/unstable.rs index 77cac0fa..e39db94e 100644 --- a/src/api/client/unstable.rs +++ b/src/api/client/unstable.rs @@ -1,12 +1,12 @@ use axum_client_ip::InsecureClientIp; -use conduit::{warn, RumaResponse}; +use conduit::warn; use ruma::{ api::client::{error::ErrorKind, membership::mutual_rooms, room::get_summary}, events::room::member::MembershipState, OwnedRoomId, }; -use crate::{services, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma, RumaResponse}; /// # `GET /_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms` /// diff --git a/src/api/mod.rs b/src/api/mod.rs index 8e30a518..0395da7a 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -9,7 +9,8 @@ extern crate conduit_service as service; pub(crate) use conduit::{debug_info, debug_warn, utils, Error, Result}; pub(crate) use service::{pdu::PduEvent, services, user_is_local}; -pub(crate) use crate::router::{Ruma, RumaResponse}; +pub(crate) use self::router::Ruma; +pub use self::router::RumaResponse; conduit::mod_ctor! {} conduit::mod_dtor! {} diff --git a/src/api/router/mod.rs b/src/api/router/mod.rs index 2c439d65..f167a606 100644 --- a/src/api/router/mod.rs +++ b/src/api/router/mod.rs @@ -1,12 +1,12 @@ mod auth; mod handler; mod request; +mod response; use std::{mem, ops::Deref}; use axum::{async_trait, body::Body, extract::FromRequest}; use bytes::{BufMut, BytesMut}; -pub(super) use conduit::error::RumaResponse; use conduit::{debug, debug_warn, trace, warn}; use ruma::{ api::{client::error::ErrorKind, IncomingRequest}, @@ -14,6 +14,7 @@ use ruma::{ }; pub(super) use self::handler::RouterExt; +pub use self::response::RumaResponse; use self::{auth::Auth, request::Request}; use crate::{service::appservice::RegistrationInfo, services, Error, Result}; diff --git a/src/api/router/response.rs b/src/api/router/response.rs new file mode 100644 index 00000000..38e58ba9 --- /dev/null +++ b/src/api/router/response.rs @@ -0,0 +1,22 @@ +use axum::response::{IntoResponse, Response}; +use bytes::BytesMut; +use conduit::Error; +use http::StatusCode; +use http_body_util::Full; +use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse}; + +#[derive(Clone)] +pub struct RumaResponse(pub T); + +impl From for RumaResponse { + fn from(t: Error) -> Self { Self(t.into()) } +} + +impl IntoResponse for RumaResponse { + fn into_response(self) -> Response { + match self.0.try_into_http_response::() { + Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(), + Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + } + } +} diff --git a/src/core/error.rs b/src/core/error.rs index 63729f31..cb74d531 100644 --- a/src/core/error.rs +++ b/src/core/error.rs @@ -1,19 +1,11 @@ use std::{convert::Infallible, fmt}; -use axum::response::{IntoResponse, Response}; use bytes::BytesMut; use http::StatusCode; use http_body_util::Full; use ruma::{ api::{ - client::{ - error::ErrorKind::{ - Forbidden, GuestAccessForbidden, LimitExceeded, MissingToken, NotFound, ThreepidAuthFailed, - ThreepidDenied, TooLarge, Unauthorized, Unknown, UnknownToken, Unrecognized, UserDeactivated, - WrongRoomKeysVersion, - }, - uiaa::{UiaaInfo, UiaaResponse}, - }, + client::uiaa::{UiaaInfo, UiaaResponse}, OutgoingResponse, }, OwnedServerName, @@ -57,6 +49,8 @@ pub enum Error { Path(#[from] axum::extract::rejection::PathRejection), // ruma + #[error("uiaa")] + Uiaa(UiaaInfo), #[error("{0}")] Mxid(#[from] ruma::IdParseError), #[error("{0}: {1}")] @@ -81,8 +75,6 @@ pub enum Error { BadServerResponse(&'static str), #[error("{0}")] Conflict(&'static str), // This is only needed for when a room alias already exists - #[error("uiaa")] - Uiaa(UiaaInfo), // unique / untyped #[error("{0}")] @@ -103,11 +95,10 @@ impl Error { /// Returns the Matrix error code / error kind #[inline] pub fn error_code(&self) -> ruma::api::client::error::ErrorKind { - if let Self::Federation(_, error) = self { - return error.error_kind().unwrap_or_else(|| &Unknown).clone(); - } + use ruma::api::client::error::ErrorKind::Unknown; match self { + Self::Federation(_, err) => err.error_kind().unwrap_or(&Unknown).clone(), Self::BadRequest(kind, _) => kind.clone(), _ => Unknown, } @@ -116,12 +107,8 @@ impl Error { /// Sanitizes public-facing errors that can leak sensitive information. pub fn sanitized_error(&self) -> String { match self { - Self::Database { - .. - } => String::from("Database error occurred."), - Self::Io { - .. - } => String::from("I/O error occurred."), + Self::Database(..) => String::from("Database error occurred."), + Self::Io(..) => String::from("I/O error occurred."), _ => self.to_string(), } } @@ -135,6 +122,88 @@ impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") } } +impl axum::response::IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + let response: UiaaResponse = self.into(); + response.try_into_http_response::().map_or_else( + |_| StatusCode::INTERNAL_SERVER_ERROR.into_response(), + |r| r.map(BytesMut::freeze).map(Full::new).into_response(), + ) + } +} + +impl From for UiaaResponse { + fn from(error: Error) -> Self { + use ruma::api::client::error::{Error as RumaError, ErrorBody, ErrorKind::Unknown}; + + if let Error::Uiaa(uiaainfo) = error { + return Self::AuthResponse(uiaainfo); + } + + let kind = match &error { + Error::Federation(_, ref error) => error.error_kind().unwrap_or(&Unknown), + Error::BadRequest(kind, _) => kind, + _ => &Unknown, + }; + + let status_code = match &error { + Error::Federation(_, ref error) => error.status_code, + Error::BadRequest(ref kind, _) => bad_request_code(kind), + Error::Conflict(_) => StatusCode::CONFLICT, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + + let message = if let Error::Federation(ref origin, ref error) = &error { + format!("Answer from {origin}: {error}") + } else { + format!("{error}") + }; + + let body = ErrorBody::Standard { + kind: kind.clone(), + message, + }; + + Self::MatrixError(RumaError { + status_code, + body, + }) + } +} + +fn bad_request_code(kind: &ruma::api::client::error::ErrorKind) -> StatusCode { + use ruma::api::client::error::ErrorKind::*; + + match kind { + GuestAccessForbidden + | ThreepidAuthFailed + | UserDeactivated + | ThreepidDenied + | WrongRoomKeysVersion { + .. + } + | Forbidden { + .. + } => StatusCode::FORBIDDEN, + + UnknownToken { + .. + } + | MissingToken + | Unauthorized => StatusCode::UNAUTHORIZED, + + LimitExceeded { + .. + } => StatusCode::TOO_MANY_REQUESTS, + + TooLarge => StatusCode::PAYLOAD_TOO_LARGE, + + NotFound | Unrecognized => StatusCode::NOT_FOUND, + + _ => StatusCode::BAD_REQUEST, + } +} + #[inline] pub fn log(e: Error) { error!("{e}"); @@ -146,86 +215,3 @@ pub fn debug_log(e: Error) { debug_error!("{e}"); drop(e); } - -#[derive(Clone)] -pub struct RumaResponse(pub T); - -impl From for RumaResponse { - fn from(t: T) -> Self { Self(t) } -} - -impl From for RumaResponse { - fn from(t: Error) -> Self { t.to_response() } -} - -impl Error { - pub fn to_response(&self) -> RumaResponse { - use ruma::api::client::error::{Error as RumaError, ErrorBody}; - - if let Self::Uiaa(uiaainfo) = self { - return RumaResponse(UiaaResponse::AuthResponse(uiaainfo.clone())); - } - - if let Self::Federation(origin, error) = self { - let mut error = error.clone(); - error.body = ErrorBody::Standard { - kind: error.error_kind().unwrap_or_else(|| &Unknown).clone(), - message: format!("Answer from {origin}: {error}"), - }; - return RumaResponse(UiaaResponse::MatrixError(error)); - } - - let message = format!("{self}"); - let (kind, status_code) = match self { - Self::BadRequest(kind, _) => ( - kind.clone(), - match kind { - WrongRoomKeysVersion { - .. - } - | Forbidden { - .. - } - | GuestAccessForbidden - | ThreepidAuthFailed - | UserDeactivated - | ThreepidDenied => StatusCode::FORBIDDEN, - Unauthorized - | UnknownToken { - .. - } - | MissingToken => StatusCode::UNAUTHORIZED, - NotFound | Unrecognized => StatusCode::NOT_FOUND, - LimitExceeded { - .. - } => StatusCode::TOO_MANY_REQUESTS, - TooLarge => StatusCode::PAYLOAD_TOO_LARGE, - _ => StatusCode::BAD_REQUEST, - }, - ), - Self::Conflict(_) => (Unknown, StatusCode::CONFLICT), - _ => (Unknown, StatusCode::INTERNAL_SERVER_ERROR), - }; - - RumaResponse(UiaaResponse::MatrixError(RumaError { - body: ErrorBody::Standard { - kind, - message, - }, - status_code, - })) - } -} - -impl ::axum::response::IntoResponse for Error { - fn into_response(self) -> ::axum::response::Response { self.to_response().into_response() } -} - -impl IntoResponse for RumaResponse { - fn into_response(self) -> Response { - match self.0.try_into_http_response::() { - Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(), - Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), - } - } -} diff --git a/src/core/mod.rs b/src/core/mod.rs index ec536ee2..de8057fa 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -10,7 +10,7 @@ pub mod utils; pub mod version; pub use config::Config; -pub use error::{Error, RumaResponse}; +pub use error::Error; pub use pdu::{PduBuilder, PduCount, PduEvent}; pub use server::Server; pub use version::version; diff --git a/src/router/request.rs b/src/router/request.rs index 079106ef..567de81f 100644 --- a/src/router/request.rs +++ b/src/router/request.rs @@ -1,13 +1,13 @@ use std::sync::{atomic::Ordering, Arc}; use axum::{extract::State, response::IntoResponse}; -use conduit::{debug_error, debug_warn, defer, Result, RumaResponse, Server}; +use conduit::{debug, debug_error, debug_warn, defer, error, trace, Result, Server}; +use conduit_api::RumaResponse; use http::{Method, StatusCode, Uri}; use ruma::api::client::{ error::{Error as RumaError, ErrorBody, ErrorKind}, uiaa::UiaaResponse, }; -use tracing::{debug, error, trace}; #[tracing::instrument(skip_all, level = "debug")] pub(crate) async fn spawn(