From c53cc03ff8db65b6b447a852eee85e540ad38cb1 Mon Sep 17 00:00:00 2001
From: Aiden McClelland <me@drbonez.dev>
Date: Thu, 1 Jul 2021 13:38:25 -0600
Subject: [PATCH] address pr comments

---
 conduit-example.toml  |   2 +
 src/database.rs       | 121 +---------------------------------
 src/database/proxy.rs | 146 ++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 151 insertions(+), 118 deletions(-)
 create mode 100644 src/database/proxy.rs

diff --git a/conduit-example.toml b/conduit-example.toml
index 66c105be..db0bbb77 100644
--- a/conduit-example.toml
+++ b/conduit-example.toml
@@ -41,3 +41,5 @@ trusted_servers = ["matrix.org"]
 #workers = 4 # default: cpu core count * 2
 
 address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy
+
+proxy = "none" # more examples can be found at src/database/proxy.rs:6
diff --git a/src/database.rs b/src/database.rs
index 64b5ee39..0ea4d784 100644
--- a/src/database.rs
+++ b/src/database.rs
@@ -6,6 +6,7 @@ pub mod appservice;
 pub mod globals;
 pub mod key_backups;
 pub mod media;
+pub mod proxy;
 pub mod pusher;
 pub mod rooms;
 pub mod sending;
@@ -28,6 +29,8 @@ use std::{
 };
 use tokio::sync::Semaphore;
 
