move cidr_range_denylist from globals to client service

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-12-03 07:35:48 +00:00
parent 9d9f403ad5
commit c01b049910
6 changed files with 28 additions and 32 deletions

View file

@ -1,6 +1,7 @@
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use conduit::{Config, Result}; use conduit::{err, implement, trace, Config, Result};
use ipaddress::IPAddress;
use reqwest::redirect; use reqwest::redirect;
use crate::{resolver, service}; use crate::{resolver, service};
@ -15,6 +16,8 @@ pub struct Service {
pub sender: reqwest::Client, pub sender: reqwest::Client,
pub appservice: reqwest::Client, pub appservice: reqwest::Client,
pub pusher: reqwest::Client, pub pusher: reqwest::Client,
pub cidr_range_denylist: Vec<IPAddress>,
} }
impl crate::Service for Service { impl crate::Service for Service {
@ -86,6 +89,14 @@ impl crate::Service for Service {
.pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout)) .pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout))
.redirect(redirect::Policy::limited(2)) .redirect(redirect::Policy::limited(2))
.build()?, .build()?,
cidr_range_denylist: config
.ip_range_denylist
.iter()
.map(IPAddress::parse)
.inspect(|cidr| trace!("Denied CIDR range: {cidr:?}"))
.collect::<Result<_, String>>()
.map_err(|e| err!(Config("ip_range_denylist", e)))?,
})) }))
} }
@ -152,3 +163,12 @@ fn base(config: &Config) -> Result<reqwest::ClientBuilder> {
Ok(builder) Ok(builder)
} }
} }
#[inline]
#[must_use]
#[implement(Service)]
pub fn valid_cidr_range(&self, ip: &IPAddress) -> bool {
self.cidr_range_denylist
.iter()
.all(|cidr| !cidr.includes(ip))
}

View file

@ -7,9 +7,8 @@ use std::{
time::Instant, time::Instant,
}; };
use conduit::{err, error, trace, Config, Result}; use conduit::{error, Config, Result};
use data::Data; use data::Data;
use ipaddress::IPAddress;
use regex::RegexSet; use regex::RegexSet;
use ruma::{ use ruma::{
OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, RoomAliasId, RoomVersionId, ServerName, UserId, OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, RoomAliasId, RoomVersionId, ServerName, UserId,
@ -22,7 +21,6 @@ pub struct Service {
pub db: Data, pub db: Data,
pub config: Config, pub config: Config,
pub cidr_range_denylist: Vec<IPAddress>,
jwt_decoding_key: Option<jsonwebtoken::DecodingKey>, jwt_decoding_key: Option<jsonwebtoken::DecodingKey>,
pub stable_room_versions: Vec<RoomVersionId>, pub stable_room_versions: Vec<RoomVersionId>,
pub unstable_room_versions: Vec<RoomVersionId>, pub unstable_room_versions: Vec<RoomVersionId>,
@ -59,14 +57,6 @@ impl crate::Service for Service {
// Experimental, partially supported room versions // Experimental, partially supported room versions
let unstable_room_versions = vec![RoomVersionId::V2, RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5]; let unstable_room_versions = vec![RoomVersionId::V2, RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5];
let cidr_range_denylist: Vec<_> = config
.ip_range_denylist
.iter()
.map(IPAddress::parse)
.inspect(|cidr| trace!("Denied CIDR range: {cidr:?}"))
.collect::<Result<_, String>>()
.map_err(|e| err!(Config("ip_range_denylist", e)))?;
let turn_secret = config let turn_secret = config
.turn_secret_file .turn_secret_file
.as_ref() .as_ref()
@ -95,7 +85,6 @@ impl crate::Service for Service {
let mut s = Self { let mut s = Self {
db, db,
config: config.clone(), config: config.clone(),
cidr_range_denylist,
jwt_decoding_key, jwt_decoding_key,
stable_room_versions, stable_room_versions,
unstable_room_versions, unstable_room_versions,
@ -255,17 +244,6 @@ impl Service {
} }
} }
#[inline]
pub fn valid_cidr_range(&self, ip: &IPAddress) -> bool {
for cidr in &self.cidr_range_denylist {
if cidr.includes(ip) {
return false;
}
}
true
}
/// checks if `user_id` is local to us via server_name comparison /// checks if `user_id` is local to us via server_name comparison
#[inline] #[inline]
pub fn user_is_local(&self, user_id: &UserId) -> bool { self.server_is_ours(user_id.server_name()) } pub fn user_is_local(&self, user_id: &UserId) -> bool { self.server_is_ours(user_id.server_name()) }

View file

@ -87,7 +87,7 @@ pub async fn get_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
#[implement(Service)] #[implement(Service)]
async fn request_url_preview(&self, url: &Url) -> Result<UrlPreviewData> { async fn request_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
if let Ok(ip) = IPAddress::parse(url.host_str().expect("URL previously validated")) { if let Ok(ip) = IPAddress::parse(url.host_str().expect("URL previously validated")) {
if !self.services.globals.valid_cidr_range(&ip) { if !self.services.client.valid_cidr_range(&ip) {
return Err!(BadServerResponse("Requesting from this address is forbidden")); return Err!(BadServerResponse("Requesting from this address is forbidden"));
} }
} }
@ -97,7 +97,7 @@ async fn request_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
if let Some(remote_addr) = response.remote_addr() { if let Some(remote_addr) = response.remote_addr() {
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
if !self.services.globals.valid_cidr_range(&ip) { if !self.services.client.valid_cidr_range(&ip) {
return Err!(BadServerResponse("Requesting from this address is forbidden")); return Err!(BadServerResponse("Requesting from this address is forbidden"));
} }
} }

View file

@ -151,7 +151,7 @@ impl Service {
if let Some(url_host) = reqwest_request.url().host_str() { if let Some(url_host) = reqwest_request.url().host_str() {
trace!("Checking request URL for IP"); trace!("Checking request URL for IP");
if let Ok(ip) = IPAddress::parse(url_host) { if let Ok(ip) = IPAddress::parse(url_host) {
if !self.services.globals.valid_cidr_range(&ip) { if !self.services.client.valid_cidr_range(&ip) {
return Err!(BadServerResponse("Not allowed to send requests to this IP")); return Err!(BadServerResponse("Not allowed to send requests to this IP"));
} }
} }
@ -166,7 +166,7 @@ impl Service {
trace!("Checking response destination's IP"); trace!("Checking response destination's IP");
if let Some(remote_addr) = response.remote_addr() { if let Some(remote_addr) = response.remote_addr() {
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
if !self.services.globals.valid_cidr_range(&ip) { if !self.services.client.valid_cidr_range(&ip) {
return Err!(BadServerResponse("Not allowed to send requests to this IP")); return Err!(BadServerResponse("Not allowed to send requests to this IP"));
} }
} }

View file

@ -358,7 +358,7 @@ impl super::Service {
} }
pub(crate) fn validate_ip(&self, ip: &IPAddress) -> Result<()> { pub(crate) fn validate_ip(&self, ip: &IPAddress) -> Result<()> {
if !self.services.globals.valid_cidr_range(ip) { if !self.services.client.valid_cidr_range(ip) {
return Err!(BadServerResponse("Not allowed to send requests to this IP")); return Err!(BadServerResponse("Not allowed to send requests to this IP"));
} }

View file

@ -9,7 +9,7 @@ use std::{fmt::Write, sync::Arc};
use conduit::{Result, Server}; use conduit::{Result, Server};
use self::{cache::Cache, dns::Resolver}; use self::{cache::Cache, dns::Resolver};
use crate::{client, globals, Dep}; use crate::{client, Dep};
pub struct Service { pub struct Service {
pub cache: Arc<Cache>, pub cache: Arc<Cache>,
@ -20,7 +20,6 @@ pub struct Service {
struct Services { struct Services {
server: Arc<Server>, server: Arc<Server>,
client: Dep<client::Service>, client: Dep<client::Service>,
globals: Dep<globals::Service>,
} }
impl crate::Service for Service { impl crate::Service for Service {
@ -33,7 +32,6 @@ impl crate::Service for Service {
services: Services { services: Services {
server: args.server.clone(), server: args.server.clone(),
client: args.depend::<client::Service>("client"), client: args.depend::<client::Service>("client"),
globals: args.depend::<globals::Service>("globals"),
}, },
})) }))
} }