From b505f0d0d7a8ec2accc4b38dfe3391c9f780ba25 Mon Sep 17 00:00:00 2001
From: Jason Volk <jason@zemos.net>
Date: Mon, 21 Oct 2024 22:00:39 +0000
Subject: [PATCH] add (back) query_trusted_key_servers_first w/ additional
 configuration detail

Signed-off-by: Jason Volk <jason@zemos.net>
---
 src/core/config/mod.rs             | 29 +++++++++++++++
 src/service/server_keys/acquire.rs | 59 +++++++++++++++++++++++-------
 src/service/server_keys/get.rs     | 47 ++++++++++++++++++++----
 src/service/server_keys/mod.rs     |  4 +-
 4 files changed, 116 insertions(+), 23 deletions(-)

diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs
index 02b277d0..52ce8a01 100644
--- a/src/core/config/mod.rs
+++ b/src/core/config/mod.rs
@@ -490,6 +490,35 @@ pub struct Config {
 	#[serde(default = "default_trusted_servers")]
 	pub trusted_servers: Vec<OwnedServerName>,
 
+	/// Whether to query the servers listed in trusted_servers first or query
+	/// the origin server first. For best security, querying the origin server
+	/// first is advised to minimize the exposure to a compromised trusted
+	/// server. For maximum performance this can be set to true, however other
+	/// options exist to query trusted servers first under specific high-load
+	/// circumstances and should be evaluated before setting this to true.
+	#[serde(default)]
+	pub query_trusted_key_servers_first: bool,
+
+	/// Whether to query the servers listed in trusted_servers first
+	/// specifically on room joins. This option limits the exposure to a
+	/// compromised trusted server to room joins only. The join operation
+	/// requires gathering keys from many origin servers which can cause
+	/// significant delays. Therefor this defaults to true to mitigate
+	/// unexpected delays out-of-the-box. The security-paranoid or those
+	/// willing to tolerate delays are advised to set this to false. Note that
+	/// setting query_trusted_key_servers_first to true causes this option to
+	/// be ignored.
+	#[serde(default = "true_fn")]
+	pub query_trusted_key_servers_first_on_join: bool,
+
+	/// Only query trusted servers for keys and never the origin server. This is
+	/// intended for clusters or custom deployments using their trusted_servers
+	/// as forwarding-agents to cache and deduplicate requests. Notary servers
+	/// do not act as forwarding-agents by default, therefor do not enable this
+	/// unless you know exactly what you are doing.
+	#[serde(default)]
+	pub only_query_trusted_key_servers: bool,
+
 	/// max log level for conduwuit. allows debug, info, warn, or error
 	/// see also: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives
 	/// **Caveat**:
diff --git a/src/service/server_keys/acquire.rs b/src/service/server_keys/acquire.rs
index 2b170040..25b676b8 100644
--- a/src/service/server_keys/acquire.rs
+++ b/src/service/server_keys/acquire.rs
@@ -47,35 +47,66 @@ where
 	S: Iterator<Item = (&'a ServerName, K)> + Send + Clone,
 	K: Iterator<Item = &'a ServerSigningKeyId> + Send + Clone,
 {
+	let notary_only = self.services.server.config.only_query_trusted_key_servers;
+	let notary_first_always = self.services.server.config.query_trusted_key_servers_first;
+	let notary_first_on_join = self
+		.services
+		.server
+		.config
+		.query_trusted_key_servers_first_on_join;
+
 	let requested_servers = batch.clone().count();
 	let requested_keys = batch.clone().flat_map(|(_, key_ids)| key_ids).count();
 
 	debug!("acquire {requested_keys} keys from {requested_servers}");
 
-	let missing = self.acquire_locals(batch).await;
-	let missing_keys = keys_count(&missing);
-	let missing_servers = missing.len();
+	let mut missing = self.acquire_locals(batch).await;
+	let mut missing_keys = keys_count(&missing);
+	let mut missing_servers = missing.len();
 	if missing_servers == 0 {
 		return;
 	}
 
 	debug!("missing {missing_keys} keys for {missing_servers} servers locally");
 
-	let missing = self.acquire_origins(missing.into_iter()).await;
-	let missing_keys = keys_count(&missing);
-	let missing_servers = missing.len();
-	if missing_servers == 0 {
-		return;
+	if notary_first_always || notary_first_on_join {
+		missing = self.acquire_notary(missing.into_iter()).await;
+		missing_keys = keys_count(&missing);
+		missing_servers = missing.len();
+		if missing_keys == 0 {
+			return;
+		}
+
+		debug_warn!("missing {missing_keys} keys for {missing_servers} servers from all notaries first");
 	}
 
-	debug_warn!("missing {missing_keys} keys for {missing_servers} servers unreachable");
+	if !notary_only {
+		missing = self.acquire_origins(missing.into_iter()).await;
+		missing_keys = keys_count(&missing);
+		missing_servers = missing.len();
+		if missing_keys == 0 {
+			return;
+		}
+
+		debug_warn!("missing {missing_keys} keys for {missing_servers} servers unreachable");
+	}
+
+	if !notary_first_always && !notary_first_on_join {
+		missing = self.acquire_notary(missing.into_iter()).await;
+		missing_keys = keys_count(&missing);
+		missing_servers = missing.len();
+		if missing_keys == 0 {
+			return;
+		}
+
+		debug_warn!("still missing {missing_keys} keys for {missing_servers} servers from all notaries.");
+	}
 
-	let missing = self.acquire_notary(missing.into_iter()).await;
-	let missing_keys = keys_count(&missing);
-	let missing_servers = missing.len();
 	if missing_keys > 0 {
-		debug_warn!("still missing {missing_keys} keys for {missing_servers} servers from all notaries");
-		warn!("did not obtain {missing_keys} of {requested_keys} keys; some events may not be accepted");
+		warn!(
+			"did not obtain {missing_keys} keys for {missing_servers} servers out of {requested_keys} total keys for \
+			 {requested_servers} total servers; some events may not be verifiable"
+		);
 	}
 }
 
diff --git a/src/service/server_keys/get.rs b/src/service/server_keys/get.rs
index 0f449b46..441e33d4 100644
--- a/src/service/server_keys/get.rs
+++ b/src/service/server_keys/get.rs
@@ -53,17 +53,40 @@ where
 
 #[implement(super::Service)]
 pub async fn get_verify_key(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> {
+	let notary_first = self.services.server.config.query_trusted_key_servers_first;
+	let notary_only = self.services.server.config.only_query_trusted_key_servers;
+
 	if let Some(result) = self.verify_keys_for(origin).await.remove(key_id) {
 		return Ok(result);
 	}
 
-	if let Ok(server_key) = self.server_request(origin).await {
-		self.add_signing_keys(server_key.clone()).await;
-		if let Some(result) = extract_key(server_key, key_id) {
+	if notary_first {
+		if let Ok(result) = self.get_verify_key_from_notaries(origin, key_id).await {
 			return Ok(result);
 		}
 	}
 
+	if !notary_only {
+		if let Ok(result) = self.get_verify_key_from_origin(origin, key_id).await {
+			return Ok(result);
+		}
+	}
+
+	if !notary_first {
+		if let Ok(result) = self.get_verify_key_from_notaries(origin, key_id).await {
+			return Ok(result);
+		}
+	}
+
+	Err!(BadServerResponse(debug_error!(
+		?key_id,
+		?origin,
+		"Failed to fetch federation signing-key"
+	)))
+}
+
+#[implement(super::Service)]
+async fn get_verify_key_from_notaries(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> {
 	for notary in self.services.globals.trusted_servers() {
 		if let Ok(server_keys) = self.notary_request(notary, origin).await {
 			for server_key in &server_keys {
@@ -78,9 +101,17 @@ pub async fn get_verify_key(&self, origin: &ServerName, key_id: &ServerSigningKe
 		}
 	}
 
-	Err!(BadServerResponse(debug_error!(
-		?key_id,
-		?origin,
-		"Failed to fetch federation signing-key"
-	)))
+	Err!(Request(NotFound("Failed to fetch signing-key from notaries")))
+}
+
+#[implement(super::Service)]
+async fn get_verify_key_from_origin(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> {
+	if let Ok(server_key) = self.server_request(origin).await {
+		self.add_signing_keys(server_key.clone()).await;
+		if let Some(result) = extract_key(server_key, key_id) {
+			return Ok(result);
+		}
+	}
+
+	Err!(Request(NotFound("Failed to fetch signing-key from origin")))
 }
diff --git a/src/service/server_keys/mod.rs b/src/service/server_keys/mod.rs
index c3b84cb3..dc09703c 100644
--- a/src/service/server_keys/mod.rs
+++ b/src/service/server_keys/mod.rs
@@ -7,7 +7,7 @@ mod verify;
 
 use std::{collections::BTreeMap, sync::Arc, time::Duration};
 
-use conduit::{implement, utils::time::timepoint_from_now, Result};
+use conduit::{implement, utils::time::timepoint_from_now, Result, Server};
 use database::{Deserialized, Json, Map};
 use ruma::{
 	api::federation::discovery::{ServerSigningKeys, VerifyKey},
@@ -30,6 +30,7 @@ pub struct Service {
 struct Services {
 	globals: Dep<globals::Service>,
 	sending: Dep<sending::Service>,
+	server: Arc<Server>,
 }
 
 struct Data {
@@ -52,6 +53,7 @@ impl crate::Service for Service {
 			services: Services {
 				globals: args.depend::<globals::Service>("globals"),
 				sending: args.depend::<sending::Service>("sending"),
+				server: args.server.clone(),
 			},
 			db: Data {
 				server_signingkeys: args.db["server_signingkeys"].clone(),