move non-generic code out of generic; reduce codegen
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
999d731a65
commit
86694f2d1d
2 changed files with 90 additions and 80 deletions
|
@ -66,6 +66,15 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> Deref for Args<T>
|
||||||
|
where
|
||||||
|
T: IncomingRequest + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
type Target = T;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target { &self.body }
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<T> FromRequest<State, Body> for Args<T>
|
impl<T> FromRequest<State, Body> for Args<T>
|
||||||
where
|
where
|
||||||
|
@ -78,7 +87,7 @@ where
|
||||||
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok();
|
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok();
|
||||||
let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?;
|
let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
body: make_body::<T>(services, &mut request, &mut json_body, &auth)?,
|
body: make_body::<T>(services, &mut request, json_body.as_mut(), &auth)?,
|
||||||
origin: auth.origin,
|
origin: auth.origin,
|
||||||
sender_user: auth.sender_user,
|
sender_user: auth.sender_user,
|
||||||
sender_device: auth.sender_device,
|
sender_device: auth.sender_device,
|
||||||
|
@ -88,20 +97,11 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Deref for Args<T>
|
|
||||||
where
|
|
||||||
T: IncomingRequest + Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
type Target = T;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target { &self.body }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn make_body<T>(
|
fn make_body<T>(
|
||||||
services: &Services, request: &mut Request, json_body: &mut Option<CanonicalJsonValue>, auth: &Auth,
|
services: &Services, request: &mut Request, json_body: Option<&mut CanonicalJsonValue>, auth: &Auth,
|
||||||
) -> Result<T>
|
) -> Result<T>
|
||||||
where
|
where
|
||||||
T: IncomingRequest + Send + Sync + 'static,
|
T: IncomingRequest,
|
||||||
{
|
{
|
||||||
let body = take_body(services, request, json_body, auth);
|
let body = take_body(services, request, json_body, auth);
|
||||||
let http_request = into_http_request(request, body);
|
let http_request = into_http_request(request, body);
|
||||||
|
@ -125,10 +125,14 @@ fn into_http_request(request: &Request, body: Bytes) -> hyper::Request<Bytes> {
|
||||||
http_request
|
http_request
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::needless_pass_by_value)]
|
||||||
fn take_body(
|
fn take_body(
|
||||||
services: &Services, request: &mut Request, json_body: &mut Option<CanonicalJsonValue>, auth: &Auth,
|
services: &Services, request: &mut Request, json_body: Option<&mut CanonicalJsonValue>, auth: &Auth,
|
||||||
) -> Bytes {
|
) -> Bytes {
|
||||||
if let Some(CanonicalJsonValue::Object(json_body)) = json_body {
|
let Some(CanonicalJsonValue::Object(json_body)) = json_body else {
|
||||||
|
return mem::take(&mut request.body);
|
||||||
|
};
|
||||||
|
|
||||||
let user_id = auth.sender_user.clone().unwrap_or_else(|| {
|
let user_id = auth.sender_user.clone().unwrap_or_else(|| {
|
||||||
let server_name = services.globals.server_name();
|
let server_name = services.globals.server_name();
|
||||||
UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id")
|
UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id")
|
||||||
|
@ -154,7 +158,4 @@ fn take_body(
|
||||||
let mut buf = BytesMut::new().writer();
|
let mut buf = BytesMut::new().writer();
|
||||||
serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail");
|
serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail");
|
||||||
buf.into_inner().freeze()
|
buf.into_inner().freeze()
|
||||||
} else {
|
|
||||||
mem::take(&mut request.body)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use std::{fmt::Debug, mem};
|
use std::mem;
|
||||||
|
|
||||||
|
use bytes::Bytes;
|
||||||
use conduit::{
|
use conduit::{
|
||||||
debug, debug_error, debug_warn, err, error::inspect_debug_log, implement, trace, utils::string::EMPTY, Err, Error,
|
debug, debug_error, debug_warn, err, error::inspect_debug_log, implement, trace, utils::string::EMPTY, Err, Error,
|
||||||
Result,
|
Result,
|
||||||
|
@ -23,10 +24,10 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
impl super::Service {
|
impl super::Service {
|
||||||
#[tracing::instrument(skip(self, client, req), name = "send")]
|
#[tracing::instrument(skip(self, client, request), name = "send")]
|
||||||
pub async fn send<T>(&self, client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse>
|
pub async fn send<T>(&self, client: &Client, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
|
||||||
where
|
where
|
||||||
T: OutgoingRequest + Debug + Send,
|
T: OutgoingRequest + Send,
|
||||||
{
|
{
|
||||||
if !self.server.config.allow_federation {
|
if !self.server.config.allow_federation {
|
||||||
return Err!(Config("allow_federation", "Federation is disabled."));
|
return Err!(Config("allow_federation", "Federation is disabled."));
|
||||||
|
@ -42,7 +43,8 @@ impl super::Service {
|
||||||
}
|
}
|
||||||
|
|
||||||
let actual = self.services.resolver.get_actual_dest(dest).await?;
|
let actual = self.services.resolver.get_actual_dest(dest).await?;
|
||||||
let request = self.prepare::<T>(dest, &actual, req).await?;
|
let request = into_http_request::<T>(&actual, request)?;
|
||||||
|
let request = self.prepare(dest, request)?;
|
||||||
self.execute::<T>(dest, &actual, request, client).await
|
self.execute::<T>(dest, &actual, request, client).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +52,7 @@ impl super::Service {
|
||||||
&self, dest: &ServerName, actual: &ActualDest, request: Request, client: &Client,
|
&self, dest: &ServerName, actual: &ActualDest, request: Request, client: &Client,
|
||||||
) -> Result<T::IncomingResponse>
|
) -> Result<T::IncomingResponse>
|
||||||
where
|
where
|
||||||
T: OutgoingRequest + Debug + Send,
|
T: OutgoingRequest + Send,
|
||||||
{
|
{
|
||||||
let url = request.url().clone();
|
let url = request.url().clone();
|
||||||
let method = request.method().clone();
|
let method = request.method().clone();
|
||||||
|
@ -58,25 +60,14 @@ impl super::Service {
|
||||||
debug!(?method, ?url, "Sending request");
|
debug!(?method, ?url, "Sending request");
|
||||||
match client.execute(request).await {
|
match client.execute(request).await {
|
||||||
Ok(response) => handle_response::<T>(&self.services.resolver, dest, actual, &method, &url, response).await,
|
Ok(response) => handle_response::<T>(&self.services.resolver, dest, actual, &method, &url, response).await,
|
||||||
Err(error) => handle_error::<T>(dest, actual, &method, &url, error),
|
Err(error) => Err(handle_error(actual, &method, &url, error).expect_err("always returns error")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn prepare<T>(&self, dest: &ServerName, actual: &ActualDest, req: T) -> Result<Request>
|
fn prepare(&self, dest: &ServerName, mut request: http::Request<Vec<u8>>) -> Result<Request> {
|
||||||
where
|
self.sign_request(&mut request, dest);
|
||||||
T: OutgoingRequest + Debug + Send,
|
|
||||||
{
|
|
||||||
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11];
|
|
||||||
const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY);
|
|
||||||
|
|
||||||
trace!("Preparing request");
|
let request = Request::try_from(request)?;
|
||||||
let mut http_request = req
|
|
||||||
.try_into_http_request::<Vec<u8>>(actual.string().as_str(), SATIR, &VERSIONS)
|
|
||||||
.map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?;
|
|
||||||
|
|
||||||
self.sign_request(&mut http_request, dest);
|
|
||||||
|
|
||||||
let request = Request::try_from(http_request)?;
|
|
||||||
self.validate_url(request.url())?;
|
self.validate_url(request.url())?;
|
||||||
|
|
||||||
Ok(request)
|
Ok(request)
|
||||||
|
@ -96,11 +87,31 @@ impl super::Service {
|
||||||
|
|
||||||
async fn handle_response<T>(
|
async fn handle_response<T>(
|
||||||
resolver: &resolver::Service, dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url,
|
resolver: &resolver::Service, dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url,
|
||||||
mut response: Response,
|
response: Response,
|
||||||
) -> Result<T::IncomingResponse>
|
) -> Result<T::IncomingResponse>
|
||||||
where
|
where
|
||||||
T: OutgoingRequest + Debug + Send,
|
T: OutgoingRequest + Send,
|
||||||
{
|
{
|
||||||
|
let response = into_http_response(dest, actual, method, url, response).await?;
|
||||||
|
let result = T::IncomingResponse::try_from_http_response(response);
|
||||||
|
|
||||||
|
if result.is_ok() && !actual.cached {
|
||||||
|
resolver.set_cached_destination(
|
||||||
|
dest.to_owned(),
|
||||||
|
CachedDest {
|
||||||
|
dest: actual.dest.clone(),
|
||||||
|
host: actual.host.clone(),
|
||||||
|
expire: CachedDest::default_expire(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}")))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn into_http_response(
|
||||||
|
dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response,
|
||||||
|
) -> Result<http::Response<Bytes>> {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
trace!(
|
trace!(
|
||||||
?status, ?method,
|
?status, ?method,
|
||||||
|
@ -113,6 +124,7 @@ where
|
||||||
let mut http_response_builder = http::Response::builder()
|
let mut http_response_builder = http::Response::builder()
|
||||||
.status(status)
|
.status(status)
|
||||||
.version(response.version());
|
.version(response.version());
|
||||||
|
|
||||||
mem::swap(
|
mem::swap(
|
||||||
response.headers_mut(),
|
response.headers_mut(),
|
||||||
http_response_builder
|
http_response_builder
|
||||||
|
@ -137,27 +149,10 @@ where
|
||||||
return Err(Error::Federation(dest.to_owned(), RumaError::from_http_response(http_response)));
|
return Err(Error::Federation(dest.to_owned(), RumaError::from_http_response(http_response)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = T::IncomingResponse::try_from_http_response(http_response);
|
Ok(http_response)
|
||||||
if response.is_ok() && !actual.cached {
|
|
||||||
resolver.set_cached_destination(
|
|
||||||
dest.to_owned(),
|
|
||||||
CachedDest {
|
|
||||||
dest: actual.dest.clone(),
|
|
||||||
host: actual.host.clone(),
|
|
||||||
expire: CachedDest::default_expire(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
response.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}")))
|
fn handle_error(actual: &ActualDest, method: &Method, url: &Url, mut e: reqwest::Error) -> Result {
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_error<T>(
|
|
||||||
_dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut e: reqwest::Error,
|
|
||||||
) -> Result<T::IncomingResponse>
|
|
||||||
where
|
|
||||||
T: OutgoingRequest + Debug + Send,
|
|
||||||
{
|
|
||||||
if e.is_timeout() || e.is_connect() {
|
if e.is_timeout() || e.is_connect() {
|
||||||
e = e.without_url();
|
e = e.without_url();
|
||||||
debug_warn!("{e:?}");
|
debug_warn!("{e:?}");
|
||||||
|
@ -246,3 +241,17 @@ fn sign_request(&self, http_request: &mut http::Request<Vec<u8>>, dest: &ServerN
|
||||||
|
|
||||||
debug_assert!(authorization.is_none(), "Authorization header already present");
|
debug_assert!(authorization.is_none(), "Authorization header already present");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn into_http_request<T>(actual: &ActualDest, request: T) -> Result<http::Request<Vec<u8>>>
|
||||||
|
where
|
||||||
|
T: OutgoingRequest + Send,
|
||||||
|
{
|
||||||
|
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11];
|
||||||
|
const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY);
|
||||||
|
|
||||||
|
let http_request = request
|
||||||
|
.try_into_http_request::<Vec<u8>>(actual.string().as_str(), SATIR, &VERSIONS)
|
||||||
|
.map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?;
|
||||||
|
|
||||||
|
Ok(http_request)
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue