add support for arbitrary proxies
This commit is contained in:
parent
cc9111059d
commit
b2d5516058
5 changed files with 168 additions and 5 deletions
121
src/database.rs
121
src/database.rs
|
@ -46,6 +46,8 @@ pub struct Config {
|
|||
allow_federation: bool,
|
||||
#[serde(default = "false_fn")]
|
||||
pub allow_jaeger: bool,
|
||||
#[serde(default)]
|
||||
proxy: ProxyConfig,
|
||||
jwt_secret: Option<String>,
|
||||
#[serde(default = "Vec::new")]
|
||||
trusted_servers: Vec<Box<ServerName>>,
|
||||
|
@ -83,6 +85,125 @@ pub type Engine = abstraction::SledEngine;
|
|||
#[cfg(feature = "rocksdb")]
|
||||
pub type Engine = abstraction::RocksDbEngine;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ProxyConfig {
|
||||
None,
|
||||
Global {
|
||||
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
|
||||
url: reqwest::Url,
|
||||
},
|
||||
ByDomain(Vec<PartialProxyConfig>),
|
||||
}
|
||||
impl ProxyConfig {
|
||||
pub fn to_proxy(&self) -> Result<Option<reqwest::Proxy>> {
|
||||
Ok(match self.clone() {
|
||||
ProxyConfig::None => None,
|
||||
ProxyConfig::Global { url } => Some(reqwest::Proxy::all(url)?),
|
||||
ProxyConfig::ByDomain(proxies) => Some(reqwest::Proxy::custom(move |url| {
|
||||
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
|
||||
})),
|
||||
})
|
||||
}
|
||||
}
|
||||
impl Default for ProxyConfig {
|
||||
fn default() -> Self {
|
||||
ProxyConfig::None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct PartialProxyConfig {
|
||||
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
|
||||
url: reqwest::Url,
|
||||
#[serde(default)]
|
||||
include: Vec<WildCardedDomain>,
|
||||
#[serde(default)]
|
||||
exclude: Vec<WildCardedDomain>,
|
||||
}
|
||||
impl PartialProxyConfig {
|
||||
pub fn for_url(&self, url: &reqwest::Url) -> Option<&reqwest::Url> {
|
||||
let domain = url.domain()?;
|
||||
let mut included_because = None; // most specific reason it was included
|
||||
let mut excluded_because = None; // most specific reason it was excluded
|
||||
if self.include.is_empty() {
|
||||
// treat empty include list as `*`
|
||||
included_because = Some(&WildCardedDomain::WildCard)
|
||||
}
|
||||
for wc_domain in &self.include {
|
||||
if wc_domain.matches(domain) {
|
||||
match included_because {
|
||||
Some(prev) if !wc_domain.more_specific_than(prev) => (),
|
||||
_ => included_because = Some(wc_domain),
|
||||
}
|
||||
}
|
||||
}
|
||||
for wc_domain in &self.exclude {
|
||||
if wc_domain.matches(domain) {
|
||||
match excluded_because {
|
||||
Some(prev) if !wc_domain.more_specific_than(prev) => (),
|
||||
_ => excluded_because = Some(wc_domain),
|
||||
}
|
||||
}
|
||||
}
|
||||
match (included_because, excluded_because) {
|
||||
(Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
|
||||
(Some(_), None) => Some(&self.url),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A domain name, that optionally allows a * as its first subdomain.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum WildCardedDomain {
|
||||
WildCard,
|
||||
WildCarded(String),
|
||||
Exact(String),
|
||||
}
|
||||
impl WildCardedDomain {
|
||||
pub fn matches(&self, domain: &str) -> bool {
|
||||
match self {
|
||||
WildCardedDomain::WildCard => true,
|
||||
WildCardedDomain::WildCarded(d) => domain.ends_with(d),
|
||||
WildCardedDomain::Exact(d) => domain == d,
|
||||
}
|
||||
}
|
||||
pub fn more_specific_than(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
|
||||
(_, WildCardedDomain::WildCard) => true,
|
||||
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
|
||||
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
|
||||
a != b && a.ends_with(b)
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl std::str::FromStr for WildCardedDomain {
|
||||
type Err = std::convert::Infallible;
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
// maybe do some domain validation?
|
||||
Ok(if s.starts_with("*.") {
|
||||
WildCardedDomain::WildCarded(s[1..].to_owned())
|
||||
} else if s == "*" {
|
||||
WildCardedDomain::WildCarded("".to_owned())
|
||||
} else {
|
||||
WildCardedDomain::Exact(s.to_owned())
|
||||
})
|
||||
}
|
||||
}
|
||||
impl<'de> serde::de::Deserialize<'de> for WildCardedDomain {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
crate::utils::deserialize_from_str(deserializer)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Database {
|
||||
pub globals: globals::Globals,
|
||||
pub users: users::Users,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue