use std::sync::{atomic::Ordering, Arc}; use axum::{extract::State, response::IntoResponse}; use conduit::{debug_error, debug_warn, defer, Result, RumaResponse, Server}; use http::{Method, StatusCode, Uri}; use ruma::api::client::{ error::{Error as RumaError, ErrorBody, ErrorKind}, uiaa::UiaaResponse, }; use tracing::{debug, error, trace}; #[tracing::instrument(skip_all)] pub(crate) async fn spawn( State(server): State>, req: http::Request, next: axum::middleware::Next, ) -> Result { if server.stopping.load(Ordering::Relaxed) { debug_warn!("unavailable pending shutdown"); return Err(StatusCode::SERVICE_UNAVAILABLE); } let active = server.requests_spawn_active.fetch_add(1, Ordering::Relaxed); trace!(active, "enter"); defer! {{ let active = server.requests_spawn_active.fetch_sub(1, Ordering::Relaxed); let finished = server.requests_spawn_finished.fetch_add(1, Ordering::Relaxed); trace!(active, finished, "leave"); }}; let fut = next.run(req); let task = server.runtime().spawn(fut); task.await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } #[tracing::instrument(skip_all, name = "handle")] pub(crate) async fn handle( State(server): State>, req: http::Request, next: axum::middleware::Next, ) -> Result { if server.stopping.load(Ordering::Relaxed) { debug_warn!( method = %req.method(), uri = %req.uri(), "unavailable pending shutdown" ); return Err(StatusCode::SERVICE_UNAVAILABLE); } let active = server .requests_handle_active .fetch_add(1, Ordering::Relaxed); trace!(active, "enter"); defer! {{ let active = server.requests_handle_active.fetch_sub(1, Ordering::Relaxed); let finished = server.requests_handle_finished.fetch_add(1, Ordering::Relaxed); trace!(active, finished, "leave"); }}; let method = req.method().clone(); let uri = req.uri().clone(); let result = next.run(req).await; handle_result(&method, &uri, result) } fn handle_result( method: &Method, uri: &Uri, result: axum::response::Response, ) -> Result { handle_result_log(method, uri, &result); match result.status() { StatusCode::METHOD_NOT_ALLOWED => handle_result_403(method, uri, &result), _ => Ok(result), } } #[allow(clippy::unnecessary_wraps)] fn handle_result_403( _method: &Method, _uri: &Uri, result: &axum::response::Response, ) -> Result { let error = UiaaResponse::MatrixError(RumaError { status_code: result.status(), body: ErrorBody::Standard { kind: ErrorKind::Unrecognized, message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(), }, }); Ok(RumaResponse(error).into_response()) } fn handle_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) { let status = result.status(); let reason = status.canonical_reason().unwrap_or("Unknown Reason"); let code = status.as_u16(); if status.is_server_error() { error!(method = ?method, uri = ?uri, "{code} {reason}"); } else if status.is_client_error() { debug_error!(method = ?method, uri = ?uri, "{code} {reason}"); } else if status.is_redirection() { debug!(method = ?method, uri = ?uri, "{code} {reason}"); } else { trace!(method = ?method, uri = ?uri, "{code} {reason}"); } }