Upgrade axum to 0.6

This commit is contained in:
Jonas Platte 2023-06-29 11:20:52 +02:00
parent 6a6f8e80f1
commit 0ded637b4a
No known key found for this signature in database
GPG key ID: CC154DE0E30B7C67
4 changed files with 130 additions and 72 deletions

View file

@ -3,18 +3,16 @@ use std::{collections::BTreeMap, iter::FromIterator, str};
use axum::{
async_trait,
body::{Full, HttpBody},
extract::{
rejection::TypedHeaderRejectionReason, FromRequest, Path, RequestParts, TypedHeader,
},
extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader},
headers::{
authorization::{Bearer, Credentials},
Authorization,
},
response::{IntoResponse, Response},
BoxError,
BoxError, RequestExt, RequestPartsExt,
};
use bytes::{BufMut, Bytes, BytesMut};
use http::StatusCode;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use http::{Request, StatusCode};
use ruma::{
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId,
@ -26,27 +24,44 @@ use super::{Ruma, RumaResponse};
use crate::{services, Error, Result};
#[async_trait]
impl<T, B> FromRequest<B> for Ruma<T>
impl<T, S, B> FromRequest<S, B> for Ruma<T>
where
T: IncomingRequest,
B: HttpBody + Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
type Rejection = Error;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
#[derive(Deserialize)]
struct QueryParams {
access_token: Option<String>,
user_id: Option<String>,
}
let metadata = T::METADATA;
let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?;
let path_params = Path::<Vec<String>>::from_request(req).await?;
let (mut parts, mut body) = match req.with_limited_body() {
Ok(limited_req) => {
let (parts, body) = limited_req.into_parts();
let body = to_bytes(body)
.await
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
(parts, body)
}
Err(original_req) => {
let (parts, body) = original_req.into_parts();
let body = to_bytes(body)
.await
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
(parts, body)
}
};
let query = req.uri().query().unwrap_or_default();
let metadata = T::METADATA;
let auth_header: Option<TypedHeader<Authorization<Bearer>>> = parts.extract().await?;
let path_params: Path<Vec<String>> = parts.extract().await?;
let query = parts.uri.query().unwrap_or_default();
let query_params: QueryParams = match serde_html_form::from_str(query) {
Ok(params) => params,
Err(e) => {
@ -63,10 +78,6 @@ where
None => query_params.access_token.as_deref(),
};
let mut body = Bytes::from_request(req)
.await
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
let appservices = services().appservice.all().unwrap();
@ -138,24 +149,24 @@ where
}
}
AuthScheme::ServerSignatures => {
let TypedHeader(Authorization(x_matrix)) =
TypedHeader::<Authorization<XMatrix>>::from_request(req)
.await
.map_err(|e| {
warn!("Missing or invalid Authorization header: {}", e);
let TypedHeader(Authorization(x_matrix)) = parts
.extract::<TypedHeader<Authorization<XMatrix>>>()
.await
.map_err(|e| {
warn!("Missing or invalid Authorization header: {}", e);
let msg = match e.reason() {
TypedHeaderRejectionReason::Missing => {
"Missing Authorization header."
}
TypedHeaderRejectionReason::Error(_) => {
"Invalid X-Matrix signatures."
}
_ => "Unknown header-related error",
};
let msg = match e.reason() {
TypedHeaderRejectionReason::Missing => {
"Missing Authorization header."
}
TypedHeaderRejectionReason::Error(_) => {
"Invalid X-Matrix signatures."
}
_ => "Unknown header-related error",
};
Error::BadRequest(ErrorKind::Forbidden, msg)
})?;
Error::BadRequest(ErrorKind::Forbidden, msg)
})?;
let origin_signatures = BTreeMap::from_iter([(
x_matrix.key.clone(),
@ -170,11 +181,11 @@ where
let mut request_map = BTreeMap::from_iter([
(
"method".to_owned(),
CanonicalJsonValue::String(req.method().to_string()),
CanonicalJsonValue::String(parts.method.to_string()),
),
(
"uri".to_owned(),
CanonicalJsonValue::String(req.uri().to_string()),
CanonicalJsonValue::String(parts.uri.to_string()),
),
(
"origin".to_owned(),
@ -224,7 +235,7 @@ where
x_matrix.origin, e, request_map
);
if req.uri().to_string().contains('@') {
if parts.uri.to_string().contains('@') {
warn!(
"Request uri contained '@' character. Make sure your \
reverse proxy gives Conduit the raw uri (apache: use \
@ -243,8 +254,8 @@ where
}
};
let mut http_request = http::Request::builder().uri(req.uri()).method(req.method());
*http_request.headers_mut().unwrap() = req.headers().clone();
let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method);
*http_request.headers_mut().unwrap() = parts.headers;
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
let user_id = sender_user.clone().unwrap_or_else(|| {
@ -364,3 +375,55 @@ impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
}
}
}
// copied from hyper under the following license:
// Copyright (c) 2014-2021 Sean McArthur
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
pub(crate) async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error>
where
T: HttpBody,
{
futures_util::pin_mut!(body);
// If there's only 1 chunk, we can just return Buf::to_bytes()
let mut first = if let Some(buf) = body.data().await {
buf?
} else {
return Ok(Bytes::new());
};
let second = if let Some(buf) = body.data().await {
buf?
} else {
return Ok(first.copy_to_bytes(first.remaining()));
};
// With more than 1 buf, we gotta flatten into a Vec first.
let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize;
let mut vec = Vec::with_capacity(cap);
vec.put(first);
vec.put(second);
while let Some(buf) = body.data().await {
vec.put(buf?);
}
Ok(vec.into())
}