+use self::proxy::ProxyConfig;
+
 #[derive(Clone, Debug, Deserialize)]
 pub struct Config {
     server_name: Box<ServerName>,
@@ -85,124 +88,6 @@ pub type Engine = abstraction::SledEngine;
 #[cfg(feature = "rocksdb")]
 pub type Engine = abstraction::RocksDbEngine;
 
-#[derive(Clone, Debug, Deserialize)]
-#[serde(rename_all = "snake_case")]
-pub enum ProxyConfig {
-    None,
-    Global {
-        #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
-        url: reqwest::Url,
-    },
-    ByDomain(Vec<PartialProxyConfig>),
-}
-impl ProxyConfig {
-    pub fn to_proxy(&self) -> Result<Option<reqwest::Proxy>> {
-        Ok(match self.clone() {
-            ProxyConfig::None => None,
-            ProxyConfig::Global { url } => Some(reqwest::Proxy::all(url)?),
-            ProxyConfig::ByDomain(proxies) => Some(reqwest::Proxy::custom(move |url| {
-                proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
-            })),
-        })
-    }
-}
-impl Default for ProxyConfig {
-    fn default() -> Self {
-        ProxyConfig::None
-    }
-}
-
-#[derive(Clone, Debug, Deserialize)]
-pub struct PartialProxyConfig {
-    #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
-    url: reqwest::Url,
-    #[serde(default)]
-    include: Vec<WildCardedDomain>,
-    #[serde(default)]
-    exclude: Vec<WildCardedDomain>,
-}
-impl PartialProxyConfig {
-    pub fn for_url(&self, url: &reqwest::Url) -> Option<&reqwest::Url> {
-        let domain = url.domain()?;
-        let mut included_because = None; // most specific reason it was included
-        let mut excluded_because = None; // most specific reason it was excluded
-        if self.include.is_empty() {
-            // treat empty include list as `*`
-            included_because = Some(&WildCardedDomain::WildCard)
-        }
-        for wc_domain in &self.include {
-            if wc_domain.matches(domain) {
-                match included_because {
-                    Some(prev) if !wc_domain.more_specific_than(prev) => (),
-                    _ => included_because = Some(wc_domain),
-                }
-            }
-        }
-        for wc_domain in &self.exclude {
-            if wc_domain.matches(domain) {
-                match excluded_because {
-                    Some(prev) if !wc_domain.more_specific_than(prev) => (),
-                    _ => excluded_because = Some(wc_domain),
-                }
-            }
-        }
-        match (included_because, excluded_because) {
-            (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
-            (Some(_), None) => Some(&self.url),
-            _ => None,
-        }
-    }
-}
-
-/// A domain name, that optionally allows a * as its first subdomain.
-#[derive(Clone, Debug)]
-pub enum WildCardedDomain {
-    WildCard,
-    WildCarded(String),
-    Exact(String),
-}
-impl WildCardedDomain {
-    pub fn matches(&self, domain: &str) -> bool {
-        match self {
-            WildCardedDomain::WildCard => true,
-            WildCardedDomain::WildCarded(d) => domain.ends_with(d),
-            WildCardedDomain::Exact(d) => domain == d,
-        }
-    }
-    pub fn more_specific_than(&self, other: &Self) -> bool {
-        match (self, other) {
-            (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
-            (_, WildCardedDomain::WildCard) => true,
-            (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
-            (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
-                a != b && a.ends_with(b)
-            }
-            _ => false,
-        }
-    }
-}
-impl std::str::FromStr for WildCardedDomain {
-    type Err = std::convert::Infallible;
-    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
-        // maybe do some domain validation?
-        Ok(if s.starts_with("*.") {
-            WildCardedDomain::WildCarded(s[1..].to_owned())
-        } else if s == "*" {
-            WildCardedDomain::WildCarded("".to_owned())
-        } else {
-            WildCardedDomain::Exact(s.to_owned())
-        })
-    }
-}
-impl<'de> serde::de::Deserialize<'de> for WildCardedDomain {
-    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
-    where
-        D: serde::de::Deserializer<'de>,
-    {
-        crate::utils::deserialize_from_str(deserializer)
-    }
-}
-
 pub struct Database {
     pub globals: globals::Globals,
     pub users: users::Users,
diff --git a/src/database/proxy.rs b/src/database/proxy.rs
new file mode 100644
index 00000000..78e9d2bf
--- /dev/null
+++ b/src/database/proxy.rs
@@ -0,0 +1,146 @@
+use reqwest::{Proxy, Url};
+use serde::Deserialize;
+
+use crate::Result;
+
+/// ## Examples:
+/// - No proxy (default):
+/// ```toml
+/// proxy ="none"
+/// ```
+/// - Global proxy
+/// ```toml
+/// [proxy]
+/// global = { url = "socks5h://localhost:9050" }
+/// ```
+/// - Proxy some domains
+/// ```toml
+/// [proxy]
+/// [[proxy.by_domain]]
+/// url = "socks5h://localhost:9050"
+/// include = ["*.onion", "matrix.myspecial.onion"]
+/// exclude = ["*.myspecial.onion"]
+/// ```
+/// ## Include vs. Exclude
+/// If include is an empty list, it is assumed to be `["*"]`.
+///
+/// If a domain matches both the exclude and include list, the proxy will only be used if it was
+/// included because of a more specific rule than it was excluded. In the above example, the proxy
+/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
+#[derive(Clone, Debug, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum ProxyConfig {
+    None,
+    Global {
+        #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
+        url: Url,
+    },
+    ByDomain(Vec<PartialProxyConfig>),
+}
+impl ProxyConfig {
+    pub fn to_proxy(&self) -> Result<Option<Proxy>> {
+        Ok(match self.clone() {
+            ProxyConfig::None => None,
+            ProxyConfig::Global { url } => Some(Proxy::all(url)?),
+            ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| {
+                proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
+            })),
+        })
+    }
+}
+impl Default for ProxyConfig {
+    fn default() -> Self {
+        ProxyConfig::None
+    }
+}
+
+#[derive(Clone, Debug, Deserialize)]
+pub struct PartialProxyConfig {
+    #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
+    url: Url,
+    #[serde(default)]
+    include: Vec<WildCardedDomain>,
+    #[serde(default)]
+    exclude: Vec<WildCardedDomain>,
+}
+impl PartialProxyConfig {
+    pub fn for_url(&self, url: &Url) -> Option<&Url> {
+        let domain = url.domain()?;
+        let mut included_because = None; // most specific reason it was included
+        let mut excluded_because = None; // most specific reason it was excluded
+        if self.include.is_empty() {
+            // treat empty include list as `*`
+            included_because = Some(&WildCardedDomain::WildCard)
+        }
+        for wc_domain in &self.include {
+            if wc_domain.matches(domain) {
+                match included_because {
+                    Some(prev) if !wc_domain.more_specific_than(prev) => (),
+                    _ => included_because = Some(wc_domain),
+                }
+            }
+        }
+        for wc_domain in &self.exclude {
+            if wc_domain.matches(domain) {
+                match excluded_because {
+                    Some(prev) if !wc_domain.more_specific_than(prev) => (),
+                    _ => excluded_because = Some(wc_domain),
+                }
+            }
+        }
+        match (included_because, excluded_because) {
+            (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
+            (Some(_), None) => Some(&self.url),
+            _ => None,
+        }
+    }
+}
+
+/// A domain name, that optionally allows a * as its first subdomain.
+#[derive(Clone, Debug)]
+pub enum WildCardedDomain {
+    WildCard,
+    WildCarded(String),
+    Exact(String),
+}
+impl WildCardedDomain {
+    pub fn matches(&self, domain: &str) -> bool {
+        match self {
+            WildCardedDomain::WildCard => true,
+            WildCardedDomain::WildCarded(d) => domain.ends_with(d),
+            WildCardedDomain::Exact(d) => domain == d,
+        }
+    }
+    pub fn more_specific_than(&self, other: &Self) -> bool {
+        match (self, other) {
+            (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
+            (_, WildCardedDomain::WildCard) => true,
+            (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
+            (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
+                a != b && a.ends_with(b)
+            }
+            _ => false,
+        }
+    }
+}
+impl std::str::FromStr for WildCardedDomain {
+    type Err = std::convert::Infallible;
+    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
+        // maybe do some domain validation?
+        Ok(if s.starts_with("*.") {
+            WildCardedDomain::WildCarded(s[1..].to_owned())
+        } else if s == "*" {
+            WildCardedDomain::WildCarded("".to_owned())
+        } else {
+            WildCardedDomain::Exact(s.to_owned())
+        })
+    }
+}
+impl<'de> serde::de::Deserialize<'de> for WildCardedDomain {
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::de::Deserializer<'de>,
+    {
+        crate::utils::deserialize_from_str(deserializer)
+    }
+}