cleanup/split/dedup sending/send callstack

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-04-16 20:54:16 -07:00 committed by June
parent 9361acadcb
commit 68aa368450

View file

@ -43,9 +43,16 @@ pub enum FedDest {
Named(String, String), Named(String, String),
} }
struct ActualDestination {
destination: FedDest,
host: String,
string: String,
cached: bool,
}
#[tracing::instrument(skip_all, name = "send")] #[tracing::instrument(skip_all, name = "send")]
pub(crate) async fn send_request<T>( pub(crate) async fn send_request<T>(
client: &reqwest::Client, destination: &ServerName, request: T, client: &reqwest::Client, destination: &ServerName, req: T,
) -> Result<T::IncomingResponse> ) -> Result<T::IncomingResponse>
where where
T: OutgoingRequest + Debug, T: OutgoingRequest + Debug,
@ -54,180 +61,42 @@ where
return Err(Error::bad_config("Federation is disabled.")); return Err(Error::bad_config("Federation is disabled."));
} }
if destination == services().globals.server_name() { trace!("Preparing to send request");
return Err(Error::bad_config("Won't send federation request to ourselves")); validate_destination(destination)?;
} let actual = get_actual_destination(destination).await;
let mut http_request = req
if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) { .try_into_http_request::<Vec<u8>>(&actual.string, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_5])
debug!(
"Destination {} is an IP literal, checking against IP range denylist.",
destination
);
let ip = IPAddress::parse(destination.host()).map_err(|e| {
warn!("Failed to parse IP literal from string: {}", e);
Error::BadServerResponse("Invalid IP address")
})?;
let cidr_ranges_s = services().globals.ip_range_denylist().to_vec();
let mut cidr_ranges: Vec<IPAddress> = Vec::new();
for cidr in cidr_ranges_s {
cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup"));
}
debug!("List of pushed CIDR ranges: {:?}", cidr_ranges);
for cidr in cidr_ranges {
if cidr.includes(&ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
}
debug!("IP literal {} is allowed.", destination);
}
trace!("Preparing to send request to {destination}");
let mut write_destination_to_cache = false;
let cached_result = services()
.globals
.actual_destinations()
.read()
.await
.get(destination)
.cloned();
let (actual_destination, host) = if let Some(result) = cached_result {
result
} else {
write_destination_to_cache = true;
let result = resolve_actual_destination(destination).await;
(result.0, result.1.into_uri_string())
};
let actual_destination_str = actual_destination.clone().into_https_string();
let mut http_request = request
.try_into_http_request::<Vec<u8>>(
&actual_destination_str,
SendAccessToken::IfRequired(""),
&[MatrixVersion::V1_5],
)
.map_err(|e| { .map_err(|e| {
warn!("Failed to find destination {}: {}", actual_destination_str, e); warn!("Failed to find destination {}: {}", actual.string, e);
Error::BadServerResponse("Invalid destination") Error::BadServerResponse("Invalid destination")
})?; })?;
let mut request_map = serde_json::Map::new(); sign_request::<T>(destination, &mut http_request);
let request = reqwest::Request::try_from(http_request)?;
let method = request.method().clone();
let url = request.url().clone();
validate_url(&url)?;
if !http_request.body().is_empty() { debug!(
request_map.insert( method = ?method,
"content".to_owned(), url = ?url,
serde_json::from_slice(http_request.body()).expect("body is valid json, we just created it"), "Sending request",
);
};
request_map.insert("method".to_owned(), T::METADATA.method.to_string().into());
request_map.insert(
"uri".to_owned(),
http_request
.uri()
.path_and_query()
.expect("all requests have a path")
.to_string()
.into(),
);
request_map.insert("origin".to_owned(), services().globals.server_name().as_str().into());
request_map.insert("destination".to_owned(), destination.as_str().into());
let mut request_json = serde_json::from_value(request_map.into()).expect("valid JSON is valid BTreeMap");
ruma::signatures::sign_json(
services().globals.server_name().as_str(),
services().globals.keypair(),
&mut request_json,
)
.expect("our request json is what ruma expects");
let request_json: serde_json::Map<String, serde_json::Value> =
serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()).unwrap();
let signatures = request_json["signatures"]
.as_object()
.unwrap()
.values()
.map(|v| {
v.as_object()
.unwrap()
.iter()
.map(|(k, v)| (k, v.as_str().unwrap()))
});
for signature_server in signatures {
for s in signature_server {
http_request.headers_mut().insert(
AUTHORIZATION,
HeaderValue::from_str(&format!(
"X-Matrix origin={},key=\"{}\",sig=\"{}\"",
services().globals.server_name(),
s.0,
s.1
))
.unwrap(),
); );
match client.execute(request).await {
Ok(response) => handle_response::<T>(destination, actual, &method, &url, response).await,
Err(e) => handle_error::<T>(destination, &actual, &method, &url, e),
} }
} }
let reqwest_request = reqwest::Request::try_from(http_request)?; async fn handle_response<T>(
let method = reqwest_request.method().clone(); destination: &ServerName, actual: ActualDestination, method: &reqwest::Method, url: &reqwest::Url,
let url = reqwest_request.url().clone(); mut response: reqwest::Response,
) -> Result<T::IncomingResponse>
if let Some(url_host) = url.host_str() { where
trace!("Checking request URL for IP"); T: OutgoingRequest + Debug,
if let Ok(ip) = IPAddress::parse(url_host) { {
let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); trace!("Received response from {} for {} with {}", actual.string, url, response.url());
let mut cidr_ranges: Vec<IPAddress> = Vec::new(); validate_response(&response)?;
for cidr in cidr_ranges_s {
cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup"));
}
for cidr in cidr_ranges {
if cidr.includes(&ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
}
}
}
debug!("Sending request {} {}", method, url);
let response = client.execute(reqwest_request).await;
trace!("Received resonse {} {}", method, url);
match response {
Ok(mut response) => {
// reqwest::Response -> http::Response conversion
trace!("Checking response destination's IP");
if let Some(remote_addr) = response.remote_addr() {
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
let cidr_ranges_s = services().globals.ip_range_denylist().to_vec();
let mut cidr_ranges: Vec<IPAddress> = Vec::new();
for cidr in cidr_ranges_s {
cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup"));
}
for cidr in cidr_ranges {
if cidr.includes(&ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
}
}
}
let status = response.status(); let status = response.status();
let mut http_response_builder = http::Response::builder() let mut http_response_builder = http::Response::builder()
@ -246,94 +115,96 @@ where
Vec::new().into() Vec::new().into()
}); // TODO: handle timeout }); // TODO: handle timeout
if !status.is_success() {
debug!(
"Got {status:?} for {method} {url}: {}",
String::from_utf8_lossy(&body)
.lines()
.collect::<Vec<_>>()
.join(" ")
);
}
let http_response = http_response_builder let http_response = http_response_builder
.body(body) .body(body)
.expect("reqwest body is valid http body"); .expect("reqwest body is valid http body");
if status.is_success() {
debug!("Got {status:?} for {method} {url}"); debug!("Got {status:?} for {method} {url}");
if !status.is_success() {
return Err(Error::FederationError(
destination.to_owned(),
RumaError::from_http_response(http_response),
));
}
let response = T::IncomingResponse::try_from_http_response(http_response); let response = T::IncomingResponse::try_from_http_response(http_response);
if response.is_ok() && write_destination_to_cache { if response.is_ok() && !actual.cached {
services() services()
.globals .globals
.actual_destinations() .actual_destinations()
.write() .write()
.await .await
.insert(OwnedServerName::from(destination), (actual_destination, host)); .insert(OwnedServerName::from(destination), (actual.destination, actual.host));
} }
response.map_err(|e| { match response {
debug!("Invalid 200 response for {} {}", url, e); Err(_e) => Err(Error::BadServerResponse("Server returned bad 200 response.")),
Error::BadServerResponse("Server returned bad 200 response.") Ok(response) => Ok(response),
})
} else {
Err(Error::FederationError(
destination.to_owned(),
RumaError::from_http_response(http_response),
))
} }
}, }
Err(e) => {
fn handle_error<T>(
_destination: &ServerName, actual: &ActualDestination, method: &reqwest::Method, url: &reqwest::Url,
e: reqwest::Error,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug,
{
// we do not need to log that servers in a room are dead, this is normal in // we do not need to log that servers in a room are dead, this is normal in
// public rooms and just spams the logs. // public rooms and just spams the logs.
if e.is_timeout() { if e.is_timeout() {
debug!( debug!("Timed out sending request to {}: {}", actual.string, e,);
"Timed out sending request to {} at {}: {}",
destination, actual_destination_str, e
);
} else if e.is_connect() { } else if e.is_connect() {
debug!("Failed to connect to {} at {}: {}", destination, actual_destination_str, e); debug!("Failed to connect to {}: {}", actual.string, e);
} else if e.is_redirect() { } else if e.is_redirect() {
debug!( debug!(
"Redirect loop sending request to {} at {}: {}\nFinal URL: {:?}", method = ?method,
destination, url = ?url,
actual_destination_str, final_url = ?e.url(),
"Redirect loop sending request to {}: {}",
actual.string,
e, e,
e.url()
); );
} else { } else {
debug!("Could not send request to {} at {}: {}", destination, actual_destination_str, e); debug!("Could not send request to {}: {}", actual.string, e);
} }
Err(e.into()) Err(e.into())
},
}
} }
fn get_ip_with_port(destination_str: &str) -> Option<FedDest> { #[tracing::instrument(skip_all, name = "resolve")]
if let Ok(destination) = destination_str.parse::<SocketAddr>() { async fn get_actual_destination(server_name: &ServerName) -> ActualDestination {
Some(FedDest::Literal(destination)) let cached;
} else if let Ok(ip_addr) = destination_str.parse::<IpAddr>() { let cached_result = services()
Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) .globals
.actual_destinations()
.read()
.await
.get(server_name)
.cloned();
let (destination, host) = if let Some(result) = cached_result {
cached = true;
result
} else { } else {
None cached = false;
} resolve_actual_destination(server_name).await
}
fn add_port_to_hostname(destination_str: &str) -> FedDest {
let (host, port) = match destination_str.find(':') {
None => (destination_str, ":8448"),
Some(pos) => destination_str.split_at(pos),
}; };
FedDest::Named(host.to_owned(), port.to_owned())
let string = destination.clone().into_https_string();
ActualDestination {
destination,
host,
string,
cached,
}
} }
/// Returns: `actual_destination`, host header /// Returns: `actual_destination`, host header
/// Implemented according to the specification at <https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names> /// Implemented according to the specification at <https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names>
/// Numbers in comments below refer to bullet points in linked section of /// Numbers in comments below refer to bullet points in linked section of
/// specification /// specification
#[tracing::instrument(skip_all, name = "resolve")] async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, String) {
async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDest) {
trace!("Finding actual destination for {destination}"); trace!("Finding actual destination for {destination}");
let destination_str = destination.as_str().to_owned(); let destination_str = destination.as_str().to_owned();
let mut hostname = destination_str.clone(); let mut hostname = destination_str.clone();
@ -429,7 +300,7 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, Fe
}; };
debug!("Actual destination: {actual_destination:?} hostname: {hostname:?}"); debug!("Actual destination: {actual_destination:?} hostname: {hostname:?}");
(actual_destination, hostname) (actual_destination, hostname.into_uri_string())
} }
async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) { async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) {
@ -441,7 +312,6 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1
{ {
Ok(override_ip) => { Ok(override_ip) => {
trace!("Caching result of {:?} overriding {:?}", hostname, overname); trace!("Caching result of {:?} overriding {:?}", hostname, overname);
services() services()
.globals .globals
.resolver .resolver
@ -533,6 +403,156 @@ async fn request_well_known(destination: &str) -> Option<String> {
Some(body.get("m.server")?.as_str()?.to_owned()) Some(body.get("m.server")?.as_str()?.to_owned())
} }
fn sign_request<T>(destination: &ServerName, http_request: &mut http::Request<Vec<u8>>)
where
T: OutgoingRequest + Debug,
{
let mut req_map = serde_json::Map::new();
if !http_request.body().is_empty() {
req_map.insert(
"content".to_owned(),
serde_json::from_slice(http_request.body()).expect("body is valid json, we just created it"),
);
};
req_map.insert("method".to_owned(), T::METADATA.method.to_string().into());
req_map.insert(
"uri".to_owned(),
http_request
.uri()
.path_and_query()
.expect("all requests have a path")
.to_string()
.into(),
);
req_map.insert("origin".to_owned(), services().globals.server_name().as_str().into());
req_map.insert("destination".to_owned(), destination.as_str().into());
let mut req_json = serde_json::from_value(req_map.into()).expect("valid JSON is valid BTreeMap");
ruma::signatures::sign_json(
services().globals.server_name().as_str(),
services().globals.keypair(),
&mut req_json,
)
.expect("our request json is what ruma expects");
let req_json: serde_json::Map<String, serde_json::Value> =
serde_json::from_slice(&serde_json::to_vec(&req_json).unwrap()).unwrap();
let signatures = req_json["signatures"]
.as_object()
.expect("signatures object")
.values()
.map(|v| {
v.as_object()
.expect("server signatures object")
.iter()
.map(|(k, v)| (k, v.as_str().expect("server signature string")))
});
for signature_server in signatures {
for s in signature_server {
http_request.headers_mut().insert(
AUTHORIZATION,
HeaderValue::from_str(&format!(
"X-Matrix origin={},key=\"{}\",sig=\"{}\"",
services().globals.server_name(),
s.0,
s.1
))
.expect("formatted X-Matrix header"),
);
}
}
}
fn validate_response(response: &reqwest::Response) -> Result<()> {
if let Some(remote_addr) = response.remote_addr() {
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
trace!("Checking response destination's IP");
validate_ip(&ip)?;
}
}
Ok(())
}
fn validate_url(url: &reqwest::Url) -> Result<()> {
if let Some(url_host) = url.host_str() {
if let Ok(ip) = IPAddress::parse(url_host) {
trace!("Checking request URL IP {ip:?}");
validate_ip(&ip)?;
}
}
Ok(())
}
fn validate_destination(destination: &ServerName) -> Result<()> {
if destination == services().globals.server_name() {
return Err(Error::bad_config("Won't send federation request to ourselves"));
}
if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) {
validate_destination_ip_literal(destination)?;
}
trace!("Destination ServerName is valid");
Ok(())
}
fn validate_destination_ip_literal(destination: &ServerName) -> Result<()> {
debug_assert!(
destination.is_ip_literal() || !IPAddress::is_valid(destination.host()),
"Destination is not an IP literal."
);
debug!("Destination is an IP literal, checking against IP range denylist.",);
let ip = IPAddress::parse(destination.host()).map_err(|e| {
warn!("Failed to parse IP literal from string: {}", e);
Error::BadServerResponse("Invalid IP address")
})?;
validate_ip(&ip)?;
Ok(())
}
fn validate_ip(ip: &IPAddress) -> Result<()> {
let cidr_ranges_s = services().globals.ip_range_denylist().to_vec();
let mut cidr_ranges: Vec<IPAddress> = Vec::new();
for cidr in cidr_ranges_s {
cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup"));
}
trace!("List of pushed CIDR ranges: {:?}", cidr_ranges);
for cidr in cidr_ranges {
if cidr.includes(ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
}
Ok(())
}
fn get_ip_with_port(destination_str: &str) -> Option<FedDest> {
if let Ok(destination) = destination_str.parse::<SocketAddr>() {
Some(FedDest::Literal(destination))
} else if let Ok(ip_addr) = destination_str.parse::<IpAddr>() {
Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448)))
} else {
None
}
}
fn add_port_to_hostname(destination_str: &str) -> FedDest {
let (host, port) = match destination_str.find(':') {
None => (destination_str, ":8448"),
Some(pos) => destination_str.split_at(pos),
};
FedDest::Named(host.to_owned(), port.to_owned())
}
impl FedDest { impl FedDest {
fn into_https_string(self) -> String { fn into_https_string(self) -> String {
match self { match self {