From 240c78e8101da122e35986c7c1414b4f2d655d31 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 1 Nov 2024 00:54:21 -0400 Subject: [PATCH] strong-type URL for URL previews to Url type Signed-off-by: strawberry --- src/api/client/media.rs | 25 ++++++++++++++++++------- src/api/client/media_legacy.rs | 25 ++++++++++++++++--------- src/service/media/preview.rs | 32 ++++++++++++-------------------- 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/src/api/client/media.rs b/src/api/client/media.rs index 12012711..71693618 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -11,6 +11,7 @@ use conduit_service::{ media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN, MXC_LENGTH}, Services, }; +use reqwest::Url; use ruma::{ api::client::{ authenticated_media::{ @@ -165,23 +166,33 @@ pub(crate) async fn get_media_preview_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let url = &body.url; - if !services.media.url_preview_allowed(url) { + let url = Url::parse(&body.url).map_err(|e| { + err!(Request(InvalidParam( + debug_warn!(%sender_user, %url, "Requested URL is not valid: {e}") + ))) + })?; + + if !services.media.url_preview_allowed(&url) { return Err!(Request(Forbidden( debug_warn!(%sender_user, %url, "URL is not allowed to be previewed") ))); } - let preview = services.media.get_url_preview(url).await.map_err(|error| { - err!(Request(Unknown( - debug_error!(%sender_user, %url, ?error, "Failed to fetch URL preview.") - ))) - })?; + let preview = services + .media + .get_url_preview(&url) + .await + .map_err(|error| { + err!(Request(Unknown( + debug_error!(%sender_user, %url, "Failed to fetch URL preview: {error}") + ))) + })?; serde_json::value::to_raw_value(&preview) .map(get_media_preview::v1::Response::from_raw_value) .map_err(|error| { err!(Request(Unknown( - debug_error!(%sender_user, %url, ?error, "Failed to parse URL preview.") + debug_error!(%sender_user, %url, "Failed to parse URL preview: {error}") ))) }) } diff --git a/src/api/client/media_legacy.rs b/src/api/client/media_legacy.rs index e87b9a2b..f6837462 100644 --- a/src/api/client/media_legacy.rs +++ b/src/api/client/media_legacy.rs @@ -8,6 +8,7 @@ use conduit::{ Err, Result, }; use conduit_service::media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN}; +use reqwest::Url; use ruma::{ api::client::media::{ create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config, @@ -55,25 +56,31 @@ pub(crate) async fn get_media_preview_legacy_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let url = &body.url; - if !services.media.url_preview_allowed(url) { + let url = Url::parse(&body.url).map_err(|e| { + err!(Request(InvalidParam( + debug_warn!(%sender_user, %url, "Requested URL is not valid: {e}") + ))) + })?; + + if !services.media.url_preview_allowed(&url) { return Err!(Request(Forbidden( debug_warn!(%sender_user, %url, "URL is not allowed to be previewed") ))); } - let preview = services.media.get_url_preview(url).await.map_err(|e| { + let preview = services.media.get_url_preview(&url).await.map_err(|e| { err!(Request(Unknown( debug_error!(%sender_user, %url, "Failed to fetch a URL preview: {e}") ))) })?; - let res = serde_json::value::to_raw_value(&preview).map_err(|e| { - err!(Request(Unknown( - debug_error!(%sender_user, %url, "Failed to parse a URL preview: {e}") - ))) - })?; - - Ok(get_media_preview::v3::Response::from_raw_value(res)) + serde_json::value::to_raw_value(&preview) + .map(get_media_preview::v3::Response::from_raw_value) + .map_err(|error| { + err!(Request(Unknown( + debug_error!(%sender_user, %url, "Failed to parse URL preview: {error}") + ))) + }) } /// # `GET /_matrix/media/v1/preview_url` diff --git a/src/service/media/preview.rs b/src/service/media/preview.rs index 6b147383..acc9d8ed 100644 --- a/src/service/media/preview.rs +++ b/src/service/media/preview.rs @@ -1,6 +1,6 @@ use std::{io::Cursor, time::SystemTime}; -use conduit::{debug, utils, warn, Err, Result}; +use conduit::{debug, utils, Err, Result}; use conduit_core::implement; use image::ImageReader as ImgReader; use ipaddress::IPAddress; @@ -70,30 +70,30 @@ pub async fn download_image(&self, url: &str) -> Result { } #[implement(Service)] -pub async fn get_url_preview(&self, url: &str) -> Result { - if let Ok(preview) = self.db.get_url_preview(url).await { +pub async fn get_url_preview(&self, url: &Url) -> Result { + if let Ok(preview) = self.db.get_url_preview(url.as_str()).await { return Ok(preview); } // ensure that only one request is made per URL - let _request_lock = self.url_preview_mutex.lock(url).await; + let _request_lock = self.url_preview_mutex.lock(url.as_str()).await; - match self.db.get_url_preview(url).await { + match self.db.get_url_preview(url.as_str()).await { Ok(preview) => Ok(preview), Err(_) => self.request_url_preview(url).await, } } #[implement(Service)] -async fn request_url_preview(&self, url: &str) -> Result { - if let Ok(ip) = IPAddress::parse(url) { +async fn request_url_preview(&self, url: &Url) -> Result { + if let Ok(ip) = IPAddress::parse(url.host_str().expect("URL previously validated")) { if !self.services.globals.valid_cidr_range(&ip) { return Err!(BadServerResponse("Requesting from this address is forbidden")); } } let client = &self.services.client.url_preview; - let response = client.head(url).send().await?; + let response = client.head(url.as_str()).send().await?; if let Some(remote_addr) = response.remote_addr() { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { @@ -111,12 +111,12 @@ async fn request_url_preview(&self, url: &str) -> Result { return Err!(Request(Unknown("Unknown Content-Type"))); }; let data = match content_type { - html if html.starts_with("text/html") => self.download_html(url).await?, - img if img.starts_with("image/") => self.download_image(url).await?, + html if html.starts_with("text/html") => self.download_html(url.as_str()).await?, + img if img.starts_with("image/") => self.download_image(url.as_str()).await?, _ => return Err!(Request(Unknown("Unsupported Content-Type"))), }; - self.set_url_preview(url, &data).await?; + self.set_url_preview(url.as_str(), &data).await?; Ok(data) } @@ -159,15 +159,7 @@ async fn download_html(&self, url: &str) -> Result { } #[implement(Service)] -pub fn url_preview_allowed(&self, url_str: &str) -> bool { - let url: Url = match Url::parse(url_str) { - Ok(u) => u, - Err(e) => { - warn!("Failed to parse URL from a str: {}", e); - return false; - }, - }; - +pub fn url_preview_allowed(&self, url: &Url) -> bool { if ["http", "https"] .iter() .all(|&scheme| scheme != url.scheme().to_lowercase())