From 3af303e52bc2908d3c69fecf62e98159745e3610 Mon Sep 17 00:00:00 2001
From: Jason Volk <jason@zemos.net>
Date: Sun, 17 Mar 2024 02:20:23 -0400
Subject: [PATCH] complete federation destination caching preempting
 getaddrinfo(3).

fixed some clippy lints and spacing adjusted

Signed-off-by: Jason Volk <jason@zemos.net>
Signed-off-by: strawberry <strawberry@puppygock.gay>
---
 Cargo.lock                 |  1 +
 Cargo.toml                 |  1 +
 src/api/server_server.rs   | 78 ++++++++++++++++++++++----------------
 src/service/globals/mod.rs |  2 +
 4 files changed, 50 insertions(+), 32 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 68f8d9af..ddc84c1f 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2062,6 +2062,7 @@ dependencies = [
  "tokio-rustls",
  "tokio-socks",
  "tower-service",
+ "trust-dns-resolver",
  "url",
  "wasm-bindgen",
  "wasm-bindgen-futures",
diff --git a/Cargo.toml b/Cargo.toml
index c241b1fc..0c84c897 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -106,6 +106,7 @@ default-features = false
 features = [
   "rustls-tls-native-roots",
   "socks",
+  "trust-dns",
 ]
 
 # all the serde stuff
diff --git a/src/api/server_server.rs b/src/api/server_server.rs
index 4edfd09d..07be27f4 100644
--- a/src/api/server_server.rs
+++ b/src/api/server_server.rs
@@ -365,7 +365,10 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe
 		None => {
 			if let Some(pos) = destination_str.find(':') {
 				debug!("2: Hostname with included port");
+
 				let (host, port) = destination_str.split_at(pos);
+				query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448)).await;
+
 				FedDest::Named(host.to_owned(), port.to_owned())
 			} else {
 				debug!("Requesting well known for {destination}");
@@ -378,30 +381,23 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe
 							None => {
 								if let Some(pos) = delegated_hostname.find(':') {
 									debug!("3.2: Hostname with port in .well-known file");
+
 									let (host, port) = delegated_hostname.split_at(pos);
+									query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448)).await;
+
 									FedDest::Named(host.to_owned(), port.to_owned())
 								} else {
 									debug!("Delegated hostname has no port in this branch");
 									if let Some(hostname_override) = query_srv_record(&delegated_hostname).await {
 										debug!("3.3: SRV lookup successful");
-										let force_port = hostname_override.port();
 
-										if let Ok(override_ip) = services()
-											.globals
-											.dns_resolver()
-											.lookup_ip(hostname_override.hostname())
-											.await
-										{
-											services().globals.tls_name_override.write().unwrap().insert(
-												delegated_hostname.clone(),
-												(override_ip.iter().collect(), force_port.unwrap_or(8448)),
-											);
-										} else {
-											debug!(
-												"Using SRV record {}, but could not resolve to IP",
-												hostname_override.hostname()
-											);
-										}
+										let force_port = hostname_override.port();
+										query_and_cache_override(
+											&delegated_hostname,
+											&hostname_override.hostname(),
+											force_port.unwrap_or(8448),
+										)
+										.await;
 
 										if let Some(port) = force_port {
 											FedDest::Named(delegated_hostname, format!(":{port}"))
@@ -410,6 +406,7 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe
 										}
 									} else {
 										debug!("3.4: No SRV records, just use the hostname from .well-known");
+										query_and_cache_override(&delegated_hostname, &delegated_hostname, 8448).await;
 										add_port_to_hostname(&delegated_hostname)
 									}
 								}
@@ -421,21 +418,14 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe
 						match query_srv_record(&destination_str).await {
 							Some(hostname_override) => {
 								debug!("4: SRV record found");
-								let force_port = hostname_override.port();
 
-								if let Ok(override_ip) =
-									services().globals.dns_resolver().lookup_ip(hostname_override.hostname()).await
-								{
-									services().globals.tls_name_override.write().unwrap().insert(
-										hostname.clone(),
-										(override_ip.iter().collect(), force_port.unwrap_or(8448)),
-									);
-								} else {
-									debug!(
-										"Using SRV record {}, but could not resolve to IP",
-										hostname_override.hostname()
-									);
-								}
+								let force_port = hostname_override.port();
+								query_and_cache_override(
+									&hostname,
+									&hostname_override.hostname(),
+									force_port.unwrap_or(8448),
+								)
+								.await;
 
 								if let Some(port) = force_port {
 									FedDest::Named(hostname.clone(), format!(":{port}"))
@@ -445,6 +435,7 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe
 							},
 							None => {
 								debug!("5: No SRV record found");
+								query_and_cache_override(&destination_str, &destination_str, 8448).await;
 								add_port_to_hostname(&destination_str)
 							},
 						}
@@ -453,7 +444,6 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe
 			}
 		},
 	};
-	debug!("Actual destination: {actual_destination:?}");
 
 	// Can't use get_ip_with_port here because we don't want to add a port
 	// to an IP address if it wasn't specified
@@ -467,9 +457,29 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe
 	} else {
 		FedDest::Named(hostname, ":8448".to_owned())
 	};
+
+	debug!("Actual destination: {actual_destination:?} hostname: {hostname:?}");
 	(actual_destination, hostname)
 }
 
+async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) {
+	match services().globals.dns_resolver().lookup_ip(hostname.to_owned()).await {
+		Ok(override_ip) => {
+			debug!("Caching result of {:?} overriding {:?}", hostname, overname);
+
+			services()
+				.globals
+				.tls_name_override
+				.write()
+				.unwrap()
+				.insert(overname.to_owned(), (override_ip.iter().collect(), port));
+		},
+		Err(e) => {
+			debug!("Got {:?} for {:?} to override {:?}", e.kind(), hostname, overname);
+		},
+	}
+}
+
 async fn query_srv_record(hostname: &'_ str) -> Option<FedDest> {
 	fn handle_successful_srv(srv: &SrvLookup) -> Option<FedDest> {
 		srv.iter().next().map(|result| {
@@ -501,6 +511,10 @@ async fn query_srv_record(hostname: &'_ str) -> Option<FedDest> {
 }
 
 async fn request_well_known(destination: &str) -> Option<String> {
+	if !services().globals.tls_name_override.read().unwrap().contains_key(destination) {
+		query_and_cache_override(destination, destination, 8448).await;
+	}
+
 	let response = services()
 		.globals
 		.default_client()
diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs
index 8313dc0e..f54f0686 100644
--- a/src/service/globals/mod.rs
+++ b/src/service/globals/mod.rs
@@ -495,6 +495,7 @@ fn reqwest_client_builder(config: &Config) -> Result<reqwest::ClientBuilder> {
 	});
 
 	let mut reqwest_client_builder = reqwest::Client::builder()
+		.trust_dns(true)
 		.pool_max_idle_per_host(0)
 		.connect_timeout(Duration::from_secs(60))
 		.timeout(Duration::from_secs(60 * 5))
@@ -522,6 +523,7 @@ fn url_preview_reqwest_client_builder(config: &Config) -> Result<reqwest::Client
 	});
 
 	let mut reqwest_client_builder = reqwest::Client::builder()
+		.trust_dns(true)
 		.pool_max_idle_per_host(0)
 		.connect_timeout(Duration::from_secs(60))
 		.timeout(Duration::from_secs(60 * 5))