de-global services() from api

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-16 08:05:25 +00:00
parent 463f1a1287
commit 8b6018d77d
61 changed files with 1485 additions and 1320 deletions

View file

@ -2,6 +2,7 @@
use std::{io::Cursor, sync::Arc, time::Duration};
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduit::{debug, error, utils::math::ruma_from_usize, warn};
use image::io::Reader as ImgReader;
@ -20,9 +21,8 @@ use crate::{
debug_warn,
service::{
media::{FileMeta, UrlPreviewData},
server_is_ours,
server_is_ours, Services,
},
services,
utils::{
self,
content_disposition::{content_disposition_type, make_content_disposition, sanitise_filename},
@ -42,10 +42,10 @@ const CORP_CROSS_ORIGIN: &str = "cross-origin";
///
/// Returns max upload size.
pub(crate) async fn get_media_config_route(
_body: Ruma<get_media_config::v3::Request>,
State(services): State<crate::State>, _body: Ruma<get_media_config::v3::Request>,
) -> Result<get_media_config::v3::Response> {
Ok(get_media_config::v3::Response {
upload_size: ruma_from_usize(services().globals.config.max_request_size),
upload_size: ruma_from_usize(services.globals.config.max_request_size),
})
}
@ -57,9 +57,11 @@ pub(crate) async fn get_media_config_route(
///
/// Returns max upload size.
pub(crate) async fn get_media_config_v1_route(
body: Ruma<get_media_config::v3::Request>,
State(services): State<crate::State>, body: Ruma<get_media_config::v3::Request>,
) -> Result<RumaResponse<get_media_config::v3::Response>> {
get_media_config_route(body).await.map(RumaResponse)
get_media_config_route(State(services), body)
.await
.map(RumaResponse)
}
/// # `GET /_matrix/media/v3/preview_url`
@ -67,17 +69,18 @@ pub(crate) async fn get_media_config_v1_route(
/// Returns URL preview.
#[tracing::instrument(skip_all, fields(%client), name = "url_preview")]
pub(crate) async fn get_media_preview_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_media_preview::v3::Request>,
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_media_preview::v3::Request>,
) -> Result<get_media_preview::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let url = &body.url;
if !url_preview_allowed(url) {
if !url_preview_allowed(services, url) {
warn!(%sender_user, "URL is not allowed to be previewed: {url}");
return Err(Error::BadRequest(ErrorKind::forbidden(), "URL is not allowed to be previewed"));
}
match get_url_preview(url).await {
match get_url_preview(services, url).await {
Ok(preview) => {
let res = serde_json::value::to_raw_value(&preview).map_err(|e| {
error!(%sender_user, "Failed to convert UrlPreviewData into a serde json value: {e}");
@ -115,9 +118,10 @@ pub(crate) async fn get_media_preview_route(
/// Returns URL preview.
#[tracing::instrument(skip_all, fields(%client), name = "url_preview")]
pub(crate) async fn get_media_preview_v1_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_media_preview::v3::Request>,
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_media_preview::v3::Request>,
) -> Result<RumaResponse<get_media_preview::v3::Response>> {
get_media_preview_route(InsecureClientIp(client), body)
get_media_preview_route(State(services), InsecureClientIp(client), body)
.await
.map(RumaResponse)
}
@ -130,17 +134,14 @@ pub(crate) async fn get_media_preview_v1_route(
/// - Media will be saved in the media/ directory
#[tracing::instrument(skip_all, fields(%client), name = "media_upload")]
pub(crate) async fn create_content_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<create_content::v3::Request>,
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<create_content::v3::Request>,
) -> Result<create_content::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let mxc = format!(
"mxc://{}/{}",
services().globals.server_name(),
utils::random_string(MXC_LENGTH)
);
let mxc = format!("mxc://{}/{}", services.globals.server_name(), utils::random_string(MXC_LENGTH));
services()
services
.media
.create(
Some(sender_user.clone()),
@ -178,9 +179,10 @@ pub(crate) async fn create_content_route(
/// - Media will be saved in the media/ directory
#[tracing::instrument(skip_all, fields(%client), name = "media_upload")]
pub(crate) async fn create_content_v1_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<create_content::v3::Request>,
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<create_content::v3::Request>,
) -> Result<RumaResponse<create_content::v3::Response>> {
create_content_route(InsecureClientIp(client), body)
create_content_route(State(services), InsecureClientIp(client), body)
.await
.map(RumaResponse)
}
@ -195,7 +197,8 @@ pub(crate) async fn create_content_v1_route(
/// seconds
#[tracing::instrument(skip_all, fields(%client), name = "media_get")]
pub(crate) async fn get_content_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content::v3::Request>,
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_content::v3::Request>,
) -> Result<get_content::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
@ -203,7 +206,7 @@ pub(crate) async fn get_content_route(
content,
content_type,
content_disposition,
}) = services().media.get(&mxc).await?
}) = services.media.get(&mxc).await?
{
let content_disposition = Some(make_content_disposition(&content_type, content_disposition, None));
let file = content.expect("content");
@ -217,6 +220,7 @@ pub(crate) async fn get_content_route(
})
} else if !server_is_ours(&body.server_name) && body.allow_remote {
let response = get_remote_content(
services,
&mxc,
&body.server_name,
body.media_id.clone(),
@ -261,9 +265,10 @@ pub(crate) async fn get_content_route(
/// seconds
#[tracing::instrument(skip_all, fields(%client), name = "media_get")]
pub(crate) async fn get_content_v1_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content::v3::Request>,
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_content::v3::Request>,
) -> Result<RumaResponse<get_content::v3::Response>> {
get_content_route(InsecureClientIp(client), body)
get_content_route(State(services), InsecureClientIp(client), body)
.await
.map(RumaResponse)
}
@ -278,7 +283,8 @@ pub(crate) async fn get_content_v1_route(
/// seconds
#[tracing::instrument(skip_all, fields(%client), name = "media_get")]
pub(crate) async fn get_content_as_filename_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_as_filename::v3::Request>,
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_content_as_filename::v3::Request>,
) -> Result<get_content_as_filename::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
@ -286,7 +292,7 @@ pub(crate) async fn get_content_as_filename_route(
content,
content_type,
content_disposition,
}) = services().media.get(&mxc).await?
}) = services.media.get(&mxc).await?
{
let content_disposition = Some(make_content_disposition(
&content_type,
@ -304,6 +310,7 @@ pub(crate) async fn get_content_as_filename_route(
})
} else if !server_is_ours(&body.server_name) && body.allow_remote {
match get_remote_content(
services,
&mxc,
&body.server_name,
body.media_id.clone(),
@ -351,9 +358,10 @@ pub(crate) async fn get_content_as_filename_route(
/// seconds
#[tracing::instrument(skip_all, fields(%client), name = "media_get")]
pub(crate) async fn get_content_as_filename_v1_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_as_filename::v3::Request>,
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_content_as_filename::v3::Request>,
) -> Result<RumaResponse<get_content_as_filename::v3::Response>> {
get_content_as_filename_route(InsecureClientIp(client), body)
get_content_as_filename_route(State(services), InsecureClientIp(client), body)
.await
.map(RumaResponse)
}
@ -368,7 +376,8 @@ pub(crate) async fn get_content_as_filename_v1_route(
/// seconds
#[tracing::instrument(skip_all, fields(%client), name = "media_thumbnail_get")]
pub(crate) async fn get_content_thumbnail_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_thumbnail::v3::Request>,
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_content_thumbnail::v3::Request>,
) -> Result<get_content_thumbnail::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
@ -376,7 +385,7 @@ pub(crate) async fn get_content_thumbnail_route(
content,
content_type,
content_disposition,
}) = services()
}) = services
.media
.get_thumbnail(
&mxc,
@ -400,7 +409,7 @@ pub(crate) async fn get_content_thumbnail_route(
content_disposition,
})
} else if !server_is_ours(&body.server_name) && body.allow_remote {
if services()
if services
.globals
.prevent_media_downloads_from()
.contains(&body.server_name)
@ -411,7 +420,7 @@ pub(crate) async fn get_content_thumbnail_route(
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
}
match services()
match services
.sending
.send_federation_request(
&body.server_name,
@ -430,7 +439,7 @@ pub(crate) async fn get_content_thumbnail_route(
.await
{
Ok(get_thumbnail_response) => {
services()
services
.media
.upload_thumbnail(
None,
@ -481,17 +490,19 @@ pub(crate) async fn get_content_thumbnail_route(
/// seconds
#[tracing::instrument(skip_all, fields(%client), name = "media_thumbnail_get")]
pub(crate) async fn get_content_thumbnail_v1_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_thumbnail::v3::Request>,
State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_content_thumbnail::v3::Request>,
) -> Result<RumaResponse<get_content_thumbnail::v3::Response>> {
get_content_thumbnail_route(InsecureClientIp(client), body)
get_content_thumbnail_route(State(services), InsecureClientIp(client), body)
.await
.map(RumaResponse)
}
async fn get_remote_content(
mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool, timeout_ms: Duration,
services: &Services, mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool,
timeout_ms: Duration,
) -> Result<get_content::v3::Response, Error> {
if services()
if services
.globals
.prevent_media_downloads_from()
.contains(&server_name.to_owned())
@ -502,7 +513,7 @@ async fn get_remote_content(
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
}
let content_response = services()
let content_response = services
.sending
.send_federation_request(
server_name,
@ -522,7 +533,7 @@ async fn get_remote_content(
None,
));
services()
services
.media
.create(
None,
@ -542,15 +553,11 @@ async fn get_remote_content(
})
}
async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
async fn download_image(services: &Services, client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
let image = client.get(url).send().await?.bytes().await?;
let mxc = format!(
"mxc://{}/{}",
services().globals.server_name(),
utils::random_string(MXC_LENGTH)
);
let mxc = format!("mxc://{}/{}", services.globals.server_name(), utils::random_string(MXC_LENGTH));
services()
services
.media
.create(None, &mxc, None, None, &image)
.await?;
@ -572,18 +579,18 @@ async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPrevie
})
}
async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
async fn download_html(services: &Services, client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
let mut response = client.get(url).send().await?;
let mut bytes: Vec<u8> = Vec::new();
while let Some(chunk) = response.chunk().await? {
bytes.extend_from_slice(&chunk);
if bytes.len() > services().globals.url_preview_max_spider_size() {
if bytes.len() > services.globals.url_preview_max_spider_size() {
debug!(
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \
response body and assuming our necessary data is in this range.",
url,
services().globals.url_preview_max_spider_size()
services.globals.url_preview_max_spider_size()
);
break;
}
@ -595,7 +602,7 @@ async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreview
let mut data = match html.opengraph.images.first() {
None => UrlPreviewData::default(),
Some(obj) => download_image(client, &obj.url).await?,
Some(obj) => download_image(services, client, &obj.url).await?,
};
let props = html.opengraph.properties;
@ -607,19 +614,19 @@ async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreview
Ok(data)
}
async fn request_url_preview(url: &str) -> Result<UrlPreviewData> {
async fn request_url_preview(services: &Services, url: &str) -> Result<UrlPreviewData> {
if let Ok(ip) = IPAddress::parse(url) {
if !services().globals.valid_cidr_range(&ip) {
if !services.globals.valid_cidr_range(&ip) {
return Err(Error::BadServerResponse("Requesting from this address is forbidden"));
}
}
let client = &services().globals.client.url_preview;
let client = &services.globals.client.url_preview;
let response = client.head(url).send().await?;
if let Some(remote_addr) = response.remote_addr() {
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
if !services().globals.valid_cidr_range(&ip) {
if !services.globals.valid_cidr_range(&ip) {
return Err(Error::BadServerResponse("Requesting from this address is forbidden"));
}
}
@ -633,24 +640,24 @@ async fn request_url_preview(url: &str) -> Result<UrlPreviewData> {
return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type"));
};
let data = match content_type {
html if html.starts_with("text/html") => download_html(client, url).await?,
img if img.starts_with("image/") => download_image(client, url).await?,
html if html.starts_with("text/html") => download_html(services, client, url).await?,
img if img.starts_with("image/") => download_image(services, client, url).await?,
_ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")),
};
services().media.set_url_preview(url, &data).await?;
services.media.set_url_preview(url, &data).await?;
Ok(data)
}
async fn get_url_preview(url: &str) -> Result<UrlPreviewData> {
if let Some(preview) = services().media.get_url_preview(url).await {
async fn get_url_preview(services: &Services, url: &str) -> Result<UrlPreviewData> {
if let Some(preview) = services.media.get_url_preview(url).await {
return Ok(preview);
}
// ensure that only one request is made per URL
let mutex_request = Arc::clone(
services()
services
.media
.url_preview_mutex
.write()
@ -660,13 +667,13 @@ async fn get_url_preview(url: &str) -> Result<UrlPreviewData> {
);
let _request_lock = mutex_request.lock().await;
match services().media.get_url_preview(url).await {
match services.media.get_url_preview(url).await {
Some(preview) => Ok(preview),
None => request_url_preview(url).await,
None => request_url_preview(services, url).await,
}
}
fn url_preview_allowed(url_str: &str) -> bool {
fn url_preview_allowed(services: &Services, url_str: &str) -> bool {
let url: Url = match Url::parse(url_str) {
Ok(u) => u,
Err(e) => {
@ -691,10 +698,10 @@ fn url_preview_allowed(url_str: &str) -> bool {
Some(h) => h.to_owned(),
};
let allowlist_domain_contains = services().globals.url_preview_domain_contains_allowlist();
let allowlist_domain_explicit = services().globals.url_preview_domain_explicit_allowlist();
let denylist_domain_explicit = services().globals.url_preview_domain_explicit_denylist();
let allowlist_url_contains = services().globals.url_preview_url_contains_allowlist();
let allowlist_domain_contains = services.globals.url_preview_domain_contains_allowlist();
let allowlist_domain_explicit = services.globals.url_preview_domain_explicit_allowlist();
let denylist_domain_explicit = services.globals.url_preview_domain_explicit_denylist();
let allowlist_url_contains = services.globals.url_preview_url_contains_allowlist();
if allowlist_domain_contains.contains(&"*".to_owned())
|| allowlist_domain_explicit.contains(&"*".to_owned())
@ -735,7 +742,7 @@ fn url_preview_allowed(url_str: &str) -> bool {
}
// check root domain if available and if user has root domain checks
if services().globals.url_preview_check_root_domain() {
if services.globals.url_preview_check_root_domain() {
debug!("Checking root domain");
match host.split_once('.') {
None => return false,