From a43c78e55f2127cac83a8ebb525cf69eb692b809 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 8 Jul 2024 18:54:42 +0000 Subject: [PATCH] add RumaError to Error; encapsulate RumaResponse in api Signed-off-by: Jason Volk --- src/api/mod.rs | 3 +-- src/api/router/mod.rs | 3 +-- src/api/router/response.rs | 3 +-- src/core/error.rs | 55 +++++++++++++++++++++++++------------- src/router/request.rs | 16 +++++------ 5 files changed, 45 insertions(+), 35 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index 0395da7a..6adf2d39 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -9,8 +9,7 @@ extern crate conduit_service as service; pub(crate) use conduit::{debug_info, debug_warn, utils, Error, Result}; pub(crate) use service::{pdu::PduEvent, services, user_is_local}; -pub(crate) use self::router::Ruma; -pub use self::router::RumaResponse; +pub(crate) use self::router::{Ruma, RumaResponse}; conduit::mod_ctor! {} conduit::mod_dtor! {} diff --git a/src/api/router/mod.rs b/src/api/router/mod.rs index f167a606..c3e08c5b 100644 --- a/src/api/router/mod.rs +++ b/src/api/router/mod.rs @@ -13,9 +13,8 @@ use ruma::{ CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, }; -pub(super) use self::handler::RouterExt; -pub use self::response::RumaResponse; use self::{auth::Auth, request::Request}; +pub(super) use self::{handler::RouterExt, response::RumaResponse}; use crate::{service::appservice::RegistrationInfo, services, Error, Result}; /// Extractor for Ruma request structs diff --git a/src/api/router/response.rs b/src/api/router/response.rs index 38e58ba9..9b67f37b 100644 --- a/src/api/router/response.rs +++ b/src/api/router/response.rs @@ -5,8 +5,7 @@ use http::StatusCode; use http_body_util::Full; use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse}; -#[derive(Clone)] -pub struct RumaResponse(pub T); +pub(crate) struct RumaResponse(pub(crate) T); impl From for RumaResponse { fn from(t: Error) -> Self { Self(t.into()) } diff --git a/src/core/error.rs b/src/core/error.rs index cb74d531..5f4d4798 100644 --- a/src/core/error.rs +++ b/src/core/error.rs @@ -4,17 +4,13 @@ use bytes::BytesMut; use http::StatusCode; use http_body_util::Full; use ruma::{ - api::{ - client::uiaa::{UiaaInfo, UiaaResponse}, - OutgoingResponse, - }, + api::{client::uiaa::UiaaResponse, OutgoingResponse}, OwnedServerName, }; -use thiserror::Error; use crate::{debug_error, error}; -#[derive(Error)] +#[derive(thiserror::Error)] pub enum Error { // std #[error("{0}")] @@ -47,10 +43,16 @@ pub enum Error { Extension(#[from] axum::extract::rejection::ExtensionRejection), #[error("{0}")] Path(#[from] axum::extract::rejection::PathRejection), + #[error("{0}")] + Http(#[from] http::Error), // ruma + #[error("{0}")] + IntoHttpError(#[from] ruma::api::error::IntoHttpError), + #[error("{0}")] + RumaError(#[from] ruma::api::client::error::Error), #[error("uiaa")] - Uiaa(UiaaInfo), + Uiaa(ruma::api::client::uiaa::UiaaInfo), #[error("{0}")] Mxid(#[from] ruma::IdParseError), #[error("{0}: {1}")] @@ -98,7 +100,7 @@ impl Error { use ruma::api::client::error::ErrorKind::Unknown; match self { - Self::Federation(_, err) => err.error_kind().unwrap_or(&Unknown).clone(), + Self::Federation(_, error) => ruma_error_kind(error).clone(), Self::BadRequest(kind, _) => kind.clone(), _ => Unknown, } @@ -134,37 +136,35 @@ impl axum::response::IntoResponse for Error { impl From for UiaaResponse { fn from(error: Error) -> Self { - use ruma::api::client::error::{Error as RumaError, ErrorBody, ErrorKind::Unknown}; - if let Error::Uiaa(uiaainfo) = error { return Self::AuthResponse(uiaainfo); } let kind = match &error { - Error::Federation(_, ref error) => error.error_kind().unwrap_or(&Unknown), + Error::Federation(_, ref error) | Error::RumaError(ref error) => ruma_error_kind(error), Error::BadRequest(kind, _) => kind, - _ => &Unknown, + _ => &ruma::api::client::error::ErrorKind::Unknown, }; let status_code = match &error { - Error::Federation(_, ref error) => error.status_code, + Error::Federation(_, ref error) | Error::RumaError(ref error) => error.status_code, Error::BadRequest(ref kind, _) => bad_request_code(kind), Error::Conflict(_) => StatusCode::CONFLICT, _ => StatusCode::INTERNAL_SERVER_ERROR, }; - let message = if let Error::Federation(ref origin, ref error) = &error { - format!("Answer from {origin}: {error}") - } else { - format!("{error}") + let message = match &error { + Error::Federation(ref origin, ref error) => format!("Answer from {origin}: {error}"), + Error::RumaError(ref error) => ruma_error_message(error), + _ => format!("{error}"), }; - let body = ErrorBody::Standard { + let body = ruma::api::client::error::ErrorBody::Standard { kind: kind.clone(), message, }; - Self::MatrixError(RumaError { + Self::MatrixError(ruma::api::client::error::Error { status_code, body, }) @@ -204,6 +204,23 @@ fn bad_request_code(kind: &ruma::api::client::error::ErrorKind) -> StatusCode { } } +fn ruma_error_message(error: &ruma::api::client::error::Error) -> String { + if let ruma::api::client::error::ErrorBody::Standard { + message, + .. + } = &error.body + { + return message.to_string(); + } + + format!("{error}") +} + +fn ruma_error_kind(e: &ruma::api::client::error::Error) -> &ruma::api::client::error::ErrorKind { + e.error_kind() + .unwrap_or(&ruma::api::client::error::ErrorKind::Unknown) +} + #[inline] pub fn log(e: Error) { error!("{e}"); diff --git a/src/router/request.rs b/src/router/request.rs index 567de81f..9256fb9c 100644 --- a/src/router/request.rs +++ b/src/router/request.rs @@ -1,13 +1,9 @@ use std::sync::{atomic::Ordering, Arc}; use axum::{extract::State, response::IntoResponse}; -use conduit::{debug, debug_error, debug_warn, defer, error, trace, Result, Server}; -use conduit_api::RumaResponse; +use conduit::{debug, debug_error, debug_warn, defer, error, trace, Error, Result, Server}; use http::{Method, StatusCode, Uri}; -use ruma::api::client::{ - error::{Error as RumaError, ErrorBody, ErrorKind}, - uiaa::UiaaResponse, -}; +use ruma::api::client::error::{Error as RumaError, ErrorBody, ErrorKind}; #[tracing::instrument(skip_all, level = "debug")] pub(crate) async fn spawn( @@ -66,15 +62,15 @@ fn handle_result( ) -> Result { handle_result_log(method, uri, &result); match result.status() { - StatusCode::METHOD_NOT_ALLOWED => handle_result_403(method, uri, &result), + StatusCode::METHOD_NOT_ALLOWED => handle_result_405(method, uri, &result), _ => Ok(result), } } -fn handle_result_403( +fn handle_result_405( _method: &Method, _uri: &Uri, result: &axum::response::Response, ) -> Result { - let error = UiaaResponse::MatrixError(RumaError { + let error = Error::RumaError(RumaError { status_code: result.status(), body: ErrorBody::Standard { kind: ErrorKind::Unrecognized, @@ -82,7 +78,7 @@ fn handle_result_403( }, }); - Ok(RumaResponse(error).into_response()) + Ok(error.into_response()) } fn handle_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) {