From 57e6af6e21ce0d380fcd75abca30aa8b862881a7 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 23 Apr 2024 15:31:40 -0700 Subject: [PATCH] split sending/send base functions Signed-off-by: Jason Volk --- src/service/sending/mod.rs | 4 +-- src/service/sending/send.rs | 50 +++++++++++++++++++++++++------------ 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 7cde0941..b8622812 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -234,7 +234,7 @@ impl Service { let permit = self.maximum_requests.acquire().await; let timeout = Duration::from_secs(self.timeout); let client = &services().globals.client.federation; - let response = tokio::time::timeout(timeout, send::send_request(client, dest, request)) + let response = tokio::time::timeout(timeout, send::send(client, dest, request)) .await .map_err(|_| { warn!("Timeout after 300 seconds waiting for server response of {dest}"); @@ -795,7 +795,7 @@ async fn send_events_dest_normal( let permit = services().sending.maximum_requests.acquire().await; let client = &services().globals.client.sender; - let response = send::send_request( + let response = send::send( client, server_name, send_transaction_message::v1::Request { diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index d9c52af2..4d010b84 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -51,7 +51,7 @@ struct ActualDest { } #[tracing::instrument(skip_all, name = "send")] -pub(crate) async fn send_request(client: &Client, dest: &ServerName, req: T) -> Result +pub(crate) async fn send(client: &Client, dest: &ServerName, req: T) -> Result where T: OutgoingRequest + Debug, { @@ -59,22 +59,19 @@ where return Err(Error::bad_config("Federation is disabled.")); } - trace!("Preparing to send request"); - validate_dest(dest)?; let actual = get_actual_dest(dest).await?; - let mut http_request = req - .try_into_http_request::>(&actual.string, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_5]) - .map_err(|e| { - debug_warn!("Failed to find destination {}: {}", actual.string, e); - Error::BadServerResponse("Invalid destination") - })?; + let request = prepare::(dest, &actual, req).await?; + execute::(client, dest, &actual, request).await +} - sign_request::(dest, &mut http_request); - let request = Request::try_from(http_request)?; +async fn execute( + client: &Client, dest: &ServerName, actual: &ActualDest, request: Request, +) -> Result +where + T: OutgoingRequest + Debug, +{ let method = request.method().clone(); let url = request.url().clone(); - validate_url(&url)?; - debug!( method = ?method, url = ?url, @@ -82,12 +79,32 @@ where ); match client.execute(request).await { Ok(response) => handle_response::(dest, actual, &method, &url, response).await, - Err(e) => handle_error::(dest, &actual, &method, &url, e), + Err(e) => handle_error::(dest, actual, &method, &url, e), } } +async fn prepare(dest: &ServerName, actual: &ActualDest, req: T) -> Result +where + T: OutgoingRequest + Debug, +{ + const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_5]; + + trace!("Preparing request"); + + let mut http_request = req + .try_into_http_request::>(&actual.string, SendAccessToken::IfRequired(""), &VERSIONS) + .map_err(|_e| Error::BadServerResponse("Invalid destination"))?; + + sign_request::(dest, &mut http_request); + + let request = Request::try_from(http_request)?; + validate_url(request.url())?; + + Ok(request) +} + async fn handle_response( - dest: &ServerName, actual: ActualDest, method: &Method, url: &Url, mut response: Response, + dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response, ) -> Result where T: OutgoingRequest + Debug, @@ -126,7 +143,7 @@ where .actual_destinations() .write() .await - .insert(OwnedServerName::from(dest), (actual.dest, actual.host)); + .insert(OwnedServerName::from(dest), (actual.dest.clone(), actual.host.clone())); } match response { @@ -176,6 +193,7 @@ async fn get_actual_dest(server_name: &ServerName) -> Result { result } else { cached = false; + validate_dest(server_name)?; resolve_actual_dest(server_name).await? };