move non-generic code out of generic; reduce codegen

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-11-12 08:01:23 +00:00
parent 999d731a65
commit 86694f2d1d
2 changed files with 90 additions and 80 deletions

View file

@ -66,6 +66,15 @@ where
}
}
impl<T> Deref for Args<T>
where
T: IncomingRequest + Send + Sync + 'static,
{
type Target = T;
fn deref(&self) -> &Self::Target { &self.body }
}
#[async_trait]
impl<T> FromRequest<State, Body> for Args<T>
where
@ -78,7 +87,7 @@ where
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok();
let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?;
Ok(Self {
body: make_body::<T>(services, &mut request, &mut json_body, &auth)?,
body: make_body::<T>(services, &mut request, json_body.as_mut(), &auth)?,
origin: auth.origin,
sender_user: auth.sender_user,
sender_device: auth.sender_device,
@ -88,20 +97,11 @@ where
}
}
impl<T> Deref for Args<T>
where
T: IncomingRequest + Send + Sync + 'static,
{
type Target = T;
fn deref(&self) -> &Self::Target { &self.body }
}
fn make_body<T>(
services: &Services, request: &mut Request, json_body: &mut Option<CanonicalJsonValue>, auth: &Auth,
services: &Services, request: &mut Request, json_body: Option<&mut CanonicalJsonValue>, auth: &Auth,
) -> Result<T>
where
T: IncomingRequest + Send + Sync + 'static,
T: IncomingRequest,
{
let body = take_body(services, request, json_body, auth);
let http_request = into_http_request(request, body);
@ -125,36 +125,37 @@ fn into_http_request(request: &Request, body: Bytes) -> hyper::Request<Bytes> {
http_request
}
#[allow(clippy::needless_pass_by_value)]
fn take_body(
services: &Services, request: &mut Request, json_body: &mut Option<CanonicalJsonValue>, auth: &Auth,
services: &Services, request: &mut Request, json_body: Option<&mut CanonicalJsonValue>, auth: &Auth,
) -> Bytes {
if let Some(CanonicalJsonValue::Object(json_body)) = json_body {
let user_id = auth.sender_user.clone().unwrap_or_else(|| {
let server_name = services.globals.server_name();
UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id")
let Some(CanonicalJsonValue::Object(json_body)) = json_body else {
return mem::take(&mut request.body);
};
let user_id = auth.sender_user.clone().unwrap_or_else(|| {
let server_name = services.globals.server_name();
UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id")
});
let uiaa_request = json_body
.get("auth")
.and_then(CanonicalJsonValue::as_object)
.and_then(|auth| auth.get("session"))
.and_then(CanonicalJsonValue::as_str)
.and_then(|session| {
services
.uiaa
.get_uiaa_request(&user_id, auth.sender_device.as_deref(), session)
});
let uiaa_request = json_body
.get("auth")
.and_then(CanonicalJsonValue::as_object)
.and_then(|auth| auth.get("session"))
.and_then(CanonicalJsonValue::as_str)
.and_then(|session| {
services
.uiaa
.get_uiaa_request(&user_id, auth.sender_device.as_deref(), session)
});
if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {
for (key, value) in initial_request {
json_body.entry(key).or_insert(value);
}
if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {
for (key, value) in initial_request {
json_body.entry(key).or_insert(value);
}
let mut buf = BytesMut::new().writer();
serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail");
buf.into_inner().freeze()
} else {
mem::take(&mut request.body)
}
let mut buf = BytesMut::new().writer();
serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail");
buf.into_inner().freeze()
}