diff --git a/src/api/mod.rs b/src/api/mod.rs index 7fe02cfe..956bcdf7 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,6 +1,6 @@ pub mod client_server; pub mod router; -pub(crate) mod ruma_wrapper; +mod ruma_wrapper; pub mod server_server; extern crate conduit_core as conduit; diff --git a/src/api/ruma_wrapper/auth.rs b/src/api/ruma_wrapper/auth.rs new file mode 100644 index 00000000..282248c1 --- /dev/null +++ b/src/api/ruma_wrapper/auth.rs @@ -0,0 +1,252 @@ +use std::collections::BTreeMap; + +use axum::RequestPartsExt; +use axum_extra::{headers::Authorization, typed_header::TypedHeaderRejectionReason, TypedHeader}; +use http::uri::PathAndQuery; +use ruma::{ + api::{client::error::ErrorKind, AuthScheme, IncomingRequest}, + CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, +}; +use tracing::warn; + +use super::{request::Request, xmatrix::XMatrix}; +use crate::{service::appservice::RegistrationInfo, services, Error, Result}; + +enum Token { + Appservice(Box), + User((OwnedUserId, OwnedDeviceId)), + Invalid, + None, +} + +pub(super) struct Auth { + pub(super) sender_user: Option, + pub(super) sender_device: Option, + pub(super) origin: Option, + pub(super) appservice_info: Option, +} + +pub(super) async fn auth(request: &mut Request) -> Result +where + T: IncomingRequest, +{ + let metadata = T::METADATA; + let token = match &request.auth { + Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()), + None => request.query.access_token.as_deref(), + }; + + let token = if let Some(token) = token { + if let Some(reg_info) = services().appservice.find_from_token(token).await { + Token::Appservice(Box::new(reg_info)) + } else if let Some((user_id, device_id)) = services().users.find_from_token(token)? { + Token::User((user_id, OwnedDeviceId::from(device_id))) + } else { + Token::Invalid + } + } else { + Token::None + }; + + if metadata.authentication == AuthScheme::None { + match request.parts.uri.path() { + // TODO: can we check this better? + "/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => { + if !services() + .globals + .config + .allow_public_room_directory_without_auth + { + match token { + Token::Appservice(_) | Token::User(_) => { + // we should have validated the token above + // already + }, + Token::None | Token::Invalid => { + return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing or invalid access token.")); + }, + } + } + }, + _ => {}, + }; + } + + match (metadata.authentication, token) { + (_, Token::Invalid) => Err(Error::BadRequest( + ErrorKind::UnknownToken { + soft_logout: false, + }, + "Unknown access token.", + )), + (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(request, info)?), + (AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => { + Ok(Auth { + sender_user: None, + sender_device: None, + origin: None, + appservice_info: Some(*info), + }) + }, + (AuthScheme::AccessToken, Token::None) => { + Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")) + }, + ( + AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None, + Token::User((user_id, device_id)), + ) => Ok(Auth { + sender_user: Some(user_id), + sender_device: Some(device_id), + origin: None, + appservice_info: None, + }), + (AuthScheme::ServerSignatures, Token::None) => Ok(auth_server(request).await?), + (AuthScheme::None | AuthScheme::AppserviceToken | AuthScheme::AccessTokenOptional, Token::None) => Ok(Auth { + sender_user: None, + sender_device: None, + origin: None, + appservice_info: None, + }), + (AuthScheme::ServerSignatures, Token::Appservice(_) | Token::User(_)) => Err(Error::BadRequest( + ErrorKind::Unauthorized, + "Only server signatures should be used on this endpoint.", + )), + (AuthScheme::AppserviceToken, Token::User(_)) => Err(Error::BadRequest( + ErrorKind::Unauthorized, + "Only appservice access tokens should be used on this endpoint.", + )), + } +} + +fn auth_appservice(request: &mut Request, info: Box) -> Result { + let user_id = request + .query + .user_id + .clone() + .map_or_else( + || { + UserId::parse_with_server_name( + info.registration.sender_localpart.as_str(), + services().globals.server_name(), + ) + }, + UserId::parse, + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; + + if !info.is_user_match(&user_id) { + return Err(Error::BadRequest(ErrorKind::Exclusive, "User is not in namespace.")); + } + + if !services().users.exists(&user_id)? { + return Err(Error::BadRequest(ErrorKind::forbidden(), "User does not exist.")); + } + + Ok(Auth { + sender_user: Some(user_id), + sender_device: None, + origin: None, + appservice_info: Some(*info), + }) +} + +async fn auth_server(request: &mut Request) -> Result { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let TypedHeader(Authorization(x_matrix)) = request + .parts + .extract::>>() + .await + .map_err(|e| { + warn!("Missing or invalid Authorization header: {e}"); + + let msg = match e.reason() { + TypedHeaderRejectionReason::Missing => "Missing Authorization header.", + TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.", + _ => "Unknown header-related error", + }; + + Error::BadRequest(ErrorKind::forbidden(), msg) + })?; + + let origin_signatures = BTreeMap::from_iter([(x_matrix.key.clone(), CanonicalJsonValue::String(x_matrix.sig))]); + + let signatures = BTreeMap::from_iter([( + x_matrix.origin.as_str().to_owned(), + CanonicalJsonValue::Object(origin_signatures), + )]); + + let server_destination = services().globals.server_name().as_str().to_owned(); + + if let Some(destination) = x_matrix.destination.as_ref() { + if destination != &server_destination { + return Err(Error::BadRequest(ErrorKind::forbidden(), "Invalid authorization.")); + } + } + + let signature_uri = CanonicalJsonValue::String( + request + .parts + .uri + .path_and_query() + .unwrap_or(&PathAndQuery::from_static("/")) + .to_string(), + ); + + let mut request_map = BTreeMap::from_iter([ + ( + "method".to_owned(), + CanonicalJsonValue::String(request.parts.method.to_string()), + ), + ("uri".to_owned(), signature_uri), + ( + "origin".to_owned(), + CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()), + ), + ("destination".to_owned(), CanonicalJsonValue::String(server_destination)), + ("signatures".to_owned(), CanonicalJsonValue::Object(signatures)), + ]); + + if let Some(json_body) = &request.json { + request_map.insert("content".to_owned(), json_body.clone()); + }; + + let keys_result = services() + .rooms + .event_handler + .fetch_signing_keys_for_server(&x_matrix.origin, vec![x_matrix.key.clone()]) + .await; + + let keys = keys_result.map_err(|e| { + warn!("Failed to fetch signing keys: {e}"); + Error::BadRequest(ErrorKind::forbidden(), "Failed to fetch signing keys.") + })?; + + let pub_key_map = BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]); + + match ruma::signatures::verify_json(&pub_key_map, &request_map) { + Ok(()) => Ok(Auth { + sender_user: None, + sender_device: None, + origin: Some(x_matrix.origin), + appservice_info: None, + }), + Err(e) => { + warn!("Failed to verify json request from {}: {e}\n{request_map:?}", x_matrix.origin); + + if request.parts.uri.to_string().contains('@') { + warn!( + "Request uri contained '@' character. Make sure your reverse proxy gives Conduit the raw uri \ + (apache: use nocanon)" + ); + } + + Err(Error::BadRequest( + ErrorKind::forbidden(), + "Failed to verify X-Matrix signatures.", + )) + }, + } +} diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs deleted file mode 100644 index 9bf99f63..00000000 --- a/src/api/ruma_wrapper/axum.rs +++ /dev/null @@ -1,330 +0,0 @@ -use std::{collections::BTreeMap, str}; - -use axum::{ - async_trait, - extract::{FromRequest, Path}, - RequestExt, RequestPartsExt, -}; -use axum_extra::{ - headers::{authorization::Bearer, Authorization}, - typed_header::TypedHeaderRejectionReason, - TypedHeader, -}; -use bytes::{BufMut, BytesMut}; -use conduit::debug_warn; -use http::uri::PathAndQuery; -use hyper::Request; -use ruma::{ - api::{client::error::ErrorKind, AuthScheme, IncomingRequest}, - CanonicalJsonValue, OwnedDeviceId, OwnedUserId, UserId, -}; -use serde::Deserialize; -use tracing::{debug, error, trace, warn}; - -use super::{xmatrix::XMatrix, Ruma}; -use crate::{service::appservice::RegistrationInfo, services, Error, Result}; - -enum Token { - Appservice(Box), - User((OwnedUserId, OwnedDeviceId)), - Invalid, - None, -} - -#[derive(Deserialize)] -struct QueryParams { - access_token: Option, - user_id: Option, -} - -#[async_trait] -impl FromRequest for Ruma -where - T: IncomingRequest, -{ - type Rejection = Error; - - #[allow(unused_qualifications)] // async traits - async fn from_request(req: Request, _state: &S) -> Result { - let limited = req.with_limited_body(); - let (mut parts, body) = limited.into_parts(); - let mut body = axum::body::to_bytes( - body, - services() - .globals - .config - .max_request_size - .try_into() - .expect("failed to convert max request size"), - ) - .await - .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; - - let metadata = T::METADATA; - let auth_header: Option>> = parts.extract().await?; - let path_params: Path> = parts.extract().await?; - - let query = parts.uri.query().unwrap_or_default(); - let query_params: QueryParams = match serde_html_form::from_str(query) { - Ok(params) => params, - Err(e) => { - error!(%query, "Failed to deserialize query parameters: {e}"); - return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to read query parameters")); - }, - }; - - let token = match &auth_header { - Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()), - None => query_params.access_token.as_deref(), - }; - - let token = if let Some(token) = token { - if let Some(reg_info) = services().appservice.find_from_token(token).await { - Token::Appservice(Box::new(reg_info)) - } else if let Some((user_id, device_id)) = services().users.find_from_token(token)? { - Token::User((user_id, OwnedDeviceId::from(device_id))) - } else { - Token::Invalid - } - } else { - Token::None - }; - - if metadata.authentication == AuthScheme::None { - match parts.uri.path() { - // TODO: can we check this better? - "/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => { - if !services() - .globals - .config - .allow_public_room_directory_without_auth - { - match token { - Token::Appservice(_) | Token::User(_) => { - // we should have validated the token above - // already - }, - Token::None | Token::Invalid => { - return Err(Error::BadRequest( - ErrorKind::MissingToken, - "Missing or invalid access token.", - )); - }, - } - } - }, - _ => {}, - }; - } - - let mut json_body = serde_json::from_slice::(&body).ok(); - - let (sender_user, sender_device, origin, appservice_info) = match (metadata.authentication, token) { - (_, Token::Invalid) => { - return Err(Error::BadRequest( - ErrorKind::UnknownToken { - soft_logout: false, - }, - "Unknown access token.", - )) - }, - (AuthScheme::AccessToken, Token::Appservice(info)) => { - let user_id = query_params - .user_id - .map_or_else( - || { - UserId::parse_with_server_name( - info.registration.sender_localpart.as_str(), - services().globals.server_name(), - ) - }, - UserId::parse, - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; - - if !info.is_user_match(&user_id) { - return Err(Error::BadRequest(ErrorKind::Exclusive, "User is not in namespace.")); - } - - if !services().users.exists(&user_id)? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "User does not exist.")); - } - - (Some(user_id), None, None, Some(*info)) - }, - ( - AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, - Token::Appservice(info), - ) => (None, None, None, Some(*info)), - (AuthScheme::AccessToken, Token::None) => { - return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")); - }, - ( - AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None, - Token::User((user_id, device_id)), - ) => (Some(user_id), Some(device_id), None, None), - (AuthScheme::ServerSignatures, Token::None) => { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let TypedHeader(Authorization(x_matrix)) = parts - .extract::>>() - .await - .map_err(|e| { - warn!("Missing or invalid Authorization header: {e}"); - - let msg = match e.reason() { - TypedHeaderRejectionReason::Missing => "Missing Authorization header.", - TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.", - _ => "Unknown header-related error", - }; - - Error::BadRequest(ErrorKind::forbidden(), msg) - })?; - - let origin_signatures = - BTreeMap::from_iter([(x_matrix.key.clone(), CanonicalJsonValue::String(x_matrix.sig))]); - - let signatures = BTreeMap::from_iter([( - x_matrix.origin.as_str().to_owned(), - CanonicalJsonValue::Object(origin_signatures), - )]); - - let server_destination = services().globals.server_name().as_str().to_owned(); - - if let Some(destination) = x_matrix.destination.as_ref() { - if destination != &server_destination { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Invalid authorization.")); - } - } - - let signature_uri = CanonicalJsonValue::String( - parts - .uri - .path_and_query() - .unwrap_or(&PathAndQuery::from_static("/")) - .to_string(), - ); - - let mut request_map = BTreeMap::from_iter([ - ("method".to_owned(), CanonicalJsonValue::String(parts.method.to_string())), - ("uri".to_owned(), signature_uri), - ( - "origin".to_owned(), - CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()), - ), - ("destination".to_owned(), CanonicalJsonValue::String(server_destination)), - ("signatures".to_owned(), CanonicalJsonValue::Object(signatures)), - ]); - - if let Some(json_body) = &json_body { - request_map.insert("content".to_owned(), json_body.clone()); - }; - - let keys_result = services() - .rooms - .event_handler - .fetch_signing_keys_for_server(&x_matrix.origin, vec![x_matrix.key.clone()]) - .await; - - let keys = keys_result.map_err(|e| { - warn!("Failed to fetch signing keys: {e}"); - Error::BadRequest(ErrorKind::forbidden(), "Failed to fetch signing keys.") - })?; - - let pub_key_map = BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]); - - match ruma::signatures::verify_json(&pub_key_map, &request_map) { - Ok(()) => (None, None, Some(x_matrix.origin), None), - Err(e) => { - warn!("Failed to verify json request from {}: {e}\n{request_map:?}", x_matrix.origin); - - if parts.uri.to_string().contains('@') { - warn!( - "Request uri contained '@' character. Make sure your reverse proxy gives Conduit the \ - raw uri (apache: use nocanon)" - ); - } - - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Failed to verify X-Matrix signatures.", - )); - }, - } - }, - (AuthScheme::None | AuthScheme::AppserviceToken | AuthScheme::AccessTokenOptional, Token::None) => { - (None, None, None, None) - }, - (AuthScheme::ServerSignatures, Token::Appservice(_) | Token::User(_)) => { - return Err(Error::BadRequest( - ErrorKind::Unauthorized, - "Only server signatures should be used on this endpoint.", - )); - }, - (AuthScheme::AppserviceToken, Token::User(_)) => { - return Err(Error::BadRequest( - ErrorKind::Unauthorized, - "Only appservice access tokens should be used on this endpoint.", - )); - }, - }; - - let mut http_request = Request::builder().uri(parts.uri).method(parts.method); - *http_request.headers_mut().unwrap() = parts.headers; - - if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { - let user_id = sender_user.clone().unwrap_or_else(|| { - UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid") - }); - - let uiaa_request = json_body - .get("auth") - .and_then(|auth| auth.as_object()) - .and_then(|auth| auth.get("session")) - .and_then(|session| session.as_str()) - .and_then(|session| { - services().uiaa.get_uiaa_request( - &user_id, - &sender_device.clone().unwrap_or_else(|| "".into()), - session, - ) - }); - - if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { - for (key, value) in initial_request { - json_body.entry(key).or_insert(value); - } - } - - let mut buf = BytesMut::new().writer(); - serde_json::to_writer(&mut buf, json_body).expect("value serialization can't fail"); - body = buf.into_inner().freeze(); - } - - let http_request = http_request.body(&*body).unwrap(); - debug!( - "{:?} {:?} {:?}", - http_request.method(), - http_request.uri(), - http_request.headers() - ); - - trace!("{:?} {:?} {:?}", http_request.method(), http_request.uri(), json_body); - let body = T::try_from_http_request(http_request, &path_params).map_err(|e| { - warn!("try_from_http_request failed: {e:?}",); - debug_warn!("JSON body: {:?}", json_body); - Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") - })?; - - Ok(Ruma { - body, - sender_user, - sender_device, - origin, - json_body, - appservice_info, - }) - } -} diff --git a/src/api/ruma_wrapper/mod.rs b/src/api/ruma_wrapper/mod.rs index a2a3fe86..1e12995f 100644 --- a/src/api/ruma_wrapper/mod.rs +++ b/src/api/ruma_wrapper/mod.rs @@ -1,4 +1,5 @@ -pub(crate) mod axum; +mod auth; +mod request; mod xmatrix; use std::ops::Deref; diff --git a/src/api/ruma_wrapper/request.rs b/src/api/ruma_wrapper/request.rs new file mode 100644 index 00000000..4a60bbd2 --- /dev/null +++ b/src/api/ruma_wrapper/request.rs @@ -0,0 +1,149 @@ +use std::{mem, str}; + +use axum::{ + async_trait, + extract::{FromRequest, Path}, + RequestExt, RequestPartsExt, +}; +use axum_extra::{ + headers::{authorization::Bearer, Authorization}, + TypedHeader, +}; +use bytes::{BufMut, Bytes, BytesMut}; +use conduit::debug_warn; +use http::request::Parts; +use ruma::{ + api::{client::error::ErrorKind, IncomingRequest}, + CanonicalJsonValue, UserId, +}; +use serde::Deserialize; +use tracing::{debug, trace, warn}; + +use super::{auth, auth::Auth, Ruma}; +use crate::{services, Error, Result}; + +#[derive(Deserialize)] +pub(super) struct QueryParams { + pub(super) access_token: Option, + pub(super) user_id: Option, +} + +pub(super) struct Request { + pub(super) auth: Option>>, + pub(super) path: Path>, + pub(super) query: QueryParams, + pub(super) json: Option, + pub(super) body: Bytes, + pub(super) parts: Parts, +} + +#[async_trait] +impl FromRequest for Ruma +where + T: IncomingRequest, +{ + type Rejection = Error; + + async fn from_request(request: hyper::Request, _state: &S) -> Result { + let mut request: Request = extract(request).await?; + let auth: Auth = auth::auth::(&mut request).await?; + let body = make_body::(&mut request, &auth)?; + Ok(Ruma { + body, + sender_user: auth.sender_user, + sender_device: auth.sender_device, + origin: auth.origin, + json_body: request.json, + appservice_info: auth.appservice_info, + }) + } +} + +fn make_body(request: &mut Request, auth: &Auth) -> Result +where + T: IncomingRequest, +{ + let body = if let Some(CanonicalJsonValue::Object(json_body)) = &mut request.json { + let user_id = auth.sender_user.clone().unwrap_or_else(|| { + UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid") + }); + + let uiaa_request = json_body + .get("auth") + .and_then(|auth| auth.as_object()) + .and_then(|auth| auth.get("session")) + .and_then(|session| session.as_str()) + .and_then(|session| { + services().uiaa.get_uiaa_request( + &user_id, + &auth.sender_device.clone().unwrap_or_else(|| "".into()), + session, + ) + }); + + if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { + for (key, value) in initial_request { + json_body.entry(key).or_insert(value); + } + } + + let mut buf = BytesMut::new().writer(); + serde_json::to_writer(&mut buf, &request.json).expect("value serialization can't fail"); + buf.into_inner().freeze() + } else { + mem::take(&mut request.body) + }; + + let mut http_request = hyper::Request::builder() + .uri(request.parts.uri.clone()) + .method(request.parts.method.clone()); + *http_request.headers_mut().unwrap() = request.parts.headers.clone(); + let http_request = http_request.body(body).unwrap(); + debug!( + "{:?} {:?} {:?}", + http_request.method(), + http_request.uri(), + http_request.headers() + ); + + trace!("{:?} {:?} {:?}", http_request.method(), http_request.uri(), request.json); + let body = T::try_from_http_request(http_request, &request.path).map_err(|e| { + warn!("try_from_http_request failed: {e:?}",); + debug_warn!("JSON body: {:?}", request.json); + Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") + })?; + + Ok(body) +} + +async fn extract(request: hyper::Request) -> Result { + let limited = request.with_limited_body(); + let (mut parts, body) = limited.into_parts(); + + let auth = parts.extract().await?; + let path = parts.extract().await?; + let query = serde_html_form::from_str(parts.uri.query().unwrap_or_default()) + .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to read query parameters"))?; + + let max_body_size = services() + .globals + .config + .max_request_size + .try_into() + .expect("failed to convert max request size"); + + let body = axum::body::to_bytes(body, max_body_size) + .await + .map_err(|_| Error::BadRequest(ErrorKind::TooLarge, "Request body too large"))?; + + let json = serde_json::from_slice::(&body).ok(); + + Ok(Request { + auth, + path, + query, + json, + body, + parts, + }) +}