From 68aa3684502e67c489892f7fbba1132b32d48456 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 16 Apr 2024 20:54:16 -0700 Subject: [PATCH] cleanup/split/dedup sending/send callstack Signed-off-by: Jason Volk --- src/service/sending/send.rs | 530 +++++++++++++++++++----------------- 1 file changed, 275 insertions(+), 255 deletions(-) diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 5090db99..f5cd0649 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -43,9 +43,16 @@ pub enum FedDest { Named(String, String), } +struct ActualDestination { + destination: FedDest, + host: String, + string: String, + cached: bool, +} + #[tracing::instrument(skip_all, name = "send")] pub(crate) async fn send_request( - client: &reqwest::Client, destination: &ServerName, request: T, + client: &reqwest::Client, destination: &ServerName, req: T, ) -> Result where T: OutgoingRequest + Debug, @@ -54,286 +61,150 @@ where return Err(Error::bad_config("Federation is disabled.")); } - if destination == services().globals.server_name() { - return Err(Error::bad_config("Won't send federation request to ourselves")); - } - - if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) { - debug!( - "Destination {} is an IP literal, checking against IP range denylist.", - destination - ); - let ip = IPAddress::parse(destination.host()).map_err(|e| { - warn!("Failed to parse IP literal from string: {}", e); - Error::BadServerResponse("Invalid IP address") + trace!("Preparing to send request"); + validate_destination(destination)?; + let actual = get_actual_destination(destination).await; + let mut http_request = req + .try_into_http_request::>(&actual.string, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_5]) + .map_err(|e| { + warn!("Failed to find destination {}: {}", actual.string, e); + Error::BadServerResponse("Invalid destination") })?; - let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); - let mut cidr_ranges: Vec = Vec::new(); + sign_request::(destination, &mut http_request); + let request = reqwest::Request::try_from(http_request)?; + let method = request.method().clone(); + let url = request.url().clone(); + validate_url(&url)?; - for cidr in cidr_ranges_s { - cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); - } + debug!( + method = ?method, + url = ?url, + "Sending request", + ); + match client.execute(request).await { + Ok(response) => handle_response::(destination, actual, &method, &url, response).await, + Err(e) => handle_error::(destination, &actual, &method, &url, e), + } +} - debug!("List of pushed CIDR ranges: {:?}", cidr_ranges); +async fn handle_response( + destination: &ServerName, actual: ActualDestination, method: &reqwest::Method, url: &reqwest::Url, + mut response: reqwest::Response, +) -> Result +where + T: OutgoingRequest + Debug, +{ + trace!("Received response from {} for {} with {}", actual.string, url, response.url()); + validate_response(&response)?; - for cidr in cidr_ranges { - if cidr.includes(&ip) { - return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); - } - } + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); - debug!("IP literal {} is allowed.", destination); + trace!("Waiting for response body"); + let body = response.bytes().await.unwrap_or_else(|e| { + debug!("server error {}", e); + Vec::new().into() + }); // TODO: handle timeout + + let http_response = http_response_builder + .body(body) + .expect("reqwest body is valid http body"); + + debug!("Got {status:?} for {method} {url}"); + if !status.is_success() { + return Err(Error::FederationError( + destination.to_owned(), + RumaError::from_http_response(http_response), + )); } - trace!("Preparing to send request to {destination}"); + let response = T::IncomingResponse::try_from_http_response(http_response); + if response.is_ok() && !actual.cached { + services() + .globals + .actual_destinations() + .write() + .await + .insert(OwnedServerName::from(destination), (actual.destination, actual.host)); + } - let mut write_destination_to_cache = false; + match response { + Err(_e) => Err(Error::BadServerResponse("Server returned bad 200 response.")), + Ok(response) => Ok(response), + } +} +fn handle_error( + _destination: &ServerName, actual: &ActualDestination, method: &reqwest::Method, url: &reqwest::Url, + e: reqwest::Error, +) -> Result +where + T: OutgoingRequest + Debug, +{ + // we do not need to log that servers in a room are dead, this is normal in + // public rooms and just spams the logs. + if e.is_timeout() { + debug!("Timed out sending request to {}: {}", actual.string, e,); + } else if e.is_connect() { + debug!("Failed to connect to {}: {}", actual.string, e); + } else if e.is_redirect() { + debug!( + method = ?method, + url = ?url, + final_url = ?e.url(), + "Redirect loop sending request to {}: {}", + actual.string, + e, + ); + } else { + debug!("Could not send request to {}: {}", actual.string, e); + } + + Err(e.into()) +} + +#[tracing::instrument(skip_all, name = "resolve")] +async fn get_actual_destination(server_name: &ServerName) -> ActualDestination { + let cached; let cached_result = services() .globals .actual_destinations() .read() .await - .get(destination) + .get(server_name) .cloned(); - let (actual_destination, host) = if let Some(result) = cached_result { + let (destination, host) = if let Some(result) = cached_result { + cached = true; result } else { - write_destination_to_cache = true; - - let result = resolve_actual_destination(destination).await; - - (result.0, result.1.into_uri_string()) + cached = false; + resolve_actual_destination(server_name).await }; - let actual_destination_str = actual_destination.clone().into_https_string(); - - let mut http_request = request - .try_into_http_request::>( - &actual_destination_str, - SendAccessToken::IfRequired(""), - &[MatrixVersion::V1_5], - ) - .map_err(|e| { - warn!("Failed to find destination {}: {}", actual_destination_str, e); - Error::BadServerResponse("Invalid destination") - })?; - - let mut request_map = serde_json::Map::new(); - - if !http_request.body().is_empty() { - request_map.insert( - "content".to_owned(), - serde_json::from_slice(http_request.body()).expect("body is valid json, we just created it"), - ); - }; - - request_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); - request_map.insert( - "uri".to_owned(), - http_request - .uri() - .path_and_query() - .expect("all requests have a path") - .to_string() - .into(), - ); - request_map.insert("origin".to_owned(), services().globals.server_name().as_str().into()); - request_map.insert("destination".to_owned(), destination.as_str().into()); - - let mut request_json = serde_json::from_value(request_map.into()).expect("valid JSON is valid BTreeMap"); - - ruma::signatures::sign_json( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut request_json, - ) - .expect("our request json is what ruma expects"); - - let request_json: serde_json::Map = - serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()).unwrap(); - - let signatures = request_json["signatures"] - .as_object() - .unwrap() - .values() - .map(|v| { - v.as_object() - .unwrap() - .iter() - .map(|(k, v)| (k, v.as_str().unwrap())) - }); - - for signature_server in signatures { - for s in signature_server { - http_request.headers_mut().insert( - AUTHORIZATION, - HeaderValue::from_str(&format!( - "X-Matrix origin={},key=\"{}\",sig=\"{}\"", - services().globals.server_name(), - s.0, - s.1 - )) - .unwrap(), - ); - } + let string = destination.clone().into_https_string(); + ActualDestination { + destination, + host, + string, + cached, } - - let reqwest_request = reqwest::Request::try_from(http_request)?; - let method = reqwest_request.method().clone(); - let url = reqwest_request.url().clone(); - - if let Some(url_host) = url.host_str() { - trace!("Checking request URL for IP"); - if let Ok(ip) = IPAddress::parse(url_host) { - let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); - let mut cidr_ranges: Vec = Vec::new(); - - for cidr in cidr_ranges_s { - cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); - } - - for cidr in cidr_ranges { - if cidr.includes(&ip) { - return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); - } - } - } - } - - debug!("Sending request {} {}", method, url); - let response = client.execute(reqwest_request).await; - trace!("Received resonse {} {}", method, url); - - match response { - Ok(mut response) => { - // reqwest::Response -> http::Response conversion - - trace!("Checking response destination's IP"); - if let Some(remote_addr) = response.remote_addr() { - if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { - let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); - let mut cidr_ranges: Vec = Vec::new(); - - for cidr in cidr_ranges_s { - cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); - } - - for cidr in cidr_ranges { - if cidr.includes(&ip) { - return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); - } - } - } - } - - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); - - trace!("Waiting for response body"); - let body = response.bytes().await.unwrap_or_else(|e| { - debug!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if !status.is_success() { - debug!( - "Got {status:?} for {method} {url}: {}", - String::from_utf8_lossy(&body) - .lines() - .collect::>() - .join(" ") - ); - } - - let http_response = http_response_builder - .body(body) - .expect("reqwest body is valid http body"); - - if status.is_success() { - debug!("Got {status:?} for {method} {url}"); - let response = T::IncomingResponse::try_from_http_response(http_response); - if response.is_ok() && write_destination_to_cache { - services() - .globals - .actual_destinations() - .write() - .await - .insert(OwnedServerName::from(destination), (actual_destination, host)); - } - - response.map_err(|e| { - debug!("Invalid 200 response for {} {}", url, e); - Error::BadServerResponse("Server returned bad 200 response.") - }) - } else { - Err(Error::FederationError( - destination.to_owned(), - RumaError::from_http_response(http_response), - )) - } - }, - Err(e) => { - // we do not need to log that servers in a room are dead, this is normal in - // public rooms and just spams the logs. - if e.is_timeout() { - debug!( - "Timed out sending request to {} at {}: {}", - destination, actual_destination_str, e - ); - } else if e.is_connect() { - debug!("Failed to connect to {} at {}: {}", destination, actual_destination_str, e); - } else if e.is_redirect() { - debug!( - "Redirect loop sending request to {} at {}: {}\nFinal URL: {:?}", - destination, - actual_destination_str, - e, - e.url() - ); - } else { - debug!("Could not send request to {} at {}: {}", destination, actual_destination_str, e); - } - - Err(e.into()) - }, - } -} - -fn get_ip_with_port(destination_str: &str) -> Option { - if let Ok(destination) = destination_str.parse::() { - Some(FedDest::Literal(destination)) - } else if let Ok(ip_addr) = destination_str.parse::() { - Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) - } else { - None - } -} - -fn add_port_to_hostname(destination_str: &str) -> FedDest { - let (host, port) = match destination_str.find(':') { - None => (destination_str, ":8448"), - Some(pos) => destination_str.split_at(pos), - }; - FedDest::Named(host.to_owned(), port.to_owned()) } /// Returns: `actual_destination`, host header /// Implemented according to the specification at /// Numbers in comments below refer to bullet points in linked section of /// specification -#[tracing::instrument(skip_all, name = "resolve")] -async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDest) { +async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, String) { trace!("Finding actual destination for {destination}"); let destination_str = destination.as_str().to_owned(); let mut hostname = destination_str.clone(); @@ -429,7 +300,7 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, Fe }; debug!("Actual destination: {actual_destination:?} hostname: {hostname:?}"); - (actual_destination, hostname) + (actual_destination, hostname.into_uri_string()) } async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) { @@ -441,7 +312,6 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1 { Ok(override_ip) => { trace!("Caching result of {:?} overriding {:?}", hostname, overname); - services() .globals .resolver @@ -533,6 +403,156 @@ async fn request_well_known(destination: &str) -> Option { Some(body.get("m.server")?.as_str()?.to_owned()) } +fn sign_request(destination: &ServerName, http_request: &mut http::Request>) +where + T: OutgoingRequest + Debug, +{ + let mut req_map = serde_json::Map::new(); + if !http_request.body().is_empty() { + req_map.insert( + "content".to_owned(), + serde_json::from_slice(http_request.body()).expect("body is valid json, we just created it"), + ); + }; + + req_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); + req_map.insert( + "uri".to_owned(), + http_request + .uri() + .path_and_query() + .expect("all requests have a path") + .to_string() + .into(), + ); + req_map.insert("origin".to_owned(), services().globals.server_name().as_str().into()); + req_map.insert("destination".to_owned(), destination.as_str().into()); + + let mut req_json = serde_json::from_value(req_map.into()).expect("valid JSON is valid BTreeMap"); + ruma::signatures::sign_json( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut req_json, + ) + .expect("our request json is what ruma expects"); + + let req_json: serde_json::Map = + serde_json::from_slice(&serde_json::to_vec(&req_json).unwrap()).unwrap(); + + let signatures = req_json["signatures"] + .as_object() + .expect("signatures object") + .values() + .map(|v| { + v.as_object() + .expect("server signatures object") + .iter() + .map(|(k, v)| (k, v.as_str().expect("server signature string"))) + }); + + for signature_server in signatures { + for s in signature_server { + http_request.headers_mut().insert( + AUTHORIZATION, + HeaderValue::from_str(&format!( + "X-Matrix origin={},key=\"{}\",sig=\"{}\"", + services().globals.server_name(), + s.0, + s.1 + )) + .expect("formatted X-Matrix header"), + ); + } + } +} + +fn validate_response(response: &reqwest::Response) -> Result<()> { + if let Some(remote_addr) = response.remote_addr() { + if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { + trace!("Checking response destination's IP"); + validate_ip(&ip)?; + } + } + + Ok(()) +} + +fn validate_url(url: &reqwest::Url) -> Result<()> { + if let Some(url_host) = url.host_str() { + if let Ok(ip) = IPAddress::parse(url_host) { + trace!("Checking request URL IP {ip:?}"); + validate_ip(&ip)?; + } + } + + Ok(()) +} + +fn validate_destination(destination: &ServerName) -> Result<()> { + if destination == services().globals.server_name() { + return Err(Error::bad_config("Won't send federation request to ourselves")); + } + + if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) { + validate_destination_ip_literal(destination)?; + } + + trace!("Destination ServerName is valid"); + Ok(()) +} + +fn validate_destination_ip_literal(destination: &ServerName) -> Result<()> { + debug_assert!( + destination.is_ip_literal() || !IPAddress::is_valid(destination.host()), + "Destination is not an IP literal." + ); + debug!("Destination is an IP literal, checking against IP range denylist.",); + + let ip = IPAddress::parse(destination.host()).map_err(|e| { + warn!("Failed to parse IP literal from string: {}", e); + Error::BadServerResponse("Invalid IP address") + })?; + + validate_ip(&ip)?; + + Ok(()) +} + +fn validate_ip(ip: &IPAddress) -> Result<()> { + let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); + let mut cidr_ranges: Vec = Vec::new(); + for cidr in cidr_ranges_s { + cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); + } + + trace!("List of pushed CIDR ranges: {:?}", cidr_ranges); + for cidr in cidr_ranges { + if cidr.includes(ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + } + } + + Ok(()) +} + +fn get_ip_with_port(destination_str: &str) -> Option { + if let Ok(destination) = destination_str.parse::() { + Some(FedDest::Literal(destination)) + } else if let Ok(ip_addr) = destination_str.parse::() { + Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) + } else { + None + } +} + +fn add_port_to_hostname(destination_str: &str) -> FedDest { + let (host, port) = match destination_str.find(':') { + None => (destination_str, ":8448"), + Some(pos) => destination_str.split_at(pos), + }; + FedDest::Named(host.to_owned(), port.to_owned()) +} + impl FedDest { fn into_https_string(self) -> String { match self {