strong-type URL for URL previews to Url type

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2024-11-01 00:54:21 -04:00
parent 8ed9d49b73
commit 240c78e810
No known key found for this signature in database
3 changed files with 46 additions and 36 deletions

View file

@ -11,6 +11,7 @@ use conduit_service::{
media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN, MXC_LENGTH}, media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN, MXC_LENGTH},
Services, Services,
}; };
use reqwest::Url;
use ruma::{ use ruma::{
api::client::{ api::client::{
authenticated_media::{ authenticated_media::{
@ -165,23 +166,33 @@ pub(crate) async fn get_media_preview_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let url = &body.url; let url = &body.url;
if !services.media.url_preview_allowed(url) { let url = Url::parse(&body.url).map_err(|e| {
err!(Request(InvalidParam(
debug_warn!(%sender_user, %url, "Requested URL is not valid: {e}")
)))
})?;
if !services.media.url_preview_allowed(&url) {
return Err!(Request(Forbidden( return Err!(Request(Forbidden(
debug_warn!(%sender_user, %url, "URL is not allowed to be previewed") debug_warn!(%sender_user, %url, "URL is not allowed to be previewed")
))); )));
} }
let preview = services.media.get_url_preview(url).await.map_err(|error| { let preview = services
err!(Request(Unknown( .media
debug_error!(%sender_user, %url, ?error, "Failed to fetch URL preview.") .get_url_preview(&url)
))) .await
})?; .map_err(|error| {
err!(Request(Unknown(
debug_error!(%sender_user, %url, "Failed to fetch URL preview: {error}")
)))
})?;
serde_json::value::to_raw_value(&preview) serde_json::value::to_raw_value(&preview)
.map(get_media_preview::v1::Response::from_raw_value) .map(get_media_preview::v1::Response::from_raw_value)
.map_err(|error| { .map_err(|error| {
err!(Request(Unknown( err!(Request(Unknown(
debug_error!(%sender_user, %url, ?error, "Failed to parse URL preview.") debug_error!(%sender_user, %url, "Failed to parse URL preview: {error}")
))) )))
}) })
} }

View file

@ -8,6 +8,7 @@ use conduit::{
Err, Result, Err, Result,
}; };
use conduit_service::media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN}; use conduit_service::media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN};
use reqwest::Url;
use ruma::{ use ruma::{
api::client::media::{ api::client::media::{
create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config, create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config,
@ -55,25 +56,31 @@ pub(crate) async fn get_media_preview_legacy_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let url = &body.url; let url = &body.url;
if !services.media.url_preview_allowed(url) { let url = Url::parse(&body.url).map_err(|e| {
err!(Request(InvalidParam(
debug_warn!(%sender_user, %url, "Requested URL is not valid: {e}")
)))
})?;
if !services.media.url_preview_allowed(&url) {
return Err!(Request(Forbidden( return Err!(Request(Forbidden(
debug_warn!(%sender_user, %url, "URL is not allowed to be previewed") debug_warn!(%sender_user, %url, "URL is not allowed to be previewed")
))); )));
} }
let preview = services.media.get_url_preview(url).await.map_err(|e| { let preview = services.media.get_url_preview(&url).await.map_err(|e| {
err!(Request(Unknown( err!(Request(Unknown(
debug_error!(%sender_user, %url, "Failed to fetch a URL preview: {e}") debug_error!(%sender_user, %url, "Failed to fetch a URL preview: {e}")
))) )))
})?; })?;
let res = serde_json::value::to_raw_value(&preview).map_err(|e| { serde_json::value::to_raw_value(&preview)
err!(Request(Unknown( .map(get_media_preview::v3::Response::from_raw_value)
debug_error!(%sender_user, %url, "Failed to parse a URL preview: {e}") .map_err(|error| {
))) err!(Request(Unknown(
})?; debug_error!(%sender_user, %url, "Failed to parse URL preview: {error}")
)))
Ok(get_media_preview::v3::Response::from_raw_value(res)) })
} }
/// # `GET /_matrix/media/v1/preview_url` /// # `GET /_matrix/media/v1/preview_url`

View file

@ -1,6 +1,6 @@
use std::{io::Cursor, time::SystemTime}; use std::{io::Cursor, time::SystemTime};
use conduit::{debug, utils, warn, Err, Result}; use conduit::{debug, utils, Err, Result};
use conduit_core::implement; use conduit_core::implement;
use image::ImageReader as ImgReader; use image::ImageReader as ImgReader;
use ipaddress::IPAddress; use ipaddress::IPAddress;
@ -70,30 +70,30 @@ pub async fn download_image(&self, url: &str) -> Result<UrlPreviewData> {
} }
#[implement(Service)] #[implement(Service)]
pub async fn get_url_preview(&self, url: &str) -> Result<UrlPreviewData> { pub async fn get_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
if let Ok(preview) = self.db.get_url_preview(url).await { if let Ok(preview) = self.db.get_url_preview(url.as_str()).await {
return Ok(preview); return Ok(preview);
} }
// ensure that only one request is made per URL // ensure that only one request is made per URL
let _request_lock = self.url_preview_mutex.lock(url).await; let _request_lock = self.url_preview_mutex.lock(url.as_str()).await;
match self.db.get_url_preview(url).await { match self.db.get_url_preview(url.as_str()).await {
Ok(preview) => Ok(preview), Ok(preview) => Ok(preview),
Err(_) => self.request_url_preview(url).await, Err(_) => self.request_url_preview(url).await,
} }
} }
#[implement(Service)] #[implement(Service)]
async fn request_url_preview(&self, url: &str) -> Result<UrlPreviewData> { async fn request_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
if let Ok(ip) = IPAddress::parse(url) { if let Ok(ip) = IPAddress::parse(url.host_str().expect("URL previously validated")) {
if !self.services.globals.valid_cidr_range(&ip) { if !self.services.globals.valid_cidr_range(&ip) {
return Err!(BadServerResponse("Requesting from this address is forbidden")); return Err!(BadServerResponse("Requesting from this address is forbidden"));
} }
} }
let client = &self.services.client.url_preview; let client = &self.services.client.url_preview;
let response = client.head(url).send().await?; let response = client.head(url.as_str()).send().await?;
if let Some(remote_addr) = response.remote_addr() { if let Some(remote_addr) = response.remote_addr() {
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
@ -111,12 +111,12 @@ async fn request_url_preview(&self, url: &str) -> Result<UrlPreviewData> {
return Err!(Request(Unknown("Unknown Content-Type"))); return Err!(Request(Unknown("Unknown Content-Type")));
}; };
let data = match content_type { let data = match content_type {
html if html.starts_with("text/html") => self.download_html(url).await?, html if html.starts_with("text/html") => self.download_html(url.as_str()).await?,
img if img.starts_with("image/") => self.download_image(url).await?, img if img.starts_with("image/") => self.download_image(url.as_str()).await?,
_ => return Err!(Request(Unknown("Unsupported Content-Type"))), _ => return Err!(Request(Unknown("Unsupported Content-Type"))),
}; };
self.set_url_preview(url, &data).await?; self.set_url_preview(url.as_str(), &data).await?;
Ok(data) Ok(data)
} }
@ -159,15 +159,7 @@ async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
} }
#[implement(Service)] #[implement(Service)]
pub fn url_preview_allowed(&self, url_str: &str) -> bool { pub fn url_preview_allowed(&self, url: &Url) -> bool {
let url: Url = match Url::parse(url_str) {
Ok(u) => u,
Err(e) => {
warn!("Failed to parse URL from a str: {}", e);
return false;
},
};
if ["http", "https"] if ["http", "https"]
.iter() .iter()
.all(|&scheme| scheme != url.scheme().to_lowercase()) .all(|&scheme| scheme != url.scheme().to_lowercase())