diff --git a/Cargo.toml b/Cargo.toml index e406c9e1..043790f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -213,7 +213,7 @@ features = [ [workspace.dependencies.futures] version = "0.3.30" default-features = false -features = ["std"] +features = ["std", "async-await"] [workspace.dependencies.tokio] version = "1.40.0" diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 088b891a..281bf2a2 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -2,13 +2,13 @@ use std::cmp; use axum::extract::State; use conduit::{ - is_equal_to, utils::{IterStream, ReadyExt}, - Err, PduCount, Result, + PduCount, Result, }; use futures::{FutureExt, StreamExt}; use ruma::{api::federation::backfill::get_backfill, uint, user_id, MilliSecondsSinceUnixEpoch}; +use super::AccessCheck; use crate::Ruma; /// # `GET /_matrix/federation/v1/backfill/` @@ -18,24 +18,14 @@ use crate::Ruma; pub(crate) async fn get_backfill_route( State(services): State, body: Ruma, ) -> Result { - services - .rooms - .event_handler - .acl_check(body.origin(), &body.room_id) - .await?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id) - .await && !services - .rooms - .state_cache - .server_in_room(body.origin(), &body.room_id) - .await - { - return Err!(Request(Forbidden("Server is not in room."))); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; let until = body .v @@ -70,7 +60,6 @@ pub(crate) async fn get_backfill_route( .state_accessor .server_can_see_event(origin, &pdu.room_id, &pdu.event_id) .await - .is_ok_and(is_equal_to!(true)) { return None; } diff --git a/src/api/server/event.rs b/src/api/server/event.rs index 64ce3e40..29d5d870 100644 --- a/src/api/server/event.rs +++ b/src/api/server/event.rs @@ -1,7 +1,8 @@ use axum::extract::State; -use conduit::{err, Err, Result}; +use conduit::{err, Result}; use ruma::{api::federation::event::get_event, MilliSecondsSinceUnixEpoch, RoomId}; +use super::AccessCheck; use crate::Ruma; /// # `GET /_matrix/federation/v1/event/{eventId}` @@ -20,35 +21,21 @@ pub(crate) async fn get_event_route( .await .map_err(|_| err!(Request(NotFound("Event not found."))))?; - let room_id_str = event + let room_id: &RoomId = event .get("room_id") .and_then(|val| val.as_str()) - .ok_or_else(|| err!(Database("Invalid event in database.")))?; + .ok_or_else(|| err!(Database("Invalid event in database.")))? + .try_into() + .map_err(|_| err!(Database("Invalid room_id in event in database.")))?; - let room_id = - <&RoomId>::try_from(room_id_str).map_err(|_| err!(Database("Invalid room_id in event in database.")))?; - - if !services - .rooms - .state_accessor - .is_world_readable(room_id) - .await && !services - .rooms - .state_cache - .server_in_room(body.origin(), room_id) - .await - { - return Err!(Request(Forbidden("Server is not in room."))); - } - - if !services - .rooms - .state_accessor - .server_can_see_event(body.origin(), room_id, &body.event_id) - .await? - { - return Err!(Request(Forbidden("Server is not allowed to see event."))); + AccessCheck { + services: &services, + origin: body.origin(), + room_id, + event_id: Some(&body.event_id), } + .check() + .await?; Ok(get_event::v1::Response { origin: services.globals.server_name().to_owned(), diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 8fe96f81..faeb2b99 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -8,6 +8,7 @@ use ruma::{ RoomId, }; +use super::AccessCheck; use crate::Ruma; /// # `GET /_matrix/federation/v1/event_auth/{roomId}/{eventId}` @@ -18,24 +19,14 @@ use crate::Ruma; pub(crate) async fn get_event_authorization_route( State(services): State, body: Ruma, ) -> Result { - services - .rooms - .event_handler - .acl_check(body.origin(), &body.room_id) - .await?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id) - .await && !services - .rooms - .state_cache - .server_in_room(body.origin(), &body.room_id) - .await - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; let event = services .rooms diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index aee4fbe9..7dff44dc 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -5,6 +5,7 @@ use ruma::{ CanonicalJsonValue, EventId, RoomId, }; +use super::AccessCheck; use crate::Ruma; /// # `POST /_matrix/federation/v1/get_missing_events/{roomId}` @@ -13,29 +14,16 @@ use crate::Ruma; pub(crate) async fn get_missing_events_route( State(services): State, body: Ruma, ) -> Result { - services - .rooms - .event_handler - .acl_check(body.origin(), &body.room_id) - .await?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id) - .await && !services - .rooms - .state_cache - .server_in_room(body.origin(), &body.room_id) - .await - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room")); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; - let limit = body - .limit - .try_into() - .expect("UInt could not be converted to usize"); + let limit = body.limit.try_into()?; let mut queued_events = body.latest_events.clone(); // the vec will never have more entries the limit @@ -70,7 +58,7 @@ pub(crate) async fn get_missing_events_route( .rooms .state_accessor .server_can_see_event(body.origin(), &body.room_id, &queued_events[i]) - .await? + .await { i = i.saturating_add(1); continue; diff --git a/src/api/server/mod.rs b/src/api/server/mod.rs index 9a184f23..9b7d91cb 100644 --- a/src/api/server/mod.rs +++ b/src/api/server/mod.rs @@ -41,3 +41,6 @@ pub(super) use state_ids::*; pub(super) use user::*; pub(super) use version::*; pub(super) use well_known::*; + +mod utils; +use utils::AccessCheck; diff --git a/src/api/server/state.rs b/src/api/server/state.rs index 59bb6c7b..06a44a99 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,10 +1,11 @@ use std::borrow::Borrow; use axum::extract::State; -use conduit::{err, result::LogErr, utils::IterStream, Err, Result}; +use conduit::{err, result::LogErr, utils::IterStream, Result}; use futures::{FutureExt, StreamExt, TryStreamExt}; use ruma::api::federation::event::get_room_state; +use super::AccessCheck; use crate::Ruma; /// # `GET /_matrix/federation/v1/state/{roomId}` @@ -13,24 +14,14 @@ use crate::Ruma; pub(crate) async fn get_room_state_route( State(services): State, body: Ruma, ) -> Result { - services - .rooms - .event_handler - .acl_check(body.origin(), &body.room_id) - .await?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id) - .await && !services - .rooms - .state_cache - .server_in_room(body.origin(), &body.room_id) - .await - { - return Err!(Request(Forbidden("Server is not in room."))); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; let shortstatehash = services .rooms diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index 957a2a86..52d8e7cc 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,11 +1,12 @@ use std::borrow::Borrow; use axum::extract::State; -use conduit::{err, Err}; +use conduit::{err, Result}; use futures::StreamExt; use ruma::api::federation::event::get_room_state_ids; -use crate::{Result, Ruma}; +use super::AccessCheck; +use crate::Ruma; /// # `GET /_matrix/federation/v1/state_ids/{roomId}` /// @@ -14,24 +15,14 @@ use crate::{Result, Ruma}; pub(crate) async fn get_room_state_ids_route( State(services): State, body: Ruma, ) -> Result { - services - .rooms - .event_handler - .acl_check(body.origin(), &body.room_id) - .await?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id) - .await && !services - .rooms - .state_cache - .server_in_room(body.origin(), &body.room_id) - .await - { - return Err!(Request(Forbidden("Server is not in room."))); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; let shortstatehash = services .rooms diff --git a/src/api/server/utils.rs b/src/api/server/utils.rs new file mode 100644 index 00000000..278465ca --- /dev/null +++ b/src/api/server/utils.rs @@ -0,0 +1,60 @@ +use conduit::{implement, is_false, Err, Result}; +use conduit_service::Services; +use futures::{future::OptionFuture, join, FutureExt}; +use ruma::{EventId, RoomId, ServerName}; + +pub(super) struct AccessCheck<'a> { + pub(super) services: &'a Services, + pub(super) origin: &'a ServerName, + pub(super) room_id: &'a RoomId, + pub(super) event_id: Option<&'a EventId>, +} + +#[implement(AccessCheck, params = "<'_>")] +pub(super) async fn check(&self) -> Result { + let acl_check = self + .services + .rooms + .event_handler + .acl_check(self.origin, self.room_id) + .map(|result| result.is_ok()); + + let world_readable = self + .services + .rooms + .state_accessor + .is_world_readable(self.room_id); + + let server_in_room = self + .services + .rooms + .state_cache + .server_in_room(self.origin, self.room_id); + + let server_can_see: OptionFuture<_> = self + .event_id + .map(|event_id| { + self.services + .rooms + .state_accessor + .server_can_see_event(self.origin, self.room_id, event_id) + }) + .into(); + + let (world_readable, server_in_room, server_can_see, acl_check) = + join!(world_readable, server_in_room, server_can_see, acl_check); + + if !acl_check { + return Err!(Request(Forbidden("Server access denied."))); + } + + if !world_readable && !server_in_room { + return Err!(Request(Forbidden("Server is not in room."))); + } + + if server_can_see.is_some_and(is_false!()) { + return Err!(Request(Forbidden("Server is not allowed to see event."))); + } + + Ok(()) +}