From dd3dab39ae1c2f81aebef1b2041e12046abb61f1 Mon Sep 17 00:00:00 2001
From: timokoesters <timo@koesters.xyz>
Date: Mon, 27 Jul 2020 17:38:00 +0200
Subject: [PATCH] feat: whoami route

---
 src/client_server.rs | 184 +++++++++++++++++++------------------------
 src/main.rs          |   1 +
 src/ruma_wrapper.rs  |   2 +-
 3 files changed, 83 insertions(+), 104 deletions(-)

diff --git a/src/client_server.rs b/src/client_server.rs
index baeb8396..2a67a572 100644
--- a/src/client_server.rs
+++ b/src/client_server.rs
@@ -11,14 +11,14 @@ use log::warn;
 #[cfg(not(feature = "conduit_bin"))]
 use super::State;
 #[cfg(feature = "conduit_bin")]
-use rocket::{delete, get, options, post, put, State, tokio};
+use rocket::{delete, get, options, post, put, tokio, State};
 
 use ruma::{
     api::client::{
         error::ErrorKind,
         r0::{
             account::{
-                change_password, deactivate, get_username_availability, register,
+                change_password, deactivate, get_username_availability, register, whoami,
                 ThirdPartyIdRemovalStatus,
             },
             alias::{create_alias, delete_alias, get_alias},
@@ -304,6 +304,18 @@ pub fn login_route(
     .into())
 }
 
+#[cfg_attr(
+    feature = "conduit_bin",
+    get("/_matrix/client/r0/account/whoami", data = "<body>")
+)]
+pub fn whoami_route(body: Ruma<whoami::Request>) -> ConduitResult<whoami::Response> {
+    let sender_id = body.sender_id.as_ref().expect("user is authenticated");
+    Ok(whoami::Response {
+        user_id: sender_id.clone(),
+    }
+    .into())
+}
+
 #[cfg_attr(
     feature = "conduit_bin",
     post("/_matrix/client/r0/logout", data = "<body>")
@@ -361,9 +373,14 @@ pub fn change_password_route(
     };
 
     if let Some(auth) = &body.auth {
-        let (worked, uiaainfo) =
-            db.uiaa
-                .try_auth(&sender_id, device_id, auth, &uiaainfo, &db.users, &db.globals)?;
+        let (worked, uiaainfo) = db.uiaa.try_auth(
+            &sender_id,
+            device_id,
+            auth,
+            &uiaainfo,
+            &db.users,
+            &db.globals,
+        )?;
         if !worked {
             return Err(Error::Uiaa(uiaainfo));
         }
@@ -520,8 +537,7 @@ pub fn get_pushrules_all_route(
     "/_matrix/client/r0/pushrules/<_>/<_>/<_>",
     //data = "<body>"
 ))]
-pub fn set_pushrule_route(
-    //db: State<'_, Database>,
+pub fn set_pushrule_route(//db: State<'_, Database>,
     //body: Ruma<set_pushrule::Request>,
 ) -> ConduitResult<set_pushrule::Response> {
     // TODO
@@ -533,19 +549,14 @@ pub fn set_pushrule_route(
     feature = "conduit_bin",
     put("/_matrix/client/r0/pushrules/<_>/<_>/<_>/enabled")
 )]
-pub fn set_pushrule_enabled_route(
-) -> ConduitResult<set_pushrule_enabled::Response> {
+pub fn set_pushrule_enabled_route() -> ConduitResult<set_pushrule_enabled::Response> {
     // TODO
     warn!("TODO: set_pushrule_enabled_route");
     Ok(set_pushrule_enabled::Response.into())
 }
 
-#[cfg_attr(
-    feature = "conduit_bin",
-    get("/_matrix/client/r0/user/<_>/filter/<_>")
-)]
-pub fn get_filter_route(
-) -> ConduitResult<get_filter::Response> {
+#[cfg_attr(feature = "conduit_bin", get("/_matrix/client/r0/user/<_>/filter/<_>"))]
+pub fn get_filter_route() -> ConduitResult<get_filter::Response> {
     // TODO
     Ok(get_filter::Response {
         filter: filter::FilterDefinition {
@@ -559,10 +570,7 @@ pub fn get_filter_route(
     .into())
 }
 
-#[cfg_attr(
-    feature = "conduit_bin",
-    post("/_matrix/client/r0/user/<_>/filter")
-)]
+#[cfg_attr(feature = "conduit_bin", post("/_matrix/client/r0/user/<_>/filter"))]
 pub fn create_filter_route() -> ConduitResult<create_filter::Response> {
     // TODO
     Ok(create_filter::Response {
@@ -573,10 +581,7 @@ pub fn create_filter_route() -> ConduitResult<create_filter::Response> {
 
 #[cfg_attr(
     feature = "conduit_bin",
-    put(
-        "/_matrix/client/r0/user/<_>/account_data/<_>",
-        data = "<body>"
-    )
+    put("/_matrix/client/r0/user/<_>/account_data/<_>", data = "<body>")
 )]
 pub fn set_global_account_data_route(
     db: State<'_, Database>,
@@ -607,10 +612,7 @@ pub fn set_global_account_data_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    get(
-        "/_matrix/client/r0/user/<_>/account_data/<_>",
-        data = "<body>"
-    )
+    get("/_matrix/client/r0/user/<_>/account_data/<_>", data = "<body>")
 )]
 pub fn get_global_account_data_route(
     db: State<'_, Database>,
@@ -737,7 +739,8 @@ pub fn set_avatar_url_route(
         // TODO also make sure this is valid mxc:// format (not only starting with it)
     }
 
-    db.users.set_avatar_url(&sender_id, body.avatar_url.clone())?;
+    db.users
+        .set_avatar_url(&sender_id, body.avatar_url.clone())?;
 
     // Send a new membership event into all joined rooms
     for room_id in db.rooms.rooms_joined(&sender_id) {
@@ -1027,10 +1030,7 @@ pub fn create_backup_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    put(
-        "/_matrix/client/unstable/room_keys/version/<_>",
-        data = "<body>"
-    )
+    put("/_matrix/client/unstable/room_keys/version/<_>", data = "<body>")
 )]
 pub fn update_backup_route(
     db: State<'_, Database>,
@@ -1072,23 +1072,20 @@ pub fn get_latest_backup_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    get(
-        "/_matrix/client/unstable/room_keys/version/<_>",
-        data = "<body>"
-    )
+    get("/_matrix/client/unstable/room_keys/version/<_>", data = "<body>")
 )]
 pub fn get_backup_route(
     db: State<'_, Database>,
     body: Ruma<get_backup::Request>,
 ) -> ConduitResult<get_backup::Response> {
     let sender_id = body.sender_id.as_ref().expect("user is authenticated");
-    let algorithm =
-        db.key_backups
-            .get_backup(&sender_id, &body.version)?
-            .ok_or(Error::BadRequest(
-                ErrorKind::NotFound,
-                "Key backup does not exist.",
-            ))?;
+    let algorithm = db
+        .key_backups
+        .get_backup(&sender_id, &body.version)?
+        .ok_or(Error::BadRequest(
+            ErrorKind::NotFound,
+            "Key backup does not exist.",
+        ))?;
 
     Ok(get_backup::Response {
         algorithm,
@@ -1211,10 +1208,7 @@ pub fn set_read_marker_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    put(
-        "/_matrix/client/r0/rooms/<_>/typing/<_>",
-        data = "<body>"
-    )
+    put("/_matrix/client/r0/rooms/<_>/typing/<_>", data = "<body>")
 )]
 pub fn create_typing_event_route(
     db: State<'_, Database>,
@@ -1529,10 +1523,7 @@ pub fn joined_rooms_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    put(
-        "/_matrix/client/r0/rooms/<_>/redact/<_>/<_>",
-        data = "<body>"
-    )
+    put("/_matrix/client/r0/rooms/<_>/redact/<_>/<_>", data = "<body>")
 )]
 pub fn redact_event_route(
     db: State<'_, Database>,
@@ -1695,7 +1686,11 @@ pub fn leave_room_route(
 
     let mut event = serde_json::from_value::<Raw<member::MemberEventContent>>(
         db.rooms
-            .room_state_get(&body.room_id, &EventType::RoomMember, &sender_id.to_string())?
+            .room_state_get(
+                &body.room_id,
+                &EventType::RoomMember,
+                &sender_id.to_string(),
+            )?
             .ok_or(Error::BadRequest(
                 ErrorKind::BadState,
                 "Cannot leave a room you are not a member of.",
@@ -1735,7 +1730,11 @@ pub fn kick_user_route(
 
     let mut event = serde_json::from_value::<Raw<ruma::events::room::member::MemberEventContent>>(
         db.rooms
-            .room_state_get(&body.room_id, &EventType::RoomMember, &body.user_id.to_string())?
+            .room_state_get(
+                &body.room_id,
+                &EventType::RoomMember,
+                &body.user_id.to_string(),
+            )?
             .ok_or(Error::BadRequest(
                 ErrorKind::BadState,
                 "Cannot kick member that's not in the room.",
@@ -1774,7 +1773,11 @@ pub fn joined_members_route(
 ) -> ConduitResult<joined_members::Response> {
     let sender_id = body.sender_id.as_ref().expect("user is authenticated");
 
-    if !db.rooms.is_joined(&sender_id, &body.room_id).unwrap_or(false) {
+    if !db
+        .rooms
+        .is_joined(&sender_id, &body.room_id)
+        .unwrap_or(false)
+    {
         return Err(Error::BadRequest(
             ErrorKind::Forbidden,
             "You aren't a member of the room.",
@@ -1812,7 +1815,11 @@ pub fn ban_user_route(
 
     let event = db
         .rooms
-        .room_state_get(&body.room_id, &EventType::RoomMember, &body.user_id.to_string())?
+        .room_state_get(
+            &body.room_id,
+            &EventType::RoomMember,
+            &body.user_id.to_string(),
+        )?
         .map_or(
             Ok::<_, Error>(member::MemberEventContent {
                 membership: member::MembershipState::Ban,
@@ -1859,7 +1866,11 @@ pub fn unban_user_route(
 
     let mut event = serde_json::from_value::<Raw<ruma::events::room::member::MemberEventContent>>(
         db.rooms
-            .room_state_get(&body.room_id, &EventType::RoomMember, &body.user_id.to_string())?
+            .room_state_get(
+                &body.room_id,
+                &EventType::RoomMember,
+                &body.user_id.to_string(),
+            )?
             .ok_or(Error::BadRequest(
                 ErrorKind::BadState,
                 "Cannot unban a user who is not banned.",
@@ -2223,10 +2234,7 @@ pub fn get_protocols_route() -> ConduitResult<get_protocols::Response> {
 
 #[cfg_attr(
     feature = "conduit_bin",
-    get(
-        "/_matrix/client/r0/rooms/<_>/event/<_>",
-        data = "<body>"
-    )
+    get("/_matrix/client/r0/rooms/<_>/event/<_>", data = "<body>")
 )]
 pub fn get_room_event_route(
     db: State<'_, Database>,
@@ -2253,10 +2261,7 @@ pub fn get_room_event_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    put(
-        "/_matrix/client/r0/rooms/<_>/send/<_>/<_>",
-        data = "<body>"
-    )
+    put("/_matrix/client/r0/rooms/<_>/send/<_>/<_>", data = "<body>")
 )]
 pub fn create_message_event_route(
     db: State<'_, Database>,
@@ -2288,10 +2293,7 @@ pub fn create_message_event_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    put(
-        "/_matrix/client/r0/rooms/<_>/state/<_>/<_>",
-        data = "<body>"
-    )
+    put("/_matrix/client/r0/rooms/<_>/state/<_>/<_>", data = "<body>")
 )]
 pub fn create_state_event_for_key_route(
     db: State<'_, Database>,
@@ -2350,10 +2352,7 @@ pub fn create_state_event_for_key_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    put(
-        "/_matrix/client/r0/rooms/<_>/state/<_>",
-        data = "<body>"
-    )
+    put("/_matrix/client/r0/rooms/<_>/state/<_>", data = "<body>")
 )]
 pub fn create_state_event_for_empty_key_route(
     db: State<'_, Database>,
@@ -2423,10 +2422,7 @@ pub fn get_state_events_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    get(
-        "/_matrix/client/r0/rooms/<_>/state/<_>/<_>",
-        data = "<body>"
-    )
+    get("/_matrix/client/r0/rooms/<_>/state/<_>/<_>", data = "<body>")
 )]
 pub fn get_state_events_for_key_route(
     db: State<'_, Database>,
@@ -2458,10 +2454,7 @@ pub fn get_state_events_for_key_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    get(
-        "/_matrix/client/r0/rooms/<_>/state/<_>",
-        data = "<body>"
-    )
+    get("/_matrix/client/r0/rooms/<_>/state/<_>", data = "<body>")
 )]
 pub fn get_state_events_for_empty_key_route(
     db: State<'_, Database>,
@@ -2877,7 +2870,8 @@ pub async fn sync_events_route(
     };
 
     // TODO: Retry the endpoint instead of returning (waiting for #118)
-    if !body.full_state && response.rooms.is_empty()
+    if !body.full_state
+        && response.rooms.is_empty()
         && response.presence.is_empty()
         && response.account_data.is_empty()
         && response.device_lists.is_empty()
@@ -2902,10 +2896,7 @@ pub async fn sync_events_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    get(
-        "/_matrix/client/r0/rooms/<_>/context/<_>",
-        data = "<body>"
-    )
+    get("/_matrix/client/r0/rooms/<_>/context/<_>", data = "<body>")
 )]
 pub fn get_context_route(
     db: State<'_, Database>,
@@ -3089,10 +3080,7 @@ pub fn publicised_groups_route() -> ConduitResult<create_message_event::Response
 
 #[cfg_attr(
     feature = "conduit_bin",
-    put(
-        "/_matrix/client/r0/sendToDevice/<_>/<_>",
-        data = "<body>"
-    )
+    put("/_matrix/client/r0/sendToDevice/<_>/<_>", data = "<body>")
 )]
 pub fn send_event_to_device_route(
     db: State<'_, Database>,
@@ -3507,10 +3495,7 @@ pub fn set_pushers_route() -> ConduitResult<get_pushers::Response> {
 
 #[cfg_attr(
     feature = "conduit_bin",
-    put(
-        "/_matrix/client/r0/user/<_>/rooms/<_>/tags/<_>",
-        data = "<body>"
-    )
+    put("/_matrix/client/r0/user/<_>/rooms/<_>/tags/<_>", data = "<body>")
 )]
 pub fn update_tag_route(
     db: State<'_, Database>,
@@ -3544,10 +3529,7 @@ pub fn update_tag_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    delete(
-        "/_matrix/client/r0/user/<_>/rooms/<_>/tags/<_>",
-        data = "<body>"
-    )
+    delete("/_matrix/client/r0/user/<_>/rooms/<_>/tags/<_>", data = "<body>")
 )]
 pub fn delete_tag_route(
     db: State<'_, Database>,
@@ -3578,10 +3560,7 @@ pub fn delete_tag_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    get(
-        "/_matrix/client/r0/user/<_>/rooms/<_>/tags",
-        data = "<body>"
-    )
+    get("/_matrix/client/r0/user/<_>/rooms/<_>/tags", data = "<body>")
 )]
 pub fn get_tags_route(
     db: State<'_, Database>,
@@ -3606,7 +3585,6 @@ pub fn get_tags_route(
 
 #[cfg(feature = "conduit_bin")]
 #[options("/<_..>")]
-pub fn options_route(
-) -> ConduitResult<send_event_to_device::Response> {
+pub fn options_route() -> ConduitResult<send_event_to_device::Response> {
     Ok(send_event_to_device::Response.into())
 }
diff --git a/src/main.rs b/src/main.rs
index 2caee4cc..cc30ff64 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -28,6 +28,7 @@ fn setup_rocket() -> rocket::Rocket {
                 client_server::register_route,
                 client_server::get_login_route,
                 client_server::login_route,
+                client_server::whoami_route,
                 client_server::logout_route,
                 client_server::logout_all_route,
                 client_server::change_password_route,
diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs
index 66f4d4c1..05934365 100644
--- a/src/ruma_wrapper.rs
+++ b/src/ruma_wrapper.rs
@@ -11,9 +11,9 @@ use {
             Data, FromDataFuture, FromTransformedData, Transform, TransformFuture, Transformed,
         },
         http::Status,
+        outcome::Outcome::*,
         response::{self, Responder},
         tokio::io::AsyncReadExt,
-        outcome::Outcome::*,
         Request, State,
     },
     ruma::api::Endpoint,