From 02bc81863324a7e96467858c3ab3f1dd678176b8 Mon Sep 17 00:00:00 2001
From: strawberry <strawberry@puppygock.gay>
Date: Sun, 21 Jan 2024 18:18:21 -0500
Subject: [PATCH] match explicit URI to see if we should authenticate the user

first attempt at forcing an endpoint to be authenticated

Signed-off-by: strawberry <strawberry@puppygock.gay>
---
 src/api/ruma_wrapper/axum.rs | 51 +++++++++++++++++++++++++++++++-----
 1 file changed, 44 insertions(+), 7 deletions(-)

diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs
index 9cf1fe16..39cabaf7 100644
--- a/src/api/ruma_wrapper/axum.rs
+++ b/src/api/ruma_wrapper/axum.rs
@@ -23,6 +23,12 @@ use tracing::{debug, error, warn};
 use super::{Ruma, RumaResponse};
 use crate::{services, Error, Result};
 
+#[derive(Deserialize)]
+struct QueryParams {
+    access_token: Option<String>,
+    user_id: Option<String>,
+}
+
 #[async_trait]
 impl<T, S, B> FromRequest<S, B> for Ruma<T>
 where
@@ -34,12 +40,6 @@ where
     type Rejection = Error;
 
     async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
-        #[derive(Deserialize)]
-        struct QueryParams {
-            access_token: Option<String>,
-            user_id: Option<String>,
-        }
-
         let (mut parts, mut body) = match req.with_limited_body() {
             Ok(limited_req) => {
                 let (parts, body) = limited_req.into_parts();
@@ -263,7 +263,44 @@ where
                             }
                         }
                     }
-                    AuthScheme::None => (None, None, None, false),
+                    AuthScheme::None => match parts.uri.path() {
+                        // allow_public_room_directory_without_auth
+                        "/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => {
+                            if !services()
+                                .globals
+                                .config
+                                .allow_public_room_directory_without_auth
+                            {
+                                let token = match token {
+                                    Some(token) => token,
+                                    _ => {
+                                        return Err(Error::BadRequest(
+                                            ErrorKind::MissingToken,
+                                            "Missing access token.",
+                                        ))
+                                    }
+                                };
+
+                                match services().users.find_from_token(token).unwrap() {
+                                    None => {
+                                        return Err(Error::BadRequest(
+                                            ErrorKind::UnknownToken { soft_logout: false },
+                                            "Unknown access token.",
+                                        ))
+                                    }
+                                    Some((user_id, device_id)) => (
+                                        Some(user_id),
+                                        Some(OwnedDeviceId::from(device_id)),
+                                        None,
+                                        false,
+                                    ),
+                                }
+                            } else {
+                                (None, None, None, false)
+                            }
+                        }
+                        _ => (None, None, None, false),
+                    },
                 }
             };