de-global services() from api
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
463f1a1287
commit
8b6018d77d
61 changed files with 1485 additions and 1320 deletions
|
@ -2,6 +2,7 @@
|
|||
|
||||
use std::{io::Cursor, sync::Arc, time::Duration};
|
||||
|
||||
use axum::extract::State;
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use conduit::{debug, error, utils::math::ruma_from_usize, warn};
|
||||
use image::io::Reader as ImgReader;
|
||||
|
@ -20,9 +21,8 @@ use crate::{
|
|||
debug_warn,
|
||||
service::{
|
||||
media::{FileMeta, UrlPreviewData},
|
||||
server_is_ours,
|
||||
server_is_ours, Services,
|
||||
},
|
||||
services,
|
||||
utils::{
|
||||
self,
|
||||
content_disposition::{content_disposition_type, make_content_disposition, sanitise_filename},
|
||||
|
@ -42,10 +42,10 @@ const CORP_CROSS_ORIGIN: &str = "cross-origin";
|
|||
///
|
||||
/// Returns max upload size.
|
||||
pub(crate) async fn get_media_config_route(
|
||||
_body: Ruma<get_media_config::v3::Request>,
|
||||
State(services): State<crate::State>, _body: Ruma<get_media_config::v3::Request>,
|
||||
) -> Result<get_media_config::v3::Response> {
|
||||
Ok(get_media_config::v3::Response {
|
||||
upload_size: ruma_from_usize(services().globals.config.max_request_size),
|
||||
upload_size: ruma_from_usize(services.globals.config.max_request_size),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -57,9 +57,11 @@ pub(crate) async fn get_media_config_route(
|
|||
///
|
||||
/// Returns max upload size.
|
||||
pub(crate) async fn get_media_config_v1_route(
|
||||
body: Ruma<get_media_config::v3::Request>,
|
||||
State(services): State<crate::State>, body: Ruma<get_media_config::v3::Request>,
|
||||
) -> Result<RumaResponse<get_media_config::v3::Response>> {
|
||||
get_media_config_route(body).await.map(RumaResponse)
|
||||
get_media_config_route(State(services), body)
|
||||
.await
|
||||
.map(RumaResponse)
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/media/v3/preview_url`
|
||||
|
@ -67,17 +69,18 @@ pub(crate) async fn get_media_config_v1_route(
|
|||
/// Returns URL preview.
|
||||
#[tracing::instrument(skip_all, fields(%client), name = "url_preview")]
|
||||
pub(crate) async fn get_media_preview_route(
|
||||
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_media_preview::v3::Request>,
|
||||
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<get_media_preview::v3::Request>,
|
||||
) -> Result<get_media_preview::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let url = &body.url;
|
||||
if !url_preview_allowed(url) {
|
||||
if !url_preview_allowed(services, url) {
|
||||
warn!(%sender_user, "URL is not allowed to be previewed: {url}");
|
||||
return Err(Error::BadRequest(ErrorKind::forbidden(), "URL is not allowed to be previewed"));
|
||||
}
|
||||
|
||||
match get_url_preview(url).await {
|
||||
match get_url_preview(services, url).await {
|
||||
Ok(preview) => {
|
||||
let res = serde_json::value::to_raw_value(&preview).map_err(|e| {
|
||||
error!(%sender_user, "Failed to convert UrlPreviewData into a serde json value: {e}");
|
||||
|
@ -115,9 +118,10 @@ pub(crate) async fn get_media_preview_route(
|
|||
/// Returns URL preview.
|
||||
#[tracing::instrument(skip_all, fields(%client), name = "url_preview")]
|
||||
pub(crate) async fn get_media_preview_v1_route(
|
||||
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_media_preview::v3::Request>,
|
||||
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<get_media_preview::v3::Request>,
|
||||
) -> Result<RumaResponse<get_media_preview::v3::Response>> {
|
||||
get_media_preview_route(InsecureClientIp(client), body)
|
||||
get_media_preview_route(State(services), InsecureClientIp(client), body)
|
||||
.await
|
||||
.map(RumaResponse)
|
||||
}
|
||||
|
@ -130,17 +134,14 @@ pub(crate) async fn get_media_preview_v1_route(
|
|||
/// - Media will be saved in the media/ directory
|
||||
#[tracing::instrument(skip_all, fields(%client), name = "media_upload")]
|
||||
pub(crate) async fn create_content_route(
|
||||
InsecureClientIp(client): InsecureClientIp, body: Ruma<create_content::v3::Request>,
|
||||
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<create_content::v3::Request>,
|
||||
) -> Result<create_content::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let mxc = format!(
|
||||
"mxc://{}/{}",
|
||||
services().globals.server_name(),
|
||||
utils::random_string(MXC_LENGTH)
|
||||
);
|
||||
let mxc = format!("mxc://{}/{}", services.globals.server_name(), utils::random_string(MXC_LENGTH));
|
||||
|
||||
services()
|
||||
services
|
||||
.media
|
||||
.create(
|
||||
Some(sender_user.clone()),
|
||||
|
@ -178,9 +179,10 @@ pub(crate) async fn create_content_route(
|
|||
/// - Media will be saved in the media/ directory
|
||||
#[tracing::instrument(skip_all, fields(%client), name = "media_upload")]
|
||||
pub(crate) async fn create_content_v1_route(
|
||||
InsecureClientIp(client): InsecureClientIp, body: Ruma<create_content::v3::Request>,
|
||||
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<create_content::v3::Request>,
|
||||
) -> Result<RumaResponse<create_content::v3::Response>> {
|
||||
create_content_route(InsecureClientIp(client), body)
|
||||
create_content_route(State(services), InsecureClientIp(client), body)
|
||||
.await
|
||||
.map(RumaResponse)
|
||||
}
|
||||
|
@ -195,7 +197,8 @@ pub(crate) async fn create_content_v1_route(
|
|||
/// seconds
|
||||
#[tracing::instrument(skip_all, fields(%client), name = "media_get")]
|
||||
pub(crate) async fn get_content_route(
|
||||
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content::v3::Request>,
|
||||
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<get_content::v3::Request>,
|
||||
) -> Result<get_content::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
|
||||
|
@ -203,7 +206,7 @@ pub(crate) async fn get_content_route(
|
|||
content,
|
||||
content_type,
|
||||
content_disposition,
|
||||
}) = services().media.get(&mxc).await?
|
||||
}) = services.media.get(&mxc).await?
|
||||
{
|
||||
let content_disposition = Some(make_content_disposition(&content_type, content_disposition, None));
|
||||
let file = content.expect("content");
|
||||
|
@ -217,6 +220,7 @@ pub(crate) async fn get_content_route(
|
|||
})
|
||||
} else if !server_is_ours(&body.server_name) && body.allow_remote {
|
||||
let response = get_remote_content(
|
||||
services,
|
||||
&mxc,
|
||||
&body.server_name,
|
||||
body.media_id.clone(),
|
||||
|
@ -261,9 +265,10 @@ pub(crate) async fn get_content_route(
|
|||
/// seconds
|
||||
#[tracing::instrument(skip_all, fields(%client), name = "media_get")]
|
||||
pub(crate) async fn get_content_v1_route(
|
||||
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content::v3::Request>,
|
||||
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<get_content::v3::Request>,
|
||||
) -> Result<RumaResponse<get_content::v3::Response>> {
|
||||
get_content_route(InsecureClientIp(client), body)
|
||||
get_content_route(State(services), InsecureClientIp(client), body)
|
||||
.await
|
||||
.map(RumaResponse)
|
||||
}
|
||||
|
@ -278,7 +283,8 @@ pub(crate) async fn get_content_v1_route(
|
|||
/// seconds
|
||||
#[tracing::instrument(skip_all, fields(%client), name = "media_get")]
|
||||
pub(crate) async fn get_content_as_filename_route(
|
||||
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_as_filename::v3::Request>,
|
||||
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<get_content_as_filename::v3::Request>,
|
||||
) -> Result<get_content_as_filename::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
|
||||
|
@ -286,7 +292,7 @@ pub(crate) async fn get_content_as_filename_route(
|
|||
content,
|
||||
content_type,
|
||||
content_disposition,
|
||||
}) = services().media.get(&mxc).await?
|
||||
}) = services.media.get(&mxc).await?
|
||||
{
|
||||
let content_disposition = Some(make_content_disposition(
|
||||
&content_type,
|
||||
|
@ -304,6 +310,7 @@ pub(crate) async fn get_content_as_filename_route(
|
|||
})
|
||||
} else if !server_is_ours(&body.server_name) && body.allow_remote {
|
||||
match get_remote_content(
|
||||
services,
|
||||
&mxc,
|
||||
&body.server_name,
|
||||
body.media_id.clone(),
|
||||
|
@ -351,9 +358,10 @@ pub(crate) async fn get_content_as_filename_route(
|
|||
/// seconds
|
||||
#[tracing::instrument(skip_all, fields(%client), name = "media_get")]
|
||||
pub(crate) async fn get_content_as_filename_v1_route(
|
||||
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_as_filename::v3::Request>,
|
||||
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<get_content_as_filename::v3::Request>,
|
||||
) -> Result<RumaResponse<get_content_as_filename::v3::Response>> {
|
||||
get_content_as_filename_route(InsecureClientIp(client), body)
|
||||
get_content_as_filename_route(State(services), InsecureClientIp(client), body)
|
||||
.await
|
||||
.map(RumaResponse)
|
||||
}
|
||||
|
@ -368,7 +376,8 @@ pub(crate) async fn get_content_as_filename_v1_route(
|
|||
/// seconds
|
||||
#[tracing::instrument(skip_all, fields(%client), name = "media_thumbnail_get")]
|
||||
pub(crate) async fn get_content_thumbnail_route(
|
||||
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_thumbnail::v3::Request>,
|
||||
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<get_content_thumbnail::v3::Request>,
|
||||
) -> Result<get_content_thumbnail::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
|
||||
|
@ -376,7 +385,7 @@ pub(crate) async fn get_content_thumbnail_route(
|
|||
content,
|
||||
content_type,
|
||||
content_disposition,
|
||||
}) = services()
|
||||
}) = services
|
||||
.media
|
||||
.get_thumbnail(
|
||||
&mxc,
|
||||
|
@ -400,7 +409,7 @@ pub(crate) async fn get_content_thumbnail_route(
|
|||
content_disposition,
|
||||
})
|
||||
} else if !server_is_ours(&body.server_name) && body.allow_remote {
|
||||
if services()
|
||||
if services
|
||||
.globals
|
||||
.prevent_media_downloads_from()
|
||||
.contains(&body.server_name)
|
||||
|
@ -411,7 +420,7 @@ pub(crate) async fn get_content_thumbnail_route(
|
|||
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||
}
|
||||
|
||||
match services()
|
||||
match services
|
||||
.sending
|
||||
.send_federation_request(
|
||||
&body.server_name,
|
||||
|
@ -430,7 +439,7 @@ pub(crate) async fn get_content_thumbnail_route(
|
|||
.await
|
||||
{
|
||||
Ok(get_thumbnail_response) => {
|
||||
services()
|
||||
services
|
||||
.media
|
||||
.upload_thumbnail(
|
||||
None,
|
||||
|
@ -481,17 +490,19 @@ pub(crate) async fn get_content_thumbnail_route(
|
|||
/// seconds
|
||||
#[tracing::instrument(skip_all, fields(%client), name = "media_thumbnail_get")]
|
||||
pub(crate) async fn get_content_thumbnail_v1_route(
|
||||
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_thumbnail::v3::Request>,
|
||||
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<get_content_thumbnail::v3::Request>,
|
||||
) -> Result<RumaResponse<get_content_thumbnail::v3::Response>> {
|
||||
get_content_thumbnail_route(InsecureClientIp(client), body)
|
||||
get_content_thumbnail_route(State(services), InsecureClientIp(client), body)
|
||||
.await
|
||||
.map(RumaResponse)
|
||||
}
|
||||
|
||||
async fn get_remote_content(
|
||||
mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool, timeout_ms: Duration,
|
||||
services: &Services, mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool,
|
||||
timeout_ms: Duration,
|
||||
) -> Result<get_content::v3::Response, Error> {
|
||||
if services()
|
||||
if services
|
||||
.globals
|
||||
.prevent_media_downloads_from()
|
||||
.contains(&server_name.to_owned())
|
||||
|
@ -502,7 +513,7 @@ async fn get_remote_content(
|
|||
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||
}
|
||||
|
||||
let content_response = services()
|
||||
let content_response = services
|
||||
.sending
|
||||
.send_federation_request(
|
||||
server_name,
|
||||
|
@ -522,7 +533,7 @@ async fn get_remote_content(
|
|||
None,
|
||||
));
|
||||
|
||||
services()
|
||||
services
|
||||
.media
|
||||
.create(
|
||||
None,
|
||||
|
@ -542,15 +553,11 @@ async fn get_remote_content(
|
|||
})
|
||||
}
|
||||
|
||||
async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
|
||||
async fn download_image(services: &Services, client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
|
||||
let image = client.get(url).send().await?.bytes().await?;
|
||||
let mxc = format!(
|
||||
"mxc://{}/{}",
|
||||
services().globals.server_name(),
|
||||
utils::random_string(MXC_LENGTH)
|
||||
);
|
||||
let mxc = format!("mxc://{}/{}", services.globals.server_name(), utils::random_string(MXC_LENGTH));
|
||||
|
||||
services()
|
||||
services
|
||||
.media
|
||||
.create(None, &mxc, None, None, &image)
|
||||
.await?;
|
||||
|
@ -572,18 +579,18 @@ async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPrevie
|
|||
})
|
||||
}
|
||||
|
||||
async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
|
||||
async fn download_html(services: &Services, client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
|
||||
let mut response = client.get(url).send().await?;
|
||||
|
||||
let mut bytes: Vec<u8> = Vec::new();
|
||||
while let Some(chunk) = response.chunk().await? {
|
||||
bytes.extend_from_slice(&chunk);
|
||||
if bytes.len() > services().globals.url_preview_max_spider_size() {
|
||||
if bytes.len() > services.globals.url_preview_max_spider_size() {
|
||||
debug!(
|
||||
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \
|
||||
response body and assuming our necessary data is in this range.",
|
||||
url,
|
||||
services().globals.url_preview_max_spider_size()
|
||||
services.globals.url_preview_max_spider_size()
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
@ -595,7 +602,7 @@ async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreview
|
|||
|
||||
let mut data = match html.opengraph.images.first() {
|
||||
None => UrlPreviewData::default(),
|
||||
Some(obj) => download_image(client, &obj.url).await?,
|
||||
Some(obj) => download_image(services, client, &obj.url).await?,
|
||||
};
|
||||
|
||||
let props = html.opengraph.properties;
|
||||
|
@ -607,19 +614,19 @@ async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreview
|
|||
Ok(data)
|
||||
}
|
||||
|
||||
async fn request_url_preview(url: &str) -> Result<UrlPreviewData> {
|
||||
async fn request_url_preview(services: &Services, url: &str) -> Result<UrlPreviewData> {
|
||||
if let Ok(ip) = IPAddress::parse(url) {
|
||||
if !services().globals.valid_cidr_range(&ip) {
|
||||
if !services.globals.valid_cidr_range(&ip) {
|
||||
return Err(Error::BadServerResponse("Requesting from this address is forbidden"));
|
||||
}
|
||||
}
|
||||
|
||||
let client = &services().globals.client.url_preview;
|
||||
let client = &services.globals.client.url_preview;
|
||||
let response = client.head(url).send().await?;
|
||||
|
||||
if let Some(remote_addr) = response.remote_addr() {
|
||||
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
|
||||
if !services().globals.valid_cidr_range(&ip) {
|
||||
if !services.globals.valid_cidr_range(&ip) {
|
||||
return Err(Error::BadServerResponse("Requesting from this address is forbidden"));
|
||||
}
|
||||
}
|
||||
|
@ -633,24 +640,24 @@ async fn request_url_preview(url: &str) -> Result<UrlPreviewData> {
|
|||
return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type"));
|
||||
};
|
||||
let data = match content_type {
|
||||
html if html.starts_with("text/html") => download_html(client, url).await?,
|
||||
img if img.starts_with("image/") => download_image(client, url).await?,
|
||||
html if html.starts_with("text/html") => download_html(services, client, url).await?,
|
||||
img if img.starts_with("image/") => download_image(services, client, url).await?,
|
||||
_ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")),
|
||||
};
|
||||
|
||||
services().media.set_url_preview(url, &data).await?;
|
||||
services.media.set_url_preview(url, &data).await?;
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
async fn get_url_preview(url: &str) -> Result<UrlPreviewData> {
|
||||
if let Some(preview) = services().media.get_url_preview(url).await {
|
||||
async fn get_url_preview(services: &Services, url: &str) -> Result<UrlPreviewData> {
|
||||
if let Some(preview) = services.media.get_url_preview(url).await {
|
||||
return Ok(preview);
|
||||
}
|
||||
|
||||
// ensure that only one request is made per URL
|
||||
let mutex_request = Arc::clone(
|
||||
services()
|
||||
services
|
||||
.media
|
||||
.url_preview_mutex
|
||||
.write()
|
||||
|
@ -660,13 +667,13 @@ async fn get_url_preview(url: &str) -> Result<UrlPreviewData> {
|
|||
);
|
||||
let _request_lock = mutex_request.lock().await;
|
||||
|
||||
match services().media.get_url_preview(url).await {
|
||||
match services.media.get_url_preview(url).await {
|
||||
Some(preview) => Ok(preview),
|
||||
None => request_url_preview(url).await,
|
||||
None => request_url_preview(services, url).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn url_preview_allowed(url_str: &str) -> bool {
|
||||
fn url_preview_allowed(services: &Services, url_str: &str) -> bool {
|
||||
let url: Url = match Url::parse(url_str) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
|
@ -691,10 +698,10 @@ fn url_preview_allowed(url_str: &str) -> bool {
|
|||
Some(h) => h.to_owned(),
|
||||
};
|
||||
|
||||
let allowlist_domain_contains = services().globals.url_preview_domain_contains_allowlist();
|
||||
let allowlist_domain_explicit = services().globals.url_preview_domain_explicit_allowlist();
|
||||
let denylist_domain_explicit = services().globals.url_preview_domain_explicit_denylist();
|
||||
let allowlist_url_contains = services().globals.url_preview_url_contains_allowlist();
|
||||
let allowlist_domain_contains = services.globals.url_preview_domain_contains_allowlist();
|
||||
let allowlist_domain_explicit = services.globals.url_preview_domain_explicit_allowlist();
|
||||
let denylist_domain_explicit = services.globals.url_preview_domain_explicit_denylist();
|
||||
let allowlist_url_contains = services.globals.url_preview_url_contains_allowlist();
|
||||
|
||||
if allowlist_domain_contains.contains(&"*".to_owned())
|
||||
|| allowlist_domain_explicit.contains(&"*".to_owned())
|
||||
|
@ -735,7 +742,7 @@ fn url_preview_allowed(url_str: &str) -> bool {
|
|||
}
|
||||
|
||||
// check root domain if available and if user has root domain checks
|
||||
if services().globals.url_preview_check_root_domain() {
|
||||
if services.globals.url_preview_check_root_domain() {
|
||||
debug!("Checking root domain");
|
||||
match host.split_once('.') {
|
||||
None => return false,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue