From acbe3bfbdab2a6d9928155520a9cd47e50f61a2d Mon Sep 17 00:00:00 2001 From: strawberry Date: Mon, 22 Apr 2024 01:52:48 -0400 Subject: [PATCH] use global `valid_cidr_range` everywhere else Signed-off-by: strawberry --- src/api/client_server/media.rs | 32 ++++---------------------------- src/service/pusher/mod.rs | 32 +++++++------------------------- 2 files changed, 11 insertions(+), 53 deletions(-) diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index 40e2b093..0e544aa6 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -692,20 +692,8 @@ async fn download_html(client: &reqwest::Client, url: &str) -> Result Result { if let Ok(ip) = IPAddress::parse(url) { - let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); - let mut cidr_ranges: Vec = Vec::new(); - - for cidr in cidr_ranges_s { - cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); - } - - for cidr in cidr_ranges { - if cidr.includes(&ip) { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Requesting from this address is forbidden", - )); - } + if !services().globals.valid_cidr_range(&ip) { + return Err(Error::BadServerResponse("Requesting from this address is forbidden")); } } @@ -714,20 +702,8 @@ async fn request_url_preview(url: &str) -> Result { if let Some(remote_addr) = response.remote_addr() { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { - let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); - let mut cidr_ranges: Vec = Vec::new(); - - for cidr in cidr_ranges_s { - cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); - } - - for cidr in cidr_ranges { - if cidr.includes(&ip) { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Requesting from this address is forbidden", - )); - } + if !services().globals.valid_cidr_range(&ip) { + return Err(Error::BadServerResponse("Requesting from this address is forbidden")); } } } diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 70d303ca..287148db 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -20,7 +20,7 @@ use ruma::{ serde::Raw, uint, RoomId, UInt, UserId, }; -use tracing::{debug, info, warn}; +use tracing::{info, trace, warn}; use crate::{services, Error, PduEvent, Result}; @@ -66,19 +66,10 @@ impl Service { let url = reqwest_request.url().clone(); if let Some(url_host) = url.host_str() { - debug!("Checking request URL for IP"); + trace!("Checking request URL for IP"); if let Ok(ip) = IPAddress::parse(url_host) { - let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); - let mut cidr_ranges: Vec = Vec::new(); - - for cidr in cidr_ranges_s { - cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); - } - - for cidr in cidr_ranges { - if cidr.includes(&ip) { - return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); - } + if !services().globals.valid_cidr_range(&ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); } } } @@ -94,20 +85,11 @@ impl Service { Ok(mut response) => { // reqwest::Response -> http::Response conversion - debug!("Checking response destination's IP"); + trace!("Checking response destination's IP"); if let Some(remote_addr) = response.remote_addr() { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { - let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); - let mut cidr_ranges: Vec = Vec::new(); - - for cidr in cidr_ranges_s { - cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); - } - - for cidr in cidr_ranges { - if cidr.includes(&ip) { - return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); - } + if !services().globals.valid_cidr_range(&ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); } } }