strong-type URL for URL previews to Url type
Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
parent
8ed9d49b73
commit
240c78e810
3 changed files with 46 additions and 36 deletions
|
@ -11,6 +11,7 @@ use conduit_service::{
|
||||||
media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN, MXC_LENGTH},
|
media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN, MXC_LENGTH},
|
||||||
Services,
|
Services,
|
||||||
};
|
};
|
||||||
|
use reqwest::Url;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{
|
api::client::{
|
||||||
authenticated_media::{
|
authenticated_media::{
|
||||||
|
@ -165,23 +166,33 @@ pub(crate) async fn get_media_preview_route(
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let url = &body.url;
|
let url = &body.url;
|
||||||
if !services.media.url_preview_allowed(url) {
|
let url = Url::parse(&body.url).map_err(|e| {
|
||||||
|
err!(Request(InvalidParam(
|
||||||
|
debug_warn!(%sender_user, %url, "Requested URL is not valid: {e}")
|
||||||
|
)))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if !services.media.url_preview_allowed(&url) {
|
||||||
return Err!(Request(Forbidden(
|
return Err!(Request(Forbidden(
|
||||||
debug_warn!(%sender_user, %url, "URL is not allowed to be previewed")
|
debug_warn!(%sender_user, %url, "URL is not allowed to be previewed")
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let preview = services.media.get_url_preview(url).await.map_err(|error| {
|
let preview = services
|
||||||
err!(Request(Unknown(
|
.media
|
||||||
debug_error!(%sender_user, %url, ?error, "Failed to fetch URL preview.")
|
.get_url_preview(&url)
|
||||||
)))
|
.await
|
||||||
})?;
|
.map_err(|error| {
|
||||||
|
err!(Request(Unknown(
|
||||||
|
debug_error!(%sender_user, %url, "Failed to fetch URL preview: {error}")
|
||||||
|
)))
|
||||||
|
})?;
|
||||||
|
|
||||||
serde_json::value::to_raw_value(&preview)
|
serde_json::value::to_raw_value(&preview)
|
||||||
.map(get_media_preview::v1::Response::from_raw_value)
|
.map(get_media_preview::v1::Response::from_raw_value)
|
||||||
.map_err(|error| {
|
.map_err(|error| {
|
||||||
err!(Request(Unknown(
|
err!(Request(Unknown(
|
||||||
debug_error!(%sender_user, %url, ?error, "Failed to parse URL preview.")
|
debug_error!(%sender_user, %url, "Failed to parse URL preview: {error}")
|
||||||
)))
|
)))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ use conduit::{
|
||||||
Err, Result,
|
Err, Result,
|
||||||
};
|
};
|
||||||
use conduit_service::media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN};
|
use conduit_service::media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN};
|
||||||
|
use reqwest::Url;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::media::{
|
api::client::media::{
|
||||||
create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config,
|
create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config,
|
||||||
|
@ -55,25 +56,31 @@ pub(crate) async fn get_media_preview_legacy_route(
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let url = &body.url;
|
let url = &body.url;
|
||||||
if !services.media.url_preview_allowed(url) {
|
let url = Url::parse(&body.url).map_err(|e| {
|
||||||
|
err!(Request(InvalidParam(
|
||||||
|
debug_warn!(%sender_user, %url, "Requested URL is not valid: {e}")
|
||||||
|
)))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if !services.media.url_preview_allowed(&url) {
|
||||||
return Err!(Request(Forbidden(
|
return Err!(Request(Forbidden(
|
||||||
debug_warn!(%sender_user, %url, "URL is not allowed to be previewed")
|
debug_warn!(%sender_user, %url, "URL is not allowed to be previewed")
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let preview = services.media.get_url_preview(url).await.map_err(|e| {
|
let preview = services.media.get_url_preview(&url).await.map_err(|e| {
|
||||||
err!(Request(Unknown(
|
err!(Request(Unknown(
|
||||||
debug_error!(%sender_user, %url, "Failed to fetch a URL preview: {e}")
|
debug_error!(%sender_user, %url, "Failed to fetch a URL preview: {e}")
|
||||||
)))
|
)))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let res = serde_json::value::to_raw_value(&preview).map_err(|e| {
|
serde_json::value::to_raw_value(&preview)
|
||||||
err!(Request(Unknown(
|
.map(get_media_preview::v3::Response::from_raw_value)
|
||||||
debug_error!(%sender_user, %url, "Failed to parse a URL preview: {e}")
|
.map_err(|error| {
|
||||||
)))
|
err!(Request(Unknown(
|
||||||
})?;
|
debug_error!(%sender_user, %url, "Failed to parse URL preview: {error}")
|
||||||
|
)))
|
||||||
Ok(get_media_preview::v3::Response::from_raw_value(res))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `GET /_matrix/media/v1/preview_url`
|
/// # `GET /_matrix/media/v1/preview_url`
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use std::{io::Cursor, time::SystemTime};
|
use std::{io::Cursor, time::SystemTime};
|
||||||
|
|
||||||
use conduit::{debug, utils, warn, Err, Result};
|
use conduit::{debug, utils, Err, Result};
|
||||||
use conduit_core::implement;
|
use conduit_core::implement;
|
||||||
use image::ImageReader as ImgReader;
|
use image::ImageReader as ImgReader;
|
||||||
use ipaddress::IPAddress;
|
use ipaddress::IPAddress;
|
||||||
|
@ -70,30 +70,30 @@ pub async fn download_image(&self, url: &str) -> Result<UrlPreviewData> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[implement(Service)]
|
#[implement(Service)]
|
||||||
pub async fn get_url_preview(&self, url: &str) -> Result<UrlPreviewData> {
|
pub async fn get_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
|
||||||
if let Ok(preview) = self.db.get_url_preview(url).await {
|
if let Ok(preview) = self.db.get_url_preview(url.as_str()).await {
|
||||||
return Ok(preview);
|
return Ok(preview);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure that only one request is made per URL
|
// ensure that only one request is made per URL
|
||||||
let _request_lock = self.url_preview_mutex.lock(url).await;
|
let _request_lock = self.url_preview_mutex.lock(url.as_str()).await;
|
||||||
|
|
||||||
match self.db.get_url_preview(url).await {
|
match self.db.get_url_preview(url.as_str()).await {
|
||||||
Ok(preview) => Ok(preview),
|
Ok(preview) => Ok(preview),
|
||||||
Err(_) => self.request_url_preview(url).await,
|
Err(_) => self.request_url_preview(url).await,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[implement(Service)]
|
#[implement(Service)]
|
||||||
async fn request_url_preview(&self, url: &str) -> Result<UrlPreviewData> {
|
async fn request_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
|
||||||
if let Ok(ip) = IPAddress::parse(url) {
|
if let Ok(ip) = IPAddress::parse(url.host_str().expect("URL previously validated")) {
|
||||||
if !self.services.globals.valid_cidr_range(&ip) {
|
if !self.services.globals.valid_cidr_range(&ip) {
|
||||||
return Err!(BadServerResponse("Requesting from this address is forbidden"));
|
return Err!(BadServerResponse("Requesting from this address is forbidden"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let client = &self.services.client.url_preview;
|
let client = &self.services.client.url_preview;
|
||||||
let response = client.head(url).send().await?;
|
let response = client.head(url.as_str()).send().await?;
|
||||||
|
|
||||||
if let Some(remote_addr) = response.remote_addr() {
|
if let Some(remote_addr) = response.remote_addr() {
|
||||||
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
|
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
|
||||||
|
@ -111,12 +111,12 @@ async fn request_url_preview(&self, url: &str) -> Result<UrlPreviewData> {
|
||||||
return Err!(Request(Unknown("Unknown Content-Type")));
|
return Err!(Request(Unknown("Unknown Content-Type")));
|
||||||
};
|
};
|
||||||
let data = match content_type {
|
let data = match content_type {
|
||||||
html if html.starts_with("text/html") => self.download_html(url).await?,
|
html if html.starts_with("text/html") => self.download_html(url.as_str()).await?,
|
||||||
img if img.starts_with("image/") => self.download_image(url).await?,
|
img if img.starts_with("image/") => self.download_image(url.as_str()).await?,
|
||||||
_ => return Err!(Request(Unknown("Unsupported Content-Type"))),
|
_ => return Err!(Request(Unknown("Unsupported Content-Type"))),
|
||||||
};
|
};
|
||||||
|
|
||||||
self.set_url_preview(url, &data).await?;
|
self.set_url_preview(url.as_str(), &data).await?;
|
||||||
|
|
||||||
Ok(data)
|
Ok(data)
|
||||||
}
|
}
|
||||||
|
@ -159,15 +159,7 @@ async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[implement(Service)]
|
#[implement(Service)]
|
||||||
pub fn url_preview_allowed(&self, url_str: &str) -> bool {
|
pub fn url_preview_allowed(&self, url: &Url) -> bool {
|
||||||
let url: Url = match Url::parse(url_str) {
|
|
||||||
Ok(u) => u,
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to parse URL from a str: {}", e);
|
|
||||||
return false;
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
if ["http", "https"]
|
if ["http", "https"]
|
||||||
.iter()
|
.iter()
|
||||||
.all(|&scheme| scheme != url.scheme().to_lowercase())
|
.all(|&scheme| scheme != url.scheme().to_lowercase())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue