From eb9a6fe426242a8860593a4a328a8178560142fb Mon Sep 17 00:00:00 2001
From: Jason Volk <jason@zemos.net>
Date: Thu, 18 Apr 2024 00:52:29 -0700
Subject: [PATCH] refactor sending send/resolver/well-known error propagation

Signed-off-by: Jason Volk <jason@zemos.net>
---
 Cargo.lock                  |   1 +
 Cargo.toml                  |   4 +
 src/service/sending/send.rs | 205 +++++++++++++++++++-----------------
 3 files changed, 116 insertions(+), 94 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index da06c5ad..e0841e44 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -536,6 +536,7 @@ dependencies = [
  "reqwest",
  "ring",
  "ruma",
+ "ruma-identifiers-validation",
  "rusqlite",
  "rust-rocksdb",
  "sd-notify",
diff --git a/Cargo.toml b/Cargo.toml
index f818d56f..84976146 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -275,6 +275,10 @@ features = [
     "unstable-extensible-events",
 ]
 
+[dependencies.ruma-identifiers-validation]
+git = "https://github.com/girlbossceo/ruma"
+branch = "conduwuit-changes"
+
 [dependencies.hickory-resolver]
 version = "0.24.1"
 default-features = false
diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs
index f5cd0649..726ceffe 100644
--- a/src/service/sending/send.rs
+++ b/src/service/sending/send.rs
@@ -4,7 +4,6 @@ use std::{
 	net::{IpAddr, SocketAddr},
 };
 
-use futures_util::TryFutureExt;
 use hickory_resolver::{error::ResolveError, lookup::SrvLookup};
 use http::{header::AUTHORIZATION, HeaderValue};
 use ipaddress::IPAddress;
@@ -15,9 +14,9 @@ use ruma::{
 	},
 	OwnedServerName, ServerName,
 };
-use tracing::{debug, trace, warn};
+use tracing::{debug, error, trace, warn};
 
-use crate::{services, Error, Result};
+use crate::{debug_error, debug_warn, services, Error, Result};
 
 /// Wraps either an literal IP address plus port, or a hostname plus complement
 /// (colon-plus-port if it was specified).
@@ -63,11 +62,11 @@ where
 
 	trace!("Preparing to send request");
 	validate_destination(destination)?;
-	let actual = get_actual_destination(destination).await;
+	let actual = get_actual_destination(destination).await?;
 	let mut http_request = req
 		.try_into_http_request::<Vec<u8>>(&actual.string, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_5])
 		.map_err(|e| {
-			warn!("Failed to find destination {}: {}", actual.string, e);
+			debug_warn!("Failed to find destination {}: {}", actual.string, e);
 			Error::BadServerResponse("Invalid destination")
 		})?;
 
@@ -96,8 +95,6 @@ where
 	T: OutgoingRequest + Debug,
 {
 	trace!("Received response from {} for {} with {}", actual.string, url, response.url());
-	validate_response(&response)?;
-
 	let status = response.status();
 	let mut http_response_builder = http::Response::builder()
 		.status(status)
@@ -111,7 +108,7 @@ where
 
 	trace!("Waiting for response body");
 	let body = response.bytes().await.unwrap_or_else(|e| {
-		debug!("server error {}", e);
+		debug_error!("server error {}", e);
 		Vec::new().into()
 	}); // TODO: handle timeout
 
@@ -153,27 +150,27 @@ where
 	// we do not need to log that servers in a room are dead, this is normal in
 	// public rooms and just spams the logs.
 	if e.is_timeout() {
-		debug!("Timed out sending request to {}: {}", actual.string, e,);
+		debug_error!("timeout {}: {}", actual.host, e);
 	} else if e.is_connect() {
-		debug!("Failed to connect to {}: {}", actual.string, e);
+		debug_error!("connect {}: {}", actual.host, e);
 	} else if e.is_redirect() {
-		debug!(
+		debug_error!(
 			method = ?method,
 			url = ?url,
 			final_url = ?e.url(),
-			"Redirect loop sending request to {}: {}",
-			actual.string,
+			"Redirect loop {}: {}",
+			actual.host,
 			e,
 		);
 	} else {
-		debug!("Could not send request to {}: {}", actual.string, e);
+		debug_error!("{}: {}", actual.host, e);
 	}
 
 	Err(e.into())
 }
 
 #[tracing::instrument(skip_all, name = "resolve")]
-async fn get_actual_destination(server_name: &ServerName) -> ActualDestination {
+async fn get_actual_destination(server_name: &ServerName) -> Result<ActualDestination> {
 	let cached;
 	let cached_result = services()
 		.globals
@@ -188,23 +185,23 @@ async fn get_actual_destination(server_name: &ServerName) -> ActualDestination {
 		result
 	} else {
 		cached = false;
-		resolve_actual_destination(server_name).await
+		resolve_actual_destination(server_name).await?
 	};
 
 	let string = destination.clone().into_https_string();
-	ActualDestination {
+	Ok(ActualDestination {
 		destination,
 		host,
 		string,
 		cached,
-	}
+	})
 }
 
 /// Returns: `actual_destination`, host header
 /// Implemented according to the specification at <https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names>
 /// Numbers in comments below refer to bullet points in linked section of
 /// specification
-async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, String) {
+async fn resolve_actual_destination(destination: &'_ ServerName) -> Result<(FedDest, String)> {
 	trace!("Finding actual destination for {destination}");
 	let destination_str = destination.as_str().to_owned();
 	let mut hostname = destination_str.clone();
@@ -218,12 +215,12 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St
 				debug!("2: Hostname with included port");
 
 				let (host, port) = destination_str.split_at(pos);
-				query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448)).await;
+				query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448)).await?;
 
 				FedDest::Named(host.to_owned(), port.to_owned())
 			} else {
 				trace!("Requesting well known for {destination}");
-				if let Some(delegated_hostname) = request_well_known(destination.as_str()).await {
+				if let Some(delegated_hostname) = request_well_known(destination.as_str()).await? {
 					debug!("3: A .well-known file is available");
 					hostname = add_port_to_hostname(&delegated_hostname).into_uri_string();
 					match get_ip_with_port(&delegated_hostname) {
@@ -233,12 +230,12 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St
 								debug!("3.2: Hostname with port in .well-known file");
 
 								let (host, port) = delegated_hostname.split_at(pos);
-								query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448)).await;
+								query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448)).await?;
 
 								FedDest::Named(host.to_owned(), port.to_owned())
 							} else {
 								trace!("Delegated hostname has no port in this branch");
-								if let Some(hostname_override) = query_srv_record(&delegated_hostname).await {
+								if let Some(hostname_override) = query_srv_record(&delegated_hostname).await? {
 									debug!("3.3: SRV lookup successful");
 
 									let force_port = hostname_override.port();
@@ -247,7 +244,7 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St
 										&hostname_override.hostname(),
 										force_port.unwrap_or(8448),
 									)
-									.await;
+									.await?;
 
 									if let Some(port) = force_port {
 										FedDest::Named(delegated_hostname, format!(":{port}"))
@@ -256,7 +253,7 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St
 									}
 								} else {
 									debug!("3.4: No SRV records, just use the hostname from .well-known");
-									query_and_cache_override(&delegated_hostname, &delegated_hostname, 8448).await;
+									query_and_cache_override(&delegated_hostname, &delegated_hostname, 8448).await?;
 									add_port_to_hostname(&delegated_hostname)
 								}
 							}
@@ -264,12 +261,12 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St
 					}
 				} else {
 					trace!("4: No .well-known or an error occured");
-					if let Some(hostname_override) = query_srv_record(&destination_str).await {
+					if let Some(hostname_override) = query_srv_record(&destination_str).await? {
 						debug!("4: No .well-known; SRV record found");
 
 						let force_port = hostname_override.port();
 						query_and_cache_override(&hostname, &hostname_override.hostname(), force_port.unwrap_or(8448))
-							.await;
+							.await?;
 
 						if let Some(port) = force_port {
 							FedDest::Named(hostname.clone(), format!(":{port}"))
@@ -278,7 +275,7 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St
 						}
 					} else {
 						debug!("4: No .well-known; 5: No SRV record found");
-						query_and_cache_override(&destination_str, &destination_str, 8448).await;
+						query_and_cache_override(&destination_str, &destination_str, 8448).await?;
 						add_port_to_hostname(&destination_str)
 					}
 				}
@@ -300,10 +297,65 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St
 	};
 
 	debug!("Actual destination: {actual_destination:?} hostname: {hostname:?}");
-	(actual_destination, hostname.into_uri_string())
+	Ok((actual_destination, hostname.into_uri_string()))
 }
 
-async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) {
+async fn request_well_known(destination: &str) -> Result<Option<String>> {
+	if !services()
+		.globals
+		.resolver
+		.overrides
+		.read()
+		.unwrap()
+		.contains_key(destination)
+	{
+		query_and_cache_override(destination, destination, 8448).await?;
+	}
+
+	let response = services()
+		.globals
+		.client
+		.well_known
+		.get(&format!("https://{destination}/.well-known/matrix/server"))
+		.send()
+		.await;
+
+	trace!("Well known response: {:?}", response);
+	if let Err(e) = &response {
+		debug!("Well known error: {e:?}");
+		return Ok(None);
+	}
+
+	let response = response?;
+	if !response.status().is_success() {
+		debug!("Well known response not 2XX");
+		return Ok(None);
+	}
+
+	let text = response.text().await?;
+	trace!("Well known response text: {:?}", text);
+	if text.len() >= 12288 {
+		debug!("Well known response contains junk");
+		return Ok(None);
+	}
+
+	let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default();
+
+	let m_server = body
+		.get("m.server")
+		.unwrap_or(&serde_json::Value::Null)
+		.as_str()
+		.unwrap_or_default();
+
+	if ruma_identifiers_validation::server_name::validate(m_server).is_err() {
+		debug!("Well known response content missing or invalid");
+		return Ok(None);
+	}
+
+	Ok(Some(m_server.to_owned()))
+}
+
+async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> {
 	match services()
 		.globals
 		.dns_resolver()
@@ -319,14 +371,17 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1
 				.write()
 				.unwrap()
 				.insert(overname.to_owned(), (override_ip.iter().collect(), port));
+
+			Ok(())
 		},
 		Err(e) => {
 			debug!("Got {:?} for {:?} to override {:?}", e.kind(), hostname, overname);
+			handle_resolve_error(&e)
 		},
 	}
 }
 
-async fn query_srv_record(hostname: &'_ str) -> Option<FedDest> {
+async fn query_srv_record(hostname: &'_ str) -> Result<Option<FedDest>> {
 	fn handle_successful_srv(srv: &SrvLookup) -> Option<FedDest> {
 		srv.iter().next().map(|result| {
 			FedDest::Named(
@@ -346,61 +401,34 @@ async fn query_srv_record(hostname: &'_ str) -> Option<FedDest> {
 			.await
 	}
 
-	let first_hostname = format!("_matrix-fed._tcp.{hostname}.");
-	let second_hostname = format!("_matrix._tcp.{hostname}.");
+	let hostnames = [format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")];
 
-	lookup_srv(&first_hostname)
-		.or_else(|_| {
-			trace!("Querying deprecated _matrix SRV record for host {:?}", hostname);
-			lookup_srv(&second_hostname)
-		})
-		.and_then(|srv_lookup| async move { Ok(handle_successful_srv(&srv_lookup)) })
-		.await
-		.ok()
-		.flatten()
+	for hostname in hostnames {
+		match lookup_srv(&hostname).await {
+			Ok(result) => return Ok(handle_successful_srv(&result)),
+			Err(e) => handle_resolve_error(&e)?,
+		}
+	}
+
+	Ok(None)
 }
 
-async fn request_well_known(destination: &str) -> Option<String> {
-	if !services()
-		.globals
-		.resolver
-		.overrides
-		.read()
-		.unwrap()
-		.contains_key(destination)
-	{
-		query_and_cache_override(destination, destination, 8448).await;
+#[allow(clippy::single_match_else)]
+fn handle_resolve_error(e: &ResolveError) -> Result<()> {
+	use hickory_resolver::error::ResolveErrorKind;
+
+	match *e.kind() {
+		ResolveErrorKind::Io {
+			..
+		} => {
+			debug_error!("DNS IO error: {e}");
+			Err(Error::Error(e.to_string()))
+		},
+		_ => {
+			debug!("DNS protocol error: {e}");
+			Ok(())
+		},
 	}
-
-	let response = services()
-		.globals
-		.client
-		.well_known
-		.get(&format!("https://{destination}/.well-known/matrix/server"))
-		.send()
-		.await;
-
-	trace!("Well known response: {:?}", response);
-	if let Err(e) = &response {
-		debug!("Well known error: {e:?}");
-		return None;
-	}
-
-	let text = response.ok()?.text().await;
-	trace!("Well known response text: {:?}", text);
-
-	if text.as_ref().ok()?.len() > 10000 {
-		debug!(
-			"Well known response for destination '{destination}' exceeded past 10000 characters, assuming no \
-			 well-known."
-		);
-		return None;
-	}
-
-	let body: serde_json::Value = serde_json::from_str(&text.ok()?).ok()?;
-	trace!("serde_json body of well known text: {}", body);
-
-	Some(body.get("m.server")?.as_str()?.to_owned())
 }
 
 fn sign_request<T>(destination: &ServerName, http_request: &mut http::Request<Vec<u8>>)
@@ -466,17 +494,6 @@ where
 	}
 }
 
-fn validate_response(response: &reqwest::Response) -> Result<()> {
-	if let Some(remote_addr) = response.remote_addr() {
-		if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
-			trace!("Checking response destination's IP");
-			validate_ip(&ip)?;
-		}
-	}
-
-	Ok(())
-}
-
 fn validate_url(url: &reqwest::Url) -> Result<()> {
 	if let Some(url_host) = url.host_str() {
 		if let Ok(ip) = IPAddress::parse(url_host) {
@@ -509,7 +526,7 @@ fn validate_destination_ip_literal(destination: &ServerName) -> Result<()> {
 	debug!("Destination is an IP literal, checking against IP range denylist.",);
 
 	let ip = IPAddress::parse(destination.host()).map_err(|e| {
-		warn!("Failed to parse IP literal from string: {}", e);
+		debug_warn!("Failed to parse IP literal from string: {}", e);
 		Error::BadServerResponse("Invalid IP address")
 	})?;