From 29f5b5809861478a28a59ab5caf76e9d1262b575 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 16 Mar 2024 13:21:56 -0400 Subject: [PATCH 01/20] remove rocksdb optimize_level_style_compaction Signed-off-by: strawberry --- src/database/abstraction/rocksdb.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index 1511e980..584c6e70 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -101,7 +101,6 @@ fn db_options(rocksdb_cache: &rust_rocksdb::Cache, config: &Config) -> rust_rock threads.try_into().expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"), ); db_opts.set_compression_type(rocksdb_compression_algo); - db_opts.optimize_level_style_compaction(10 * 1024 * 1024); // https://github.com/facebook/rocksdb/wiki/Setup-Options-and-Basic-Tuning db_opts.set_level_compaction_dynamic_level_bytes(true); From 41f27dc94936b0b6d6c977885f6817010f5043aa Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 16 Mar 2024 13:22:15 -0400 Subject: [PATCH 02/20] slight wording updates Signed-off-by: strawberry --- src/api/client_server/state.rs | 2 +- src/service/rooms/timeline/mod.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/api/client_server/state.rs b/src/api/client_server/state.rs index d18608a2..f0457bc8 100644 --- a/src/api/client_server/state.rs +++ b/src/api/client_server/state.rs @@ -75,7 +75,7 @@ pub async fn send_state_event_for_empty_key_route( .into()) } -/// # `GET /_matrix/client/r0/rooms/{roomid}/state` +/// # `GET /_matrix/client/v3/rooms/{roomid}/state` /// /// Get all state events for a room. /// diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 7ec8320c..fa40b85f 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1007,12 +1007,12 @@ impl Service { return Ok(()); }, Err(e) => { - warn!("{backfill_server} could not provide backfill: {e}"); + warn!("{backfill_server} failed to provide backfill: {e}"); }, } } - info!("No servers could backfill"); + info!("No servers could backfill, but backfill was needed"); Ok(()) } From 94b4d584a667b53f55bb301846698fe20f1fac53 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 16 Mar 2024 13:23:26 -0400 Subject: [PATCH 03/20] admin command to see a room's full state from our database Signed-off-by: strawberry --- src/service/admin/mod.rs | 50 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index c8d772a4..bf6d2bc9 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -396,6 +396,21 @@ enum DebugCommand { server: Box, }, + /// - Gets all the room state events for the specified room. + /// + /// This is functionally equivalent to `GET + /// /_matrix/client/v3/rooms/{roomid}/state`, except the admin command does + /// *not* check if the sender user is allowed to see state events. This is + /// done because it's implied that server admins here have database access + /// and can see/get room info themselves anyways if they were malicious + /// admins. + /// + /// Of course the check is still done on the actual client API. + GetRoomState { + /// Room ID + room_id: Box, + }, + /// - Forces device lists for all local and remote users to be updated (as /// having new keys available) ForceDeviceListUpdates, @@ -2061,6 +2076,41 @@ impl Service { }, } }, + DebugCommand::GetRoomState { + room_id, + } => { + let room_state = services() + .rooms + .state_accessor + .room_state_full(&room_id) + .await? + .values() + .map(|pdu| pdu.to_state_event()) + .collect::>(); + + if room_state.is_empty() { + return Ok(RoomMessageEventContent::text_plain( + "Unable to find room state in our database (vector is empty)", + )); + } + + let json_text = serde_json::to_string_pretty(&room_state).map_err(|e| { + error!("Failed converting room state vector in our database to pretty JSON: {e}"); + Error::bad_database( + "Failed to convert room state events to pretty JSON, possible invalid room state events \ + in our database", + ) + })?; + + return Ok(RoomMessageEventContent::text_html( + format!("{}\n```json\n{}\n```", "Found full room state", json_text), + format!( + "

{}

\n
{}\n
\n", + "Found full room state", + HtmlEscape(&json_text) + ), + )); + }, DebugCommand::ForceDeviceListUpdates => { // Force E2EE device list updates for all users for user_id in services().users.iter().filter_map(std::result::Result::ok) { From 72182f3714e9a6e1f82361e08e71a8b320976bd6 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Sat, 16 Mar 2024 14:11:03 -0400 Subject: [PATCH 04/20] fix: avoid panics when admin room is not available Co-authored-by: strawberry Signed-off-by: strawberry --- src/api/client_server/account.rs | 10 +- src/service/admin/mod.rs | 310 +++++++++++++++--------------- src/service/rooms/timeline/mod.rs | 154 ++++++++------- 3 files changed, 233 insertions(+), 241 deletions(-) diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index fa327f56..26016802 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -275,10 +275,14 @@ pub async fn register_route(body: Ruma) -> Result { + let (mut message_content, reply) = match event { + AdminRoomEvent::SendMessage(content) => (content, None), + AdminRoomEvent::ProcessMessage(room_message, reply_id) => { + (self.process_admin_message(room_message).await, Some(reply_id)) + } + }; - loop { - tokio::select! { - Some(event) = receiver.recv() => { - let (mut message_content, reply) = match event { - AdminRoomEvent::SendMessage(content) => (content, None), - AdminRoomEvent::ProcessMessage(room_message, reply_id) => { - (self.process_admin_message(room_message).await, Some(reply_id)) + let mutex_state = Arc::clone( + services().globals + .roomid_mutex_state + .write() + .await + .entry(conduit_room.clone()) + .or_default(), + ); + + let state_lock = mutex_state.lock().await; + + if let Some(reply) = reply { + message_content.relates_to = Some(Reply { in_reply_to: InReplyTo { event_id: reply.into() } }); } - }; - let mutex_state = Arc::clone( - services().globals - .roomid_mutex_state - .write() - .await - .entry(conduit_room.clone()) - .or_default(), - ); + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMessage, + content: to_raw_value(&message_content) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }, + &conduit_user, + &conduit_room, + &state_lock) + .await + .unwrap(); - let state_lock = mutex_state.lock().await; - if let Some(reply) = reply { - message_content.relates_to = Some(Reply { in_reply_to: InReplyTo { event_id: reply.into() } }); + drop(state_lock); } - - services().rooms.timeline.build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMessage, - content: to_raw_value(&message_content) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }, - &conduit_user, - &conduit_room, - &state_lock) - .await - .unwrap(); - - - drop(state_lock); } } } @@ -1111,14 +1101,13 @@ impl Service { format!("#admins:{}", services().globals.server_name()) .try_into() .expect("#admins:server_name is a valid alias name"); - let admin_room_id = services() - .rooms - .alias - .resolve_local_alias(&admin_room_alias)? - .expect("Admin room must exist"); - if room.to_string().eq(&admin_room_id) || room.to_string().eq(&admin_room_alias) { - return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room.")); + if let Some(admin_room_id) = services().admin.get_admin_room()? { + if room.to_string().eq(&admin_room_id) || room.to_string().eq(&admin_room_alias) { + return Ok(RoomMessageEventContent::text_plain( + "Not allowed to ban the admin room.", + )); + } } let room_id = if room.is_room_id() { @@ -1282,23 +1271,15 @@ impl Service { let mut room_ban_count = 0; let mut room_ids: Vec<&RoomId> = Vec::new(); - let admin_room_alias: Box = - format!("#admins:{}", services().globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); - let admin_room_id = services() - .rooms - .alias - .resolve_local_alias(&admin_room_alias)? - .expect("Admin room must exist"); - for &room_id in &rooms_s { match <&RoomId>::try_from(room_id) { Ok(owned_room_id) => { // silently ignore deleting admin room - if owned_room_id.eq(&admin_room_id) { - info!("User specified admin room in bulk ban list, ignoring"); - continue; + if let Some(admin_room_id) = services().admin.get_admin_room()? { + if owned_room_id.eq(&admin_room_id) { + info!("User specified admin room in bulk ban list, ignoring"); + continue; + } } room_ids.push(owned_room_id); @@ -2443,105 +2424,113 @@ impl Service { Ok(()) } + /// Gets the room ID of the admin room + /// + /// Errors are propagated from the database, and will have None if there is + /// no admin room + pub(crate) fn get_admin_room(&self) -> Result> { + let admin_room_alias: Box = format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); + + services().rooms.alias.resolve_local_alias(&admin_room_alias) + } + /// Invite the user to the conduit admin room. /// /// In conduit, this is equivalent to granting admin privileges. pub(crate) async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Result<()> { - let admin_room_alias: Box = format!("#admins:{}", services().globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); - let room_id = services().rooms.alias.resolve_local_alias(&admin_room_alias)?.expect("Admin room must exist"); + if let Some(room_id) = services().admin.get_admin_room()? { + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); - let state_lock = mutex_state.lock().await; + // Use the server user to grant the new admin's power level + let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("@conduit:server_name is valid"); - // Use the server user to grant the new admin's power level - let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) - .expect("@conduit:server_name is valid"); + // Invite and join the real user + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Invite, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: Some(displayname), + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + user_id, + &room_id, + &state_lock, + ) + .await?; - // Invite and join the real user - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Invite, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: Some(displayname), - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - user_id, - &room_id, - &state_lock, - ) - .await?; + // Set power level + let mut users = BTreeMap::new(); + users.insert(conduit_user.clone(), 100.into()); + users.insert(user_id.to_owned(), 100.into()); - // Set power level - let mut users = BTreeMap::new(); - users.insert(conduit_user.clone(), 100.into()); - users.insert(user_id.to_owned(), 100.into()); + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomPowerLevels, + content: to_raw_value(&RoomPowerLevelsEventContent { + users, + ..Default::default() + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - - // Send welcome message - services().rooms.timeline.build_and_append_pdu( + // Send welcome message + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&RoomMessageEventContent::text_html( @@ -2558,7 +2547,10 @@ impl Service { &state_lock, ).await?; - Ok(()) + Ok(()) + } else { + Ok(()) + } } } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index fa40b85f..92f0a09d 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -26,7 +26,7 @@ use ruma::{ state_res, state_res::{Event, RoomVersion}, uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName, - RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, + RoomId, RoomVersionId, ServerName, UserId, }; use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; @@ -461,10 +461,6 @@ impl Service { if let Some(body) = content.body { services().rooms.search.index_pdu(shortroomid, &pdu_id, &body)?; - let admin_room = services().rooms.alias.resolve_local_alias( - <&RoomAliasId>::try_from(format!("#admins:{}", services().globals.server_name()).as_str()) - .expect("#admins:server_name is a valid room alias"), - )?; let server_user = format!("@conduit:{}", services().globals.server_name()); let to_conduit = body.starts_with(&format!("{server_user}: ")) @@ -477,8 +473,10 @@ impl Service { // the administrator can execute commands as conduit let from_conduit = pdu.sender == server_user && services().globals.emergency_password().is_none(); - if to_conduit && !from_conduit && admin_room.as_ref() == Some(&pdu.room_id) { - services().admin.process_message(body, pdu.event_id.clone()); + if let Some(admin_room) = services().admin.get_admin_room()? { + if to_conduit && !from_conduit && admin_room == pdu.room_id { + services().admin.process_message(body, pdu.event_id.clone()); + } } } }, @@ -720,84 +718,82 @@ impl Service { ) -> Result> { let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; - let admin_room = services().rooms.alias.resolve_local_alias( - <&RoomAliasId>::try_from(format!("#admins:{}", services().globals.server_name()).as_str()) - .expect("#admins:server_name is a valid room alias"), - )?; - if admin_room.filter(|v| v == room_id).is_some() { - match pdu.event_type() { - TimelineEventType::RoomEncryption => { - warn!("Encryption is not allowed in the admins room"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Encryption is not allowed in the admins room.", - )); - }, - TimelineEventType::RoomMember => { - #[derive(Deserialize)] - struct ExtractMembership { - membership: MembershipState, - } - - let target = pdu.state_key().filter(|v| v.starts_with('@')).unwrap_or(sender.as_str()); - let server_name = services().globals.server_name(); - let server_user = format!("@conduit:{server_name}"); - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - - if content.membership == MembershipState::Leave { - if target == server_user { - warn!("Conduit user cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Conduit user cannot leave from admins room.", - )); + if let Some(admin_room) = services().admin.get_admin_room()? { + if admin_room == room_id { + match pdu.event_type() { + TimelineEventType::RoomEncryption => { + warn!("Encryption is not allowed in the admins room"); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Encryption is not allowed in the admins room.", + )); + }, + TimelineEventType::RoomMember => { + #[derive(Deserialize)] + struct ExtractMembership { + membership: MembershipState, } - let count = services() - .rooms - .state_cache - .room_members(room_id) - .filter_map(std::result::Result::ok) - .filter(|m| m.server_name() == server_name) - .filter(|m| m != target) - .count(); - if count < 2 { - warn!("Last admin cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Last admin cannot leave from admins room.", - )); - } - } + let target = pdu.state_key().filter(|v| v.starts_with('@')).unwrap_or(sender.as_str()); + let server_name = services().globals.server_name(); + let server_user = format!("@conduit:{server_name}"); + let content = serde_json::from_str::(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - if content.membership == MembershipState::Ban && pdu.state_key().is_some() { - if target == server_user { - warn!("Conduit user cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Conduit user cannot be banned in admins room.", - )); + if content.membership == MembershipState::Leave { + if target == server_user { + warn!("Conduit user cannot leave from admins room"); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Conduit user cannot leave from admins room.", + )); + } + + let count = services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(std::result::Result::ok) + .filter(|m| m.server_name() == server_name) + .filter(|m| m != target) + .count(); + if count < 2 { + warn!("Last admin cannot leave from admins room"); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Last admin cannot leave from admins room.", + )); + } } - let count = services() - .rooms - .state_cache - .room_members(room_id) - .filter_map(std::result::Result::ok) - .filter(|m| m.server_name() == server_name) - .filter(|m| m != target) - .count(); - if count < 2 { - warn!("Last admin cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Last admin cannot be banned in admins room.", - )); + if content.membership == MembershipState::Ban && pdu.state_key().is_some() { + if target == server_user { + warn!("Conduit user cannot be banned in admins room"); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Conduit user cannot be banned in admins room.", + )); + } + + let count = services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(std::result::Result::ok) + .filter(|m| m.server_name() == server_name) + .filter(|m| m != target) + .count(); + if count < 2 { + warn!("Last admin cannot be banned in admins room"); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Last admin cannot be banned in admins room.", + )); + } } - } - }, - _ => {}, + }, + _ => {}, + } } } From aec63c29e1d7c14e9a8c99b96a6ff60cf6708645 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Sat, 16 Mar 2024 16:05:52 -0400 Subject: [PATCH 05/20] refactor: check if federation is disabled inside the authcheck where possible Signed-off-by: strawberry --- src/api/ruma_wrapper/axum.rs | 4 ++ src/api/server_server.rs | 72 ++---------------------------------- 2 files changed, 8 insertions(+), 68 deletions(-) diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index a208a359..716f01e7 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -153,6 +153,10 @@ where // treat non-appservice registrations as None authentication AuthScheme::AppserviceToken => (None, None, None, false), AuthScheme::ServerSignatures => { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + let TypedHeader(Authorization(x_matrix)) = parts.extract::>>().await.map_err(|e| { warn!("Missing or invalid Authorization header: {}", e); diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 70413f22..5158ca45 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -619,10 +619,6 @@ pub async fn get_server_keys_deprecated_route() -> impl IntoResponse { get_serve pub async fn get_public_rooms_filtered_route( body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - if !services().globals.allow_public_room_directory_over_federation() { return Err(Error::bad_config("Room directory is not public.")); } @@ -650,10 +646,6 @@ pub async fn get_public_rooms_filtered_route( pub async fn get_public_rooms_route( body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - if !services().globals.allow_public_room_directory_over_federation() { return Err(Error::bad_config("Room directory is not public.")); } @@ -707,10 +699,6 @@ pub fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, Canonical pub async fn send_transaction_message_route( body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let mut resolved_map = BTreeMap::new(); @@ -946,10 +934,6 @@ pub async fn send_transaction_message_route( /// - Only works if a user of this server is currently invited or joined the /// room pub async fn get_event_route(body: Ruma) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let event = services().rooms.timeline.get_pdu_json(&body.event_id)?.ok_or_else(|| { @@ -985,10 +969,6 @@ pub async fn get_event_route(body: Ruma) -> Result) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); debug!("Got backfill request from: {}", sender_servername); @@ -1041,10 +1021,6 @@ pub async fn get_backfill_route(body: Ruma) -> Result pub async fn get_missing_events_route( body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { @@ -1118,10 +1094,6 @@ pub async fn get_missing_events_route( pub async fn get_event_authorization_route( body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { @@ -1157,10 +1129,6 @@ pub async fn get_event_authorization_route( /// /// Retrieves the current state of the room. pub async fn get_room_state_route(body: Ruma) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { @@ -1211,10 +1179,6 @@ pub async fn get_room_state_route(body: Ruma) -> Re pub async fn get_room_state_ids_route( body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { @@ -1253,10 +1217,6 @@ pub async fn get_room_state_ids_route( pub async fn create_join_event_template_route( body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - if !services().rooms.metadata.exists(&body.room_id)? { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } @@ -1343,10 +1303,6 @@ pub async fn create_join_event_template_route( async fn create_join_event( sender_servername: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - if !services().rooms.metadata.exists(room_id)? { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } @@ -1500,10 +1456,6 @@ pub async fn create_join_event_v2_route( /// /// Invites a remote user to a room. pub async fn create_invite_route(body: Ruma) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; @@ -1622,10 +1574,6 @@ pub async fn create_invite_route(body: Ruma) -> Resu /// /// Gets information on all devices of the user. pub async fn get_devices_route(body: Ruma) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - if body.user_id.server_name() != services().globals.server_name() { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -1673,10 +1621,6 @@ pub async fn get_devices_route(body: Ruma) -> Result, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - let room_id = services() .rooms .alias @@ -1695,10 +1639,6 @@ pub async fn get_room_information_route( pub async fn get_profile_information_route( body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - if body.user_id.server_name() != services().globals.server_name() { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -1738,10 +1678,6 @@ pub async fn get_profile_information_route( /// /// Gets devices and identity keys for the given users. pub async fn get_keys_route(body: Ruma) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - if body.device_keys.iter().any(|(u, _)| u.server_name() != services().globals.server_name()) { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -1768,10 +1704,6 @@ pub async fn get_keys_route(body: Ruma) -> Result) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - if body.one_time_keys.iter().any(|(u, _)| u.server_name() != services().globals.server_name()) { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -1788,6 +1720,10 @@ pub async fn claim_keys_route(body: Ruma) -> Result Result { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + let server_url = match services().globals.well_known_server() { Some(url) => url.clone(), None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), From 8972487691f1d761c2a679ecf9089be696d5e3ce Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 16 Mar 2024 16:07:42 -0400 Subject: [PATCH 06/20] check allow_federation in send_federation_request Signed-off-by: strawberry --- src/service/sending/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 86de6f1c..529e1631 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -647,6 +647,10 @@ impl Service { where T: OutgoingRequest + Debug, { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) { info!( "Destination {} is an IP literal, checking against IP range denylist.", From c14b28b4088d3bf35e39e504e9029b25cb54f82c Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Sat, 16 Mar 2024 16:09:11 -0400 Subject: [PATCH 07/20] feat(spaces): hierarchy over federation Signed-off-by: strawberry --- src/api/client_server/space.rs | 38 +- src/api/server_server.rs | 15 + src/database/key_value/rooms/state_cache.rs | 2 + src/main.rs | 1 + src/service/mod.rs | 8 +- src/service/rooms/spaces/mod.rs | 1385 +++++++++++++++---- src/service/rooms/state/mod.rs | 2 +- src/service/rooms/timeline/mod.rs | 2 +- 8 files changed, 1138 insertions(+), 315 deletions(-) diff --git a/src/api/client_server/space.rs b/src/api/client_server/space.rs index ac6139a5..0cbd6057 100644 --- a/src/api/client_server/space.rs +++ b/src/api/client_server/space.rs @@ -1,6 +1,11 @@ -use ruma::api::client::space::get_hierarchy; +use std::str::FromStr; -use crate::{services, Result, Ruma}; +use ruma::{ + api::client::{error::ErrorKind, space::get_hierarchy}, + UInt, +}; + +use crate::{service::rooms::spaces::PagnationToken, services, Error, Result, Ruma}; /// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy`` /// @@ -9,11 +14,32 @@ use crate::{services, Result, Ruma}; pub async fn get_hierarchy_route(body: Ruma) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let skip = body.from.as_ref().and_then(|s| s.parse::().ok()).unwrap_or(0); + let limit = body.limit.unwrap_or_else(|| UInt::from(10_u32)).min(UInt::from(100_u32)); - let limit = body.limit.map_or(10, u64::from).min(100) as usize; + let max_depth = body.max_depth.unwrap_or_else(|| UInt::from(3_u32)).min(UInt::from(10_u32)); - let max_depth = body.max_depth.map_or(3, u64::from).min(10) as usize + 1; // +1 to skip the space room itself + let key = body.from.as_ref().and_then(|s| PagnationToken::from_str(s).ok()); - services().rooms.spaces.get_hierarchy(sender_user, &body.room_id, limit, skip, max_depth, body.suggested_only).await + // Should prevent unexpeded behaviour in (bad) clients + if let Some(ref token) = key { + if token.suggested_only != body.suggested_only || token.max_depth != max_depth { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "suggested_only and max_depth cannot change on paginated requests", + )); + } + } + + services() + .rooms + .spaces + .get_client_hierarchy( + sender_user, + &body.room_id, + u64::from(limit) as usize, + key.map_or(0, |token| u64::from(token.skip) as usize), + u64::from(max_depth) as usize, + body.suggested_only, + ) + .await } diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 5158ca45..4f28e271 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -28,6 +28,7 @@ use ruma::{ keys::{claim_keys, get_keys}, membership::{create_invite, create_join_event, prepare_join_event}, query::{get_profile_information, get_room_information}, + space::get_hierarchy, transactions::{ edu::{DeviceListUpdateContent, DirectDeviceContent, Edu, SigningKeyUpdateContent}, send_transaction_message, @@ -1734,6 +1735,20 @@ pub async fn well_known_server_route() -> Result { }))) } +/// # `GET /_matrix/federation/v1/hierarchy/{roomId}` +/// +/// Gets the space tree in a depth-first manner to locate child rooms of a given +/// space. +pub async fn get_hierarchy_route(body: Ruma) -> Result { + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + + if services().rooms.metadata.exists(&body.room_id)? { + services().rooms.spaces.get_federation_hierarchy(&body.room_id, sender_servername, body.suggested_only).await + } else { + Err(Error::BadRequest(ErrorKind::NotFound, "Room does not exist.")) + } +} + #[cfg(test)] mod tests { use super::{add_port_to_hostname, get_ip_with_port, FedDest}; diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 6c19dbe8..1645c874 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -268,6 +268,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { })) } + /// Returns the number of users which are currently in a room #[tracing::instrument(skip(self))] fn room_joined_count(&self, room_id: &RoomId) -> Result> { self.roomid_joinedcount @@ -276,6 +277,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { .transpose() } + /// Returns the number of users which are currently invited to a room #[tracing::instrument(skip(self))] fn room_invited_count(&self, room_id: &RoomId) -> Result> { self.roomid_invitedcount diff --git a/src/main.rs b/src/main.rs index ed163519..03c9bdf6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -739,6 +739,7 @@ fn routes() -> Router { .ruma_route(server_server::get_profile_information_route) .ruma_route(server_server::get_keys_route) .ruma_route(server_server::claim_keys_route) + .ruma_route(server_server::get_hierarchy_route) .route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync)) .route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync)) .route("/client/server.json", get(client_server::syncv3_client_server_json)) diff --git a/src/service/mod.rs b/src/service/mod.rs index 25413647..8d37061c 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -134,7 +134,7 @@ impl Services<'_> { db, }, spaces: rooms::spaces::Service { - roomid_spacechunk_cache: Mutex::new(LruCache::new( + roomid_spacehierarchy_cache: Mutex::new(LruCache::new( (100.0 * config.conduit_cache_capacity_modifier) as usize, )), }, @@ -175,7 +175,7 @@ impl Services<'_> { let user_visibility_cache = self.rooms.state_accessor.user_visibility_cache.lock().unwrap().len(); let stateinfo_cache = self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len(); let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().await.len(); - let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().await.len(); + let roomid_spacehierarchy_cache = self.rooms.spaces.roomid_spacehierarchy_cache.lock().await.len(); format!( "\ @@ -184,7 +184,7 @@ server_visibility_cache: {server_visibility_cache} user_visibility_cache: {user_visibility_cache} stateinfo_cache: {stateinfo_cache} lasttimelinecount_cache: {lasttimelinecount_cache} -roomid_spacechunk_cache: {roomid_spacechunk_cache}" +roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache}" ) } @@ -205,7 +205,7 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}" self.rooms.timeline.lasttimelinecount_cache.lock().await.clear(); } if amount > 5 { - self.rooms.spaces.roomid_spacechunk_cache.lock().await.clear(); + self.rooms.spaces.roomid_spacehierarchy_cache.lock().await.clear(); } } } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 4e7ffcca..20a50d71 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -1,13 +1,13 @@ -use std::sync::Arc; +use std::str::FromStr; use lru_cache::LruCache; use ruma::{ api::{ - client::{ - error::ErrorKind, - space::{get_hierarchy, SpaceHierarchyRoomsChunk}, + client::{self, error::ErrorKind, space::SpaceHierarchyRoomsChunk}, + federation::{ + self, + space::{SpaceHierarchyChildSummary, SpaceHierarchyParentSummary}, }, - federation, }, events::{ room::{ @@ -16,255 +16,503 @@ use ruma::{ create::RoomCreateEventContent, guest_access::{GuestAccess, RoomGuestAccessEventContent}, history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, - join_rules::{self, AllowRule, JoinRule, RoomJoinRulesEventContent}, + join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent, RoomMembership}, topic::RoomTopicEventContent, }, - space::child::SpaceChildEventContent, + space::child::{HierarchySpaceChildEvent, SpaceChildEventContent}, StateEventType, }, + serde::Raw, space::SpaceRoomJoinRule, - OwnedRoomId, RoomId, UserId, + OwnedRoomId, RoomId, ServerName, UInt, UserId, }; use tokio::sync::Mutex; -use tracing::{debug, error, warn}; +use tracing::{debug, error, info, warn}; -use crate::{services, Error, PduEvent, Result}; +use crate::{services, Error, Result}; -pub enum CachedJoinRule { - //Simplified(SpaceRoomJoinRule), - Full(JoinRule), +pub struct CachedSpaceHierarchySummary { + summary: SpaceHierarchyParentSummary, } -pub struct CachedSpaceChunk { - chunk: SpaceHierarchyRoomsChunk, - children: Vec, - join_rule: CachedJoinRule, +pub enum SummaryAccessibility { + Accessible(Box), + Inaccessible, +} + +pub struct Arena { + nodes: Vec, + max_depth: usize, + first_untraversed: Option, +} + +pub struct Node { + parent: Option, + // Next meaning: + // --> + // o o o o + next_sibling: Option, + // First meaning: + // | + // v + // o o o o + first_child: Option, + pub room_id: OwnedRoomId, + traversed: bool, +} + +#[derive(Clone, Copy, PartialEq, Debug, PartialOrd)] +pub struct NodeId { + index: usize, +} + +impl Arena { + /// Checks if a given node is traversed + fn traversed(&self, id: NodeId) -> Option { Some(self.get(id)?.traversed) } + + /// Gets the previous sibling of a given node + fn next_sibling(&self, id: NodeId) -> Option { self.get(id)?.next_sibling } + + /// Gets the parent of a given node + fn parent(&self, id: NodeId) -> Option { self.get(id)?.parent } + + /// Gets the last child of a given node + fn first_child(&self, id: NodeId) -> Option { self.get(id)?.first_child } + + /// Sets traversed to true for a given node + fn traverse(&mut self, id: NodeId) { self.nodes[id.index].traversed = true; } + + /// Gets the node of a given id + fn get(&self, id: NodeId) -> Option<&Node> { self.nodes.get(id.index) } + + /// Gets a mutable reference of a node of a given id + fn get_mut(&mut self, id: NodeId) -> Option<&mut Node> { self.nodes.get_mut(id.index) } + + /// Returns the first untraversed node, marking it as traversed in the + /// process + pub fn first_untraversed(&mut self) -> Option { + if self.nodes.is_empty() { + None + } else if let Some(untraversed) = self.first_untraversed { + let mut current = untraversed; + + self.traverse(untraversed); + + // Possible paths: + // 1) Next child exists, and hence is not traversed + // 2) Next child does not exist, so go to the parent, then repeat + // 3) If both the parent and child do not exist, then we have just traversed the + // whole space tree. + // + // You should only ever encounter a traversed node when going up through parents + while let Some(true) = self.traversed(current) { + if let Some(next) = self.next_sibling(current) { + current = next; + } else if let Some(parent) = self.parent(current) { + current = parent + } else { + break; + } + } + + // Traverses down the children until it reaches one without children + while let Some(child) = self.first_child(current) { + current = child; + } + + if self.traversed(current)? { + self.first_untraversed = None; + } else { + self.first_untraversed = Some(current); + } + + Some(untraversed) + } else { + None + } + } + + /// Adds all the given nodes as children of the parent node + pub fn push(&mut self, parent: NodeId, mut children: Vec) { + if children.is_empty() { + self.traverse(parent); + } else if self.nodes.get(parent.index).is_some() { + let mut parents = vec![( + parent, + self.get(parent) + .expect("It is some, as above") + .room_id + // Cloning cause otherwise when iterating over the parents, below, there would + // be a mutable and immutable reference to self.nodes + .clone(), + )]; + + while let Some(parent) = self.parent(parents.last().expect("Has at least one value, as above").0) { + parents.push((parent, self.get(parent).expect("It is some, as above").room_id.clone())) + } + + // If at max_depth, don't add new rooms + if self.max_depth < parents.len() { + return; + } + + children.reverse(); + + let mut next_id = None; + + for child in children { + // Prevent adding a child which is a parent (recursion) + if !parents.iter().any(|parent| parent.1 == child) { + self.nodes.push(Node { + parent: Some(parent), + next_sibling: next_id, + first_child: None, + room_id: child, + traversed: false, + }); + + next_id = Some(NodeId { + index: self.nodes.len() - 1, + }); + } + } + + if self.first_untraversed.is_none() + || parent >= self.first_untraversed.expect("Should have already continued if none") + { + self.first_untraversed = next_id; + } + + self.traverse(parent); + + // This is done as if we use an if-let above, we cannot reference self.nodes + // above as then we would have multiple mutable references + let node = self.get_mut(parent).expect("Must be some, as inside this block"); + + node.first_child = next_id; + } + } + + pub fn new(root: OwnedRoomId, max_depth: usize) -> Self { + let zero_depth = max_depth == 0; + + Arena { + nodes: vec![Node { + parent: None, + next_sibling: None, + first_child: None, + room_id: root, + traversed: zero_depth, + }], + max_depth, + first_untraversed: if zero_depth { + None + } else { + Some(NodeId { + index: 0, + }) + }, + } + } +} + +// Note: perhaps use some better form of token rather than just room count +#[derive(Debug, PartialEq)] +pub struct PagnationToken { + pub skip: UInt, + pub limit: UInt, + pub max_depth: UInt, + pub suggested_only: bool, +} + +impl FromStr for PagnationToken { + type Err = Error; + + fn from_str(value: &str) -> Result { + let mut values = value.split('_'); + + let mut pag_tok = || { + Some(PagnationToken { + skip: UInt::from_str(values.next()?).ok()?, + limit: UInt::from_str(values.next()?).ok()?, + max_depth: UInt::from_str(values.next()?).ok()?, + suggested_only: { + let slice = values.next()?; + + if values.next().is_none() { + if slice == "true" { + true + } else if slice == "false" { + false + } else { + None? + } + } else { + None? + } + }, + }) + }; + + if let Some(token) = pag_tok() { + Ok(token) + } else { + Err(Error::BadRequest(ErrorKind::InvalidParam, "invalid token")) + } + } +} + +impl std::fmt::Display for PagnationToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}_{}_{}_{}", self.skip, self.limit, self.max_depth, self.suggested_only) + } +} + +/// Identifier used to check if rooms are accessible +/// +/// None is used if you want to return the room, no matter if accessible or not +pub enum Identifier<'a> { + UserId(&'a UserId), + ServerName(&'a ServerName), + None, } pub struct Service { - pub roomid_spacechunk_cache: Mutex>>, + pub roomid_spacehierarchy_cache: Mutex>>, +} + +// Here because cannot implement `From` across ruma-federation-api and +// ruma-client-api types +impl From for SpaceHierarchyRoomsChunk { + fn from(value: CachedSpaceHierarchySummary) -> Self { + let SpaceHierarchyParentSummary { + canonical_alias, + name, + num_joined_members, + room_id, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + children_state, + .. + } = value.summary; + + SpaceHierarchyRoomsChunk { + canonical_alias, + name, + num_joined_members, + room_id, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + children_state, + } + } } impl Service { - pub async fn get_hierarchy( - &self, sender_user: &UserId, room_id: &RoomId, limit: usize, skip: usize, max_depth: usize, - suggested_only: bool, - ) -> Result { - let mut left_to_skip = skip; + ///Gets the response for the space hierarchy over federation request + /// + ///Panics if the room does not exist, so a check if the room exists should + /// be done + pub async fn get_federation_hierarchy( + &self, room_id: &RoomId, server_name: &ServerName, suggested_only: bool, + ) -> Result { + match self.get_summary_and_children(&room_id.to_owned(), suggested_only, Identifier::None).await? { + Some(SummaryAccessibility::Accessible(room)) => { + let mut children = Vec::new(); + let mut inaccessible_children = Vec::new(); - let mut rooms_in_path = Vec::new(); - let mut stack = vec![vec![room_id.to_owned()]]; - let mut results = Vec::new(); + for child in get_parent_children(*room.clone(), suggested_only) { + match self + .get_summary_and_children(&child, suggested_only, Identifier::ServerName(server_name)) + .await? + { + Some(SummaryAccessibility::Accessible(summary)) => { + children.push((*summary).into()); + }, + Some(SummaryAccessibility::Inaccessible) => { + inaccessible_children.push(child); + }, + None => (), + } + } - while let Some(current_room) = { - while stack.last().map_or(false, Vec::is_empty) { - stack.pop(); - } - if !stack.is_empty() { - stack.last_mut().and_then(Vec::pop) + Ok(federation::space::get_hierarchy::v1::Response { + room: *room, + children, + inaccessible_children, + }) + }, + Some(SummaryAccessibility::Inaccessible) => { + Err(Error::BadRequest(ErrorKind::NotFound, "The requested room is inaccessible")) + }, + None => Err(Error::BadRequest(ErrorKind::NotFound, "The requested room was not found")), + } + } + + async fn get_summary_and_children( + &self, current_room: &OwnedRoomId, suggested_only: bool, identifier: Identifier<'_>, + ) -> Result> { + if let Some(cached) = self.roomid_spacehierarchy_cache.lock().await.get_mut(¤t_room.to_owned()).as_ref() { + return Ok(if let Some(cached) = cached { + if is_accessable_child( + current_room, + &cached.summary.join_rule, + &identifier, + &cached.summary.allowed_room_ids, + )? { + Some(SummaryAccessibility::Accessible(Box::new(cached.summary.clone()))) + } else { + Some(SummaryAccessibility::Inaccessible) + } } else { None - } - } { - rooms_in_path.push(current_room.clone()); - if results.len() >= limit { - break; - } + }); + } - if let Some(cached) = self.roomid_spacechunk_cache.lock().await.get_mut(¤t_room.clone()).as_ref() { - if let Some(cached) = cached { - let allowed = match &cached.join_rule { - //CachedJoinRule::Simplified(s) => { - //self.handle_simplified_join_rule(s, sender_user, ¤t_room)? - //} - CachedJoinRule::Full(f) => self.handle_join_rule(f, sender_user, ¤t_room)?, - }; - if allowed { - if left_to_skip > 0 { - left_to_skip -= 1; - } else { - results.push(cached.chunk.clone()); - } - if rooms_in_path.len() < max_depth { - stack.push(cached.children.clone()); - } - } - } - continue; - } - - if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(¤t_room)? { - let state = services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; - - let mut children_ids = Vec::new(); - let mut children_pdus = Vec::new(); - for (key, id) in state { - let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?; - if event_type != StateEventType::SpaceChild { - continue; - } - - let pdu = services() - .rooms - .timeline - .get_pdu(&id)? - .ok_or_else(|| Error::bad_database("Event in space state not found"))?; - - if serde_json::from_str::(pdu.content.get()) - .ok() - .map(|c| c.via) - .map_or(true, |v| v.is_empty()) - { - continue; - } - - if let Ok(room_id) = OwnedRoomId::try_from(state_key) { - children_ids.push(room_id); - children_pdus.push(pdu); - } - } - - // TODO: Sort children - children_ids.reverse(); - - let chunk = self.get_room_chunk(sender_user, ¤t_room, children_pdus).await; - if let Ok(chunk) = chunk { - if left_to_skip > 0 { - left_to_skip -= 1; - } else { - results.push(chunk.clone()); - } - let join_rule = services() - .rooms - .state_accessor - .room_state_get(¤t_room, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| c.join_rule) - .map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") - }) - }) - .transpose()? - .unwrap_or(JoinRule::Invite); - - self.roomid_spacechunk_cache.lock().await.insert( + Ok( + if let Some(children_pdus) = get_stripped_space_child_events(current_room).await? { + let summary = self.get_room_summary(current_room, children_pdus, identifier); + if let Ok(summary) = summary { + self.roomid_spacehierarchy_cache.lock().await.insert( current_room.clone(), - Some(CachedSpaceChunk { - chunk, - children: children_ids.clone(), - join_rule: CachedJoinRule::Full(join_rule), + Some(CachedSpaceHierarchySummary { + summary: summary.clone(), }), ); - } - if rooms_in_path.len() < max_depth { - stack.push(children_ids); + Some(SummaryAccessibility::Accessible(Box::new(summary))) + } else { + None } - } else if let Some(server) = current_room.server_name() { + // Federation requests should not request information from other + // servers + } else if let Identifier::UserId(_) = identifier { + let server = current_room.server_name().expect("Room IDs should always have a server name"); if server == services().globals.server_name() { - continue; + return Ok(None); } - if !results.is_empty() { - // Early return so the client can see some data already - break; - } - - debug!("Asking {server} for /hierarchy"); + info!("Asking {server} for /hierarchy"); if let Ok(response) = services() .sending .send_federation_request( server, federation::space::get_hierarchy::v1::Request { - room_id: current_room.clone(), + room_id: current_room.to_owned(), suggested_only, }, ) .await { - debug!("Got response from {server} for /hierarchy\n{response:?}"); - let chunk = SpaceHierarchyRoomsChunk { - canonical_alias: response.room.canonical_alias, - name: response.room.name, - num_joined_members: response.room.num_joined_members, - room_id: response.room.room_id, - topic: response.room.topic, - world_readable: response.room.world_readable, - guest_can_join: response.room.guest_can_join, - avatar_url: response.room.avatar_url, - join_rule: response.room.join_rule.clone(), - room_type: response.room.room_type, - children_state: response.room.children_state, - }; - let children = response.children.iter().map(|c| c.room_id.clone()).collect::>(); + info!("Got response from {server} for /hierarchy\n{response:?}"); + let summary = response.room.clone(); - let join_rule = match response.room.join_rule { - SpaceRoomJoinRule::Invite => JoinRule::Invite, - SpaceRoomJoinRule::Knock => JoinRule::Knock, - SpaceRoomJoinRule::Private => JoinRule::Private, - SpaceRoomJoinRule::Restricted => JoinRule::Restricted(join_rules::Restricted { - allow: response.room.allowed_room_ids.into_iter().map(AllowRule::room_membership).collect(), + self.roomid_spacehierarchy_cache.lock().await.insert( + current_room.clone(), + Some(CachedSpaceHierarchySummary { + summary: summary.clone(), }), - SpaceRoomJoinRule::KnockRestricted => JoinRule::KnockRestricted(join_rules::Restricted { - allow: response.room.allowed_room_ids.into_iter().map(AllowRule::room_membership).collect(), - }), - SpaceRoomJoinRule::Public => JoinRule::Public, - _ => return Err(Error::BadServerResponse("Unknown join rule")), - }; - if self.handle_join_rule(&join_rule, sender_user, ¤t_room)? { - if left_to_skip > 0 { - left_to_skip -= 1; - } else { - results.push(chunk.clone()); - } - if rooms_in_path.len() < max_depth { - stack.push(children.clone()); + ); + + for child in response.children { + let mut guard = self.roomid_spacehierarchy_cache.lock().await; + if !guard.contains_key(current_room) { + guard.insert( + current_room.clone(), + Some(CachedSpaceHierarchySummary { + summary: { + let SpaceHierarchyChildSummary { + canonical_alias, + name, + num_joined_members, + room_id, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + allowed_room_ids, + } = child; + + SpaceHierarchyParentSummary { + canonical_alias, + name, + num_joined_members, + room_id: room_id.clone(), + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + children_state: get_stripped_space_child_events(&room_id).await?.unwrap(), + allowed_room_ids, + } + }, + }), + ); } } - - self.roomid_spacechunk_cache.lock().await.insert( - current_room.clone(), - Some(CachedSpaceChunk { - chunk, - children, - join_rule: CachedJoinRule::Full(join_rule), - }), - ); - - /* TODO: - for child in response.children { - roomid_spacechunk_cache.insert( - current_room.clone(), - CachedSpaceChunk { - chunk: child.chunk, - children, - join_rule, - }, - ); - } - */ + if is_accessable_child( + current_room, + &response.room.join_rule, + &identifier, + &response.room.allowed_room_ids, + )? { + Some(SummaryAccessibility::Accessible(Box::new(summary.clone()))) + } else { + Some(SummaryAccessibility::Inaccessible) + } } else { - self.roomid_spacechunk_cache.lock().await.insert(current_room.clone(), None); - } - } - } + self.roomid_spacehierarchy_cache.lock().await.insert(current_room.clone(), None); - Ok(get_hierarchy::v1::Response { - next_batch: if results.is_empty() { - None + None + } } else { - Some((skip + results.len()).to_string()) + None }, - rooms: results, - }) + ) } - async fn get_room_chunk( - &self, sender_user: &UserId, room_id: &RoomId, children: Vec>, - ) -> Result { - Ok(SpaceHierarchyRoomsChunk { + fn get_room_summary( + &self, current_room: &OwnedRoomId, children_state: Vec>, + identifier: Identifier<'_>, + ) -> Result { + let room_id: &RoomId = current_room; + + let join_rule = services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? + .map(|s| { + serde_json::from_str(s.content.get()).map(|c: RoomJoinRulesEventContent| c.join_rule).map_err(|e| { + error!("Invalid room join rule event in database: {}", e); + Error::BadDatabase("Invalid room join rule event in database.") + }) + }) + .transpose()? + .unwrap_or(JoinRule::Invite); + + let allowed_room_ids = allowed_room_ids(join_rule.clone()); + + if !is_accessable_child(current_room, &join_rule.clone().into(), &identifier, &allowed_room_ids)? { + debug!("User is not allowed to see room {room_id}"); + // This error will be caught later + return Err(Error::BadRequest(ErrorKind::Forbidden, "User is not allowed to see the room")); + } + + let join_rule = join_rule.into(); + + Ok(SpaceHierarchyParentSummary { canonical_alias: services() .rooms .state_accessor @@ -286,73 +534,30 @@ impl Service { .try_into() .expect("user count should not be that big"), room_id: room_id.to_owned(), - topic: services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomTopic, "")? - .map_or(Ok(None), |s| { + topic: services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomTopic, "")?.map_or( + Ok(None), + |s| { serde_json::from_str(s.content.get()).map(|c: RoomTopicEventContent| Some(c.topic)).map_err(|_| { error!("Invalid room topic event in database for room {}", room_id); Error::bad_database("Invalid room topic event in database.") }) - }) - .unwrap_or(None), - world_readable: services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| { - c.history_visibility == HistoryVisibility::WorldReadable - }) - .map_err(|_| Error::bad_database("Invalid room history visibility event in database.")) - })?, - guest_can_join: services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomGuestAccess, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) - .map_err(|_| Error::bad_database("Invalid room guest access event in database.")) - })?, + }, + )?, + world_readable: world_readable(room_id)?, + guest_can_join: guest_can_join(room_id)?, avatar_url: services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomAvatar, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomAvatarEventContent| c.url) - .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) - }) - .transpose()? - // url is now an Option so we must flatten - .flatten(), - join_rule: { - let join_rule = services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()).map(|c: RoomJoinRulesEventContent| c.join_rule).map_err( - |e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") - }, - ) - }) - .transpose()? - .unwrap_or(JoinRule::Invite); - - if !self.handle_join_rule(&join_rule, sender_user, room_id)? { - debug!("User is not allowed to see room {room_id}"); - // This error will be caught later - return Err(Error::BadRequest(ErrorKind::Forbidden, "User is not allowed to see the room")); - } - - self.translate_joinrule(&join_rule)? - }, + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomAvatar, "")? + .map(|s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomAvatarEventContent| c.url) + .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) + }) + .transpose()? + // url is now an Option so we must flatten + .flatten(), + join_rule, room_type: services() .rooms .state_accessor @@ -365,57 +570,631 @@ impl Service { }) .transpose()? .and_then(|e| e.room_type), - children_state: children.into_iter().map(|pdu| pdu.to_stripped_spacechild_state_event()).collect(), + children_state, + allowed_room_ids, }) } - fn translate_joinrule(&self, join_rule: &JoinRule) -> Result { - match join_rule { - JoinRule::Invite => Ok(SpaceRoomJoinRule::Invite), - JoinRule::Knock => Ok(SpaceRoomJoinRule::Knock), - JoinRule::Private => Ok(SpaceRoomJoinRule::Private), - JoinRule::Restricted(_) => Ok(SpaceRoomJoinRule::Restricted), - JoinRule::KnockRestricted(_) => Ok(SpaceRoomJoinRule::KnockRestricted), - JoinRule::Public => Ok(SpaceRoomJoinRule::Public), - _ => Err(Error::BadServerResponse("Unknown join rule")), + pub async fn get_client_hierarchy( + &self, sender_user: &UserId, room_id: &RoomId, limit: usize, skip: usize, max_depth: usize, + suggested_only: bool, + ) -> Result { + match self + .get_summary_and_children(&room_id.to_owned(), suggested_only, Identifier::UserId(sender_user)) + .await? + { + Some(SummaryAccessibility::Accessible(summary)) => { + let mut left_to_skip = skip; + let mut arena = Arena::new(summary.room_id.clone(), max_depth); + + let mut results = Vec::new(); + let root = arena.first_untraversed().expect("The node just added is not traversed"); + + arena.push(root, get_parent_children(*summary.clone(), suggested_only)); + results.push(summary_to_chunk(*summary.clone())); + + while let Some(current_room) = arena.first_untraversed() { + if limit > results.len() { + if let Some(SummaryAccessibility::Accessible(summary)) = self + .get_summary_and_children( + &arena.get(current_room).expect("We added this node, it must exist").room_id, + suggested_only, + Identifier::UserId(sender_user), + ) + .await? + { + let children = get_parent_children(*summary.clone(), suggested_only); + arena.push(current_room, children); + + if left_to_skip > 0 { + left_to_skip -= 1 + } else { + results.push(summary_to_chunk(*summary.clone())) + } + } + } else { + break; + } + } + + Ok(client::space::get_hierarchy::v1::Response { + next_batch: if results.len() < limit { + None + } else { + let skip = UInt::new((skip + limit) as u64); + + skip.map(|skip| { + PagnationToken { + skip, + limit: UInt::new(max_depth as u64) + .expect("When sent in request it must have been valid UInt"), + max_depth: UInt::new(max_depth as u64) + .expect("When sent in request it must have been valid UInt"), + suggested_only, + } + .to_string() + }) + }, + rooms: results, + }) + }, + Some(SummaryAccessibility::Inaccessible) => { + Err(Error::BadRequest(ErrorKind::Forbidden, "The requested room is inaccessible")) + }, + None => Err(Error::BadRequest(ErrorKind::Forbidden, "The requested room was not found")), } } +} - fn handle_simplified_join_rule( - &self, join_rule: &SpaceRoomJoinRule, sender_user: &UserId, room_id: &RoomId, - ) -> Result { - let allowed = match join_rule { - SpaceRoomJoinRule::Public => true, - SpaceRoomJoinRule::Knock => true, - SpaceRoomJoinRule::Invite => services().rooms.state_cache.is_joined(sender_user, room_id)?, - _ => false, - }; +/// Simply returns the stripped m.space.child events of a room +async fn get_stripped_space_child_events( + room_id: &RoomId, +) -> Result>>, Error> { + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + let state = services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; + let mut children_pdus = Vec::new(); + for (key, id) in state { + let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?; + if event_type != StateEventType::SpaceChild { + continue; + } - Ok(allowed) - } + let pdu = services() + .rooms + .timeline + .get_pdu(&id)? + .ok_or_else(|| Error::bad_database("Event in space state not found"))?; - fn handle_join_rule(&self, join_rule: &JoinRule, sender_user: &UserId, room_id: &RoomId) -> Result { - if self.handle_simplified_join_rule(&self.translate_joinrule(join_rule)?, sender_user, room_id)? { - return Ok(true); + if serde_json::from_str::(pdu.content.get()) + .ok() + .map(|c| c.via) + .map_or(true, |v| v.is_empty()) + { + continue; + } + + if OwnedRoomId::try_from(state_key).is_ok() { + children_pdus.push(pdu.to_stripped_spacechild_state_event()); + } } + Ok(Some(children_pdus)) + } else { + Ok(None) + } +} - match join_rule { - JoinRule::Restricted(r) => { - for rule in &r.allow { - if let AllowRule::RoomMembership(rm) = rule { - if let Ok(true) = services().rooms.state_cache.is_joined(sender_user, &rm.room_id) { +/// With the given identifier, checks if a room is accessable +fn is_accessable_child( + current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, + allowed_room_ids: &Vec, +) -> Result { + is_accessable_child_recurse(current_room, join_rule, identifier, allowed_room_ids, 0) +} + +fn is_accessable_child_recurse( + current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, + allowed_room_ids: &Vec, recurse_num: usize, +) -> Result { + // Set limit at 10, as we cannot keep going up parents forever + if recurse_num < 10 { + match identifier { + Identifier::ServerName(server_name) => { + let room_id: &RoomId = current_room; + + // Checks if ACLs allow for the server to participate + if services().rooms.event_handler.acl_check(server_name, room_id).is_err() { + return Ok(false); + } + }, + Identifier::UserId(user_id) => { + if services().rooms.state_cache.is_joined(user_id, current_room)? + || services().rooms.state_cache.is_invited(user_id, current_room)? + { + return Ok(true); + } + }, + _ => (), + } // Takes care of joinrules + Ok(match join_rule { + SpaceRoomJoinRule::KnockRestricted | SpaceRoomJoinRule::Restricted => { + for room in allowed_room_ids { + if let Ok((join_rule, allowed_room_ids)) = get_join_rule(room) { + // Recursive, get rid of if possible + if let Ok(true) = is_accessable_child_recurse( + room, + &join_rule, + identifier, + &allowed_room_ids, + recurse_num + 1, + ) { return Ok(true); } } } - - Ok(false) + false }, - JoinRule::KnockRestricted(_) => { - // TODO: Check rules - Ok(false) - }, - _ => Ok(false), - } + SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock => true, + SpaceRoomJoinRule::Invite | SpaceRoomJoinRule::Private => false, + // Custom join rule + _ => false, + }) + } else { + // If you need to go up 10 parents, we just assume it is inaccessable + Ok(false) + } +} + +/// Checks if guests are able to join a given room +fn guest_can_join(room_id: &RoomId) -> Result { + services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomGuestAccess, "")?.map_or( + Ok(false), + |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) + .map_err(|_| Error::bad_database("Invalid room guest access event in database.")) + }, + ) +} + +/// Checks if guests are able to view room content without joining +fn world_readable(room_id: &RoomId) -> Result { + services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?.map_or( + Ok(false), + |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable) + .map_err(|_| Error::bad_database("Invalid room history visibility event in database.")) + }, + ) +} + +/// Returns the join rule for a given room +fn get_join_rule(current_room: &RoomId) -> Result<(SpaceRoomJoinRule, Vec), Error> { + Ok(services() + .rooms + .state_accessor + .room_state_get(current_room, &StateEventType::RoomJoinRules, "")? + .map(|s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomJoinRulesEventContent| (c.join_rule.clone().into(), allowed_room_ids(c.join_rule))) + .map_err(|e| { + error!("Invalid room join rule event in database: {}", e); + Error::BadDatabase("Invalid room join rule event in database.") + }) + }) + .transpose()? + .unwrap_or((SpaceRoomJoinRule::Invite, vec![]))) +} + +// Here because cannot implement `From` across ruma-federation-api and +// ruma-client-api types +fn summary_to_chunk(summary: SpaceHierarchyParentSummary) -> SpaceHierarchyRoomsChunk { + let SpaceHierarchyParentSummary { + canonical_alias, + name, + num_joined_members, + room_id, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + children_state, + .. + } = summary; + + SpaceHierarchyRoomsChunk { + canonical_alias, + name, + num_joined_members, + room_id, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + children_state, + } +} + +/// Returns an empty vec if not a restricted room +fn allowed_room_ids(join_rule: JoinRule) -> Vec { + let mut room_ids = vec![]; + if let JoinRule::Restricted(r) | JoinRule::KnockRestricted(r) = join_rule { + for rule in r.allow { + if let AllowRule::RoomMembership(RoomMembership { + room_id: membership, + }) = rule + { + room_ids.push(membership.to_owned()); + } + } + } + room_ids +} + +/// Returns the children of a SpaceHierarchyParentSummary, making use of the +/// children_state field +fn get_parent_children(parent: SpaceHierarchyParentSummary, suggested_only: bool) -> Vec { + parent + .children_state + .iter() + .filter_map(|raw_ce| { + raw_ce.deserialize().map_or(None, |ce| { + if suggested_only && !ce.content.suggested { + None + } else { + Some(ce.state_key) + } + }) + }) + .collect() +} + +#[cfg(test)] +mod tests { + use ruma::{ + api::federation::space::SpaceHierarchyParentSummaryInit, events::room::join_rules::Restricted, owned_room_id, + }; + + use super::*; + + fn first(arena: &mut Arena, room_id: OwnedRoomId) { + let first_untrav = arena.first_untraversed().unwrap(); + + assert_eq!(arena.get(first_untrav).unwrap().room_id, room_id); + } + + #[test] + fn zero_depth() { + let mut arena = Arena::new(owned_room_id!("!foo:example.org"), 0); + + assert_eq!(arena.first_untraversed(), None); + } + + #[test] + fn two_depth() { + let mut arena = Arena::new(owned_room_id!("!root:example.org"), 2); + + let root = arena.first_untraversed().unwrap(); + arena.push( + root, + vec![ + owned_room_id!("!subspace1:example.org"), + owned_room_id!("!subspace2:example.org"), + owned_room_id!("!foo:example.org"), + ], + ); + + let subspace1 = arena.first_untraversed().unwrap(); + let subspace2 = arena.first_untraversed().unwrap(); + + arena.push( + subspace1, + vec![owned_room_id!("!room1:example.org"), owned_room_id!("!room2:example.org")], + ); + + first(&mut arena, owned_room_id!("!room1:example.org")); + first(&mut arena, owned_room_id!("!room2:example.org")); + + arena.push( + subspace2, + vec![owned_room_id!("!room3:example.org"), owned_room_id!("!room4:example.org")], + ); + + first(&mut arena, owned_room_id!("!room3:example.org")); + first(&mut arena, owned_room_id!("!room4:example.org")); + + let foo_node = NodeId { + index: 1, + }; + + assert_eq!(arena.first_untraversed(), Some(foo_node)); + assert_eq!( + arena.get(foo_node).map(|node| node.room_id.clone()), + Some(owned_room_id!("!foo:example.org")) + ); + } + + #[test] + fn empty_push() { + let mut arena = Arena::new(owned_room_id!("!root:example.org"), 5); + + let root = arena.first_untraversed().unwrap(); + arena.push( + root, + vec![owned_room_id!("!room1:example.org"), owned_room_id!("!room2:example.org")], + ); + + let room1 = arena.first_untraversed().unwrap(); + arena.push(room1, vec![]); + + first(&mut arena, owned_room_id!("!room2:example.org")); + assert!(arena.first_untraversed().is_none()); + } + + #[test] + fn beyond_max_depth() { + let mut arena = Arena::new(owned_room_id!("!root:example.org"), 0); + + let root = NodeId { + index: 0, + }; + + arena.push(root, vec![owned_room_id!("!too_deep:example.org")]); + + assert_eq!(arena.first_child(root), None); + assert_eq!(arena.nodes.len(), 1); + } + + #[test] + fn order_check() { + let mut arena = Arena::new(owned_room_id!("!root:example.org"), 3); + + let root = arena.first_untraversed().unwrap(); + arena.push( + root, + vec![ + owned_room_id!("!subspace1:example.org"), + owned_room_id!("!subspace2:example.org"), + owned_room_id!("!foo:example.org"), + ], + ); + + let subspace1 = arena.first_untraversed().unwrap(); + arena.push( + subspace1, + vec![ + owned_room_id!("!room1:example.org"), + owned_room_id!("!room3:example.org"), + owned_room_id!("!room5:example.org"), + ], + ); + + first(&mut arena, owned_room_id!("!room1:example.org")); + first(&mut arena, owned_room_id!("!room3:example.org")); + first(&mut arena, owned_room_id!("!room5:example.org")); + + let subspace2 = arena.first_untraversed().unwrap(); + + assert_eq!(arena.get(subspace2).unwrap().room_id, owned_room_id!("!subspace2:example.org")); + + arena.push( + subspace2, + vec![owned_room_id!("!room1:example.org"), owned_room_id!("!room2:example.org")], + ); + + first(&mut arena, owned_room_id!("!room1:example.org")); + first(&mut arena, owned_room_id!("!room2:example.org")); + first(&mut arena, owned_room_id!("!foo:example.org")); + + assert_eq!(arena.first_untraversed(), None); + } + + #[test] + fn get_summary_children() { + let summary: SpaceHierarchyParentSummary = SpaceHierarchyParentSummaryInit { + num_joined_members: UInt::from(1_u32), + room_id: owned_room_id!("!root:example.org"), + world_readable: true, + guest_can_join: true, + join_rule: SpaceRoomJoinRule::Public, + children_state: vec![ + serde_json::from_str( + r#"{ + "content": { + "via": [ + "example.org" + ], + "suggested": false + }, + "origin_server_ts": 1629413349153, + "sender": "@alice:example.org", + "state_key": "!foo:example.org", + "type": "m.space.child" + }"#, + ) + .unwrap(), + serde_json::from_str( + r#"{ + "content": { + "via": [ + "example.org" + ], + "suggested": true + }, + "origin_server_ts": 1629413349157, + "sender": "@alice:example.org", + "state_key": "!bar:example.org", + "type": "m.space.child" + }"#, + ) + .unwrap(), + serde_json::from_str( + r#"{ + "content": { + "via": [ + "example.org" + ] + }, + "origin_server_ts": 1629413349160, + "sender": "@alice:example.org", + "state_key": "!baz:example.org", + "type": "m.space.child" + }"#, + ) + .unwrap(), + ], + allowed_room_ids: vec![], + } + .into(); + + assert_eq!( + get_parent_children(summary.clone(), false), + vec![ + owned_room_id!("!foo:example.org"), + owned_room_id!("!bar:example.org"), + owned_room_id!("!baz:example.org") + ] + ); + assert_eq!(get_parent_children(summary, true), vec![owned_room_id!("!bar:example.org")]); + } + + #[test] + fn allowed_room_ids_rom_join_rule() { + let restricted_join_rule = JoinRule::Restricted(Restricted { + allow: vec![ + AllowRule::RoomMembership(RoomMembership { + room_id: owned_room_id!("!foo:example.org"), + }), + AllowRule::RoomMembership(RoomMembership { + room_id: owned_room_id!("!bar:example.org"), + }), + AllowRule::RoomMembership(RoomMembership { + room_id: owned_room_id!("!baz:example.org"), + }), + ], + }); + + let invite_join_rule = JoinRule::Invite; + + assert_eq!( + allowed_room_ids(restricted_join_rule), + vec![ + owned_room_id!("!foo:example.org"), + owned_room_id!("!bar:example.org"), + owned_room_id!("!baz:example.org") + ] + ); + + let empty_vec: Vec = vec![]; + + assert_eq!(allowed_room_ids(invite_join_rule), empty_vec); + } + + #[test] + fn invalid_pagnation_tokens() { + fn token_is_err(token: &str) { + let token: Result = PagnationToken::from_str(token); + assert!(token.is_err()); + } + + token_is_err("231_2_noabool"); + token_is_err(""); + token_is_err("111_3_"); + token_is_err("foo_not_int"); + token_is_err("11_4_true_"); + token_is_err("___"); + token_is_err("__false"); + } + + #[test] + fn valid_pagnation_tokens() { + assert_eq!( + PagnationToken { + skip: UInt::from(40_u32), + limit: UInt::from(20_u32), + max_depth: UInt::from(1_u32), + suggested_only: true + }, + PagnationToken::from_str("40_20_1_true").unwrap() + ); + + assert_eq!( + PagnationToken { + skip: UInt::from(27645_u32), + limit: UInt::from(97_u32), + max_depth: UInt::from(10539_u32), + suggested_only: false + }, + PagnationToken::from_str("27645_97_10539_false").unwrap() + ); + } + + #[test] + fn pagnation_token_to_string() { + assert_eq!( + PagnationToken { + skip: UInt::from(27645_u32), + limit: UInt::from(97_u32), + max_depth: UInt::from(9420_u32), + suggested_only: false + } + .to_string(), + "27645_97_9420_false" + ); + + assert_eq!( + PagnationToken { + skip: UInt::from(12_u32), + limit: UInt::from(3_u32), + max_depth: UInt::from(1_u32), + suggested_only: true + } + .to_string(), + "12_3_1_true" + ); + } + + #[test] + fn forbid_recursion() { + let mut arena = Arena::new(owned_room_id!("!root:example.org"), 5); + let root_node_id = arena.first_untraversed().unwrap(); + + arena.push( + root_node_id, + vec![ + owned_room_id!("!subspace1:example.org"), + owned_room_id!("!room1:example.org"), + owned_room_id!("!subspace2:example.org"), + ], + ); + + let subspace1_node_id = arena.first_untraversed().unwrap(); + arena.push( + subspace1_node_id, + vec![owned_room_id!("!subspace2:example.org"), owned_room_id!("!room1:example.org")], + ); + + let subspace2_node_id = arena.first_untraversed().unwrap(); + // Here, both subspaces should be ignored and not added, as they are both + // parents of subspace2 + arena.push( + subspace2_node_id, + vec![ + owned_room_id!("!subspace1:example.org"), + owned_room_id!("!subspace2:example.org"), + owned_room_id!("!room1:example.org"), + ], + ); + + assert_eq!(arena.nodes.len(), 7); + first(&mut arena, owned_room_id!("!room1:example.org")); + first(&mut arena, owned_room_id!("!room1:example.org")); + first(&mut arena, owned_room_id!("!room1:example.org")); + first(&mut arena, owned_room_id!("!subspace2:example.org")); + assert!(arena.first_untraversed().is_none()); } } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 3273d9c2..db2c1921 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -74,7 +74,7 @@ impl Service { .await?; }, TimelineEventType::SpaceChild => { - services().rooms.spaces.roomid_spacechunk_cache.lock().await.remove(&pdu.room_id); + services().rooms.spaces.roomid_spacehierarchy_cache.lock().await.remove(&pdu.room_id); }, _ => continue, } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 92f0a09d..0d9ed3be 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -418,7 +418,7 @@ impl Service { }, TimelineEventType::SpaceChild => { if let Some(_state_key) = &pdu.state_key { - services().rooms.spaces.roomid_spacechunk_cache.lock().await.remove(&pdu.room_id); + services().rooms.spaces.roomid_spacehierarchy_cache.lock().await.remove(&pdu.room_id); } }, TimelineEventType::RoomMember => { From a33b33cab5becafade100faea6da85cd21329492 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 16 Mar 2024 18:34:54 -0400 Subject: [PATCH 08/20] document forbidden room aliases and usernames Signed-off-by: strawberry --- conduwuit-example.toml | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/conduwuit-example.toml b/conduwuit-example.toml index 6bc7793b..3fdf279d 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -161,7 +161,19 @@ registration_token = "change this token for something specific to your server" # controls whether non-admin local users are forbidden from sending room invites (local and remote), # and if non-admin users can receive remote room invites. admins are always allowed to send and receive all room invites. # defaults to false -# block_non_admin_invites = falsse +# block_non_admin_invites = false + +# List of forbidden username patterns/strings. Values in this list are matched as *contains*. +# This is checked upon username availability check, registration, and startup as warnings if any local users in your database +# have a forbidden username. +# No default. +# forbidden_usernames = [] + +# List of forbidden room aliases and room IDs as patterns/strings. Values in this list are matched as *contains*. +# This is checked upon room alias creation, custom room ID creation if used, and startup as warnings if any room aliases +# in your database have a forbidden room alias/ID. +# No default. +# forbidden_room_names = [] # Set this to true to allow your server's public room directory to be federated. # Set this to false to protect against /publicRooms spiders, but will forbid external users @@ -387,4 +399,4 @@ url_preview_check_root_domain = false # Whether to listen and allow for HTTP and HTTPS connections (insecure!) # This config option is only available if conduwuit was built with `axum_dual_protocol` feature (not default feature) # Defaults to false -#dual_protocol = false \ No newline at end of file +#dual_protocol = false From 6d4163d4105cc5765efb5bb7be1e4d4d019f6e11 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 17 Mar 2024 01:42:30 -0400 Subject: [PATCH 09/20] track media uploads by user Signed-off-by: strawberry --- src/api/client_server/media.rs | 11 ++++++++++- src/database/key_value/media.rs | 27 ++++++++++++++++++++++----- src/database/mod.rs | 2 ++ src/service/media/data.rs | 3 ++- src/service/media/mod.rs | 24 ++++++++++++++++++------ 5 files changed, 54 insertions(+), 13 deletions(-) diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index 380d4e3c..bb98814b 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -138,6 +138,8 @@ pub async fn get_media_preview_v1_route( /// - Some metadata will be saved in the database /// - Media will be saved in the media/ directory pub async fn create_content_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let mxc = format!( "mxc://{}/{}", services().globals.server_name(), @@ -147,6 +149,7 @@ pub async fn create_content_route(body: Ruma) -> Re services() .media .create( + Some(sender_user.clone()), mxc.clone(), body.filename.as_ref().map(|filename| "inline; filename=".to_owned() + filename).as_deref(), body.content_type.as_deref(), @@ -175,6 +178,8 @@ pub async fn create_content_route(body: Ruma) -> Re pub async fn create_content_v1_route( body: Ruma, ) -> Result> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let mxc = format!( "mxc://{}/{}", services().globals.server_name(), @@ -184,6 +189,7 @@ pub async fn create_content_v1_route( services() .media .create( + Some(sender_user.clone()), mxc.clone(), body.filename.as_ref().map(|filename| "inline; filename=".to_owned() + filename).as_deref(), body.content_type.as_deref(), @@ -231,6 +237,7 @@ pub async fn get_remote_content( services() .media .create( + None, mxc.to_owned(), content_response.content_disposition.as_deref(), content_response.content_type.as_deref(), @@ -484,6 +491,7 @@ pub async fn get_content_thumbnail_route( services() .media .upload_thumbnail( + None, mxc, None, get_thumbnail_response.content_type.as_deref(), @@ -566,6 +574,7 @@ pub async fn get_content_thumbnail_v1_route( services() .media .upload_thumbnail( + None, mxc, None, get_thumbnail_response.content_type.as_deref(), @@ -589,7 +598,7 @@ async fn download_image(client: &reqwest::Client, url: &str) -> Result (None, None), diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs index f00f6b55..af7a883a 100644 --- a/src/database/key_value/media.rs +++ b/src/database/key_value/media.rs @@ -4,12 +4,14 @@ use tracing::debug; use crate::{ database::KeyValueDatabase, service::{self, media::UrlPreviewData}, - utils, Error, Result, + utils::string_from_bytes, + Error, Result, }; impl service::media::Data for KeyValueDatabase { fn create_file_metadata( - &self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>, + &self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, + content_type: Option<&str>, ) -> Result> { let mut key = mxc.as_bytes().to_vec(); key.push(0xFF); @@ -22,6 +24,12 @@ impl service::media::Data for KeyValueDatabase { self.mediaid_file.insert(&key, &[])?; + if let Some(user) = sender_user { + let key = mxc.as_bytes().to_vec(); + let user = user.as_bytes().to_vec(); + self.mediaid_user.insert(&key, &user)?; + } + Ok(key) } @@ -31,13 +39,22 @@ impl service::media::Data for KeyValueDatabase { let mut prefix = mxc.as_bytes().to_vec(); prefix.push(0xFF); - debug!("MXC db prefix: {:?}", prefix); + debug!("MXC db prefix: {prefix:?}"); for (key, _) in self.mediaid_file.scan_prefix(prefix) { debug!("Deleting key: {:?}", key); self.mediaid_file.remove(&key)?; } + for (key, value) in self.mediaid_user.scan_prefix(mxc.as_bytes().to_vec()) { + if key == mxc.as_bytes().to_vec() { + let user = string_from_bytes(&value).unwrap_or_default(); + + debug!("Deleting key \"{key:?}\" which was uploaded by user {user}"); + self.mediaid_user.remove(&key)?; + } + } + Ok(()) } @@ -85,7 +102,7 @@ impl service::media::Data for KeyValueDatabase { let content_type = parts .next() .map(|bytes| { - utils::string_from_bytes(bytes) + string_from_bytes(bytes) .map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode.")) }) .transpose()?; @@ -97,7 +114,7 @@ impl service::media::Data for KeyValueDatabase { None } else { Some( - utils::string_from_bytes(content_disposition_bytes) + string_from_bytes(content_disposition_bytes) .map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?, ) }; diff --git a/src/database/mod.rs b/src/database/mod.rs index 994b9273..e2b25bf0 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -157,6 +157,7 @@ pub struct KeyValueDatabase { //pub media: media::Media, pub(super) mediaid_file: Arc, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType pub(super) url_previews: Arc, + pub(super) mediaid_user: Arc, //pub key_backups: key_backups::KeyBackups, pub(super) backupid_algorithm: Arc, // BackupId = UserId + Version(Count) pub(super) backupid_etag: Arc, // BackupId = UserId + Version(Count) @@ -365,6 +366,7 @@ impl KeyValueDatabase { roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, mediaid_file: builder.open_tree("mediaid_file")?, url_previews: builder.open_tree("url_previews")?, + mediaid_user: builder.open_tree("mediaid_user")?, backupid_algorithm: builder.open_tree("backupid_algorithm")?, backupid_etag: builder.open_tree("backupid_etag")?, backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, diff --git a/src/service/media/data.rs b/src/service/media/data.rs index 9da50860..7cbde755 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -2,7 +2,8 @@ use crate::Result; pub trait Data: Send + Sync { fn create_file_metadata( - &self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>, + &self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, + content_type: Option<&str>, ) -> Result>; fn delete_file_mxc(&self, mxc: String) -> Result<()>; diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index a4e78378..696fa9f0 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, io::Cursor, sync::Arc, time::SystemTime}; pub(crate) use data::Data; use image::imageops::FilterType; -use ruma::OwnedMxcUri; +use ruma::{OwnedMxcUri, OwnedUserId}; use serde::Serialize; use tokio::{ fs::{self, File}, @@ -45,10 +45,15 @@ pub struct Service { impl Service { /// Uploads a file. pub async fn create( - &self, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, file: &[u8], + &self, sender_user: Option, mxc: String, content_disposition: Option<&str>, + content_type: Option<&str>, file: &[u8], ) -> Result<()> { // Width, Height = 0 if it's not a thumbnail - let key = self.db.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?; + let key = if let Some(user) = sender_user { + self.db.create_file_metadata(Some(user.as_str()), mxc, 0, 0, content_disposition, content_type)? + } else { + self.db.create_file_metadata(None, mxc, 0, 0, content_disposition, content_type)? + }; let path; @@ -106,11 +111,17 @@ impl Service { } /// Uploads or replaces a file thumbnail. + #[allow(clippy::too_many_arguments)] pub async fn upload_thumbnail( - &self, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, width: u32, height: u32, - file: &[u8], + &self, sender_user: Option, mxc: String, content_disposition: Option<&str>, + content_type: Option<&str>, width: u32, height: u32, file: &[u8], ) -> Result<()> { - let key = self.db.create_file_metadata(mxc, width, height, content_disposition, content_type)?; + let key = if let Some(user) = sender_user { + self.db.create_file_metadata(Some(user.as_str()), mxc, width, height, content_disposition, content_type)? + } else { + self.db.create_file_metadata(None, mxc, width, height, content_disposition, content_type)? + }; + let path; #[allow(clippy::unnecessary_operation)] // error[E0658]: attributes on expressions are experimental @@ -403,6 +414,7 @@ impl Service { // Save thumbnail in database so we don't have to generate it again next time let thumbnail_key = self.db.create_file_metadata( + None, mxc, width, height, From 70b1bdd65588876ec1d1aad3f19caa838103dc91 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 17 Mar 2024 01:55:09 -0400 Subject: [PATCH 10/20] slight inclusive wording changes Signed-off-by: strawberry --- src/api/server_server.rs | 2 +- src/database/key_value/globals.rs | 2 +- src/service/media/mod.rs | 4 ++-- src/service/sending/mod.rs | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 4f28e271..4edfd09d 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1540,7 +1540,7 @@ pub async fn create_invite_route(body: Ruma) -> Resu let mut event: JsonObject = serde_json::from_str(body.event.get()) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?; - event.insert("event_id".to_owned(), "$dummy".into()); + event.insert("event_id".to_owned(), "$placeholder".into()); let pdu: PduEvent = serde_json::from_value(event.into()).map_err(|e| { warn!("Invalid invite event: {}", e); diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index be5796ed..48e55578 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -52,7 +52,7 @@ impl service::globals::Data for KeyValueDatabase { let mut futures = FuturesUnordered::new(); - // Return when *any* user changed his key + // Return when *any* user changed their key // TODO: only send for user they share a room with futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix)); diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 696fa9f0..767fee9a 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -194,8 +194,8 @@ impl Service { debug!("Full MXC key from database: {:?}", key); // we need to get the MXC URL from the first part of the key (the first 0xff / - // 255 push) this code does look kinda crazy but blame conduit for using magic - // keys + // 255 push). this is all necessary because of conduit using magic keys for + // media let mut parts = key.split(|&b| b == 0xFF); let mxc = parts .next() diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 529e1631..fae05404 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -360,11 +360,11 @@ impl Service { for user_id in device_list_changes { // Empty prev id forces synapse to resync: https://github.com/matrix-org/synapse/blob/98aec1cc9da2bd6b8e34ffb282c85abf9b8b42ca/synapse/handlers/device.py#L767 - // Because synapse resyncs, we can just insert dummy data + // Because synapse resyncs, we can just insert placeholder data let edu = Edu::DeviceListUpdate(DeviceListUpdateContent { user_id, - device_id: device_id!("dummy").to_owned(), - device_display_name: Some("Dummy".to_owned()), + device_id: device_id!("placeholder").to_owned(), + device_display_name: Some("Placeholder".to_owned()), stream_id: uint!(1), prev_id: Vec::new(), deleted: None, From e982428f073f527f472e1f1fcec2b9d15785baac Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 17 Mar 2024 01:57:16 -0400 Subject: [PATCH 11/20] bump async-trait and ruma Signed-off-by: strawberry --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 20c83ab1..68f8d9af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -96,9 +96,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.77" +version = "0.1.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" +checksum = "461abc97219de0eaaf81fe3ef974a540158f3d079c2ab200f891f1a2ef201e85" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 821cee14..c241b1fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,7 +48,7 @@ serde_html_form = "0.2.5" hmac = "0.12.1" sha-1 = "0.10.1" -async-trait = "0.1.77" +async-trait = "0.1.78" # used for checking if an IP is in specific subnets / CIDR ranges easier ipaddress = "0.1.3" From 73c42991e910c831bb22de1abaa760189d367952 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 17 Mar 2024 02:02:14 -0400 Subject: [PATCH 12/20] clear dns and tls-override caches from !admin command. Signed-off-by: Jason Volk Signed-off-by: strawberry --- src/service/mod.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/service/mod.rs b/src/service/mod.rs index 8d37061c..1667dc81 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -207,5 +207,11 @@ roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache}" if amount > 5 { self.rooms.spaces.roomid_spacehierarchy_cache.lock().await.clear(); } + if amount > 6 { + self.globals.tls_name_override.write().unwrap().clear(); + } + if amount > 7 { + self.globals.dns_resolver().clear_cache(); + } } } From 72c97434b050ea6b9775dc1753c23d3f4d68252e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 17 Mar 2024 02:04:05 -0400 Subject: [PATCH 13/20] add remove_batch with transaction to database abstraction. adjusted to make building sqlite happy again Signed-off-by: Jason Volk Signed-off-by: strawberry --- src/database/abstraction.rs | 4 ++++ src/database/abstraction/rocksdb.rs | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index dd331ed3..fdf93089 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -43,6 +43,10 @@ pub(crate) trait KvTree: Send + Sync { fn remove(&self, key: &[u8]) -> Result<()>; + #[allow(dead_code)] + #[cfg(feature = "rocksdb")] + fn remove_batch(&self, _iter: &mut dyn Iterator>) -> Result<()> { unimplemented!() } + fn iter<'a>(&'a self) -> Box, Vec)> + 'a>; fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box, Vec)> + 'a>; diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index 584c6e70..729c34d2 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -251,6 +251,18 @@ impl KvTree for RocksDbEngineTree<'_> { Ok(self.db.rocks.delete_cf_opt(&self.cf(), key, &writeoptions)?) } + fn remove_batch(&self, iter: &mut dyn Iterator>) -> Result<()> { + let writeoptions = rust_rocksdb::WriteOptions::default(); + + let mut batch = WriteBatchWithTransaction::::default(); + + for key in iter { + batch.delete_cf(&self.cf(), key); + } + + Ok(self.db.rocks.write_opt(batch, &writeoptions)?) + } + fn iter<'a>(&'a self) -> Box, Vec)> + 'a> { let mut readoptions = rust_rocksdb::ReadOptions::default(); readoptions.set_total_order_seek(true); From 3af303e52bc2908d3c69fecf62e98159745e3610 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 17 Mar 2024 02:20:23 -0400 Subject: [PATCH 14/20] complete federation destination caching preempting getaddrinfo(3). fixed some clippy lints and spacing adjusted Signed-off-by: Jason Volk Signed-off-by: strawberry --- Cargo.lock | 1 + Cargo.toml | 1 + src/api/server_server.rs | 78 ++++++++++++++++++++++---------------- src/service/globals/mod.rs | 2 + 4 files changed, 50 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 68f8d9af..ddc84c1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2062,6 +2062,7 @@ dependencies = [ "tokio-rustls", "tokio-socks", "tower-service", + "trust-dns-resolver", "url", "wasm-bindgen", "wasm-bindgen-futures", diff --git a/Cargo.toml b/Cargo.toml index c241b1fc..0c84c897 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,6 +106,7 @@ default-features = false features = [ "rustls-tls-native-roots", "socks", + "trust-dns", ] # all the serde stuff diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 4edfd09d..07be27f4 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -365,7 +365,10 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe None => { if let Some(pos) = destination_str.find(':') { debug!("2: Hostname with included port"); + let (host, port) = destination_str.split_at(pos); + query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await; + FedDest::Named(host.to_owned(), port.to_owned()) } else { debug!("Requesting well known for {destination}"); @@ -378,30 +381,23 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe None => { if let Some(pos) = delegated_hostname.find(':') { debug!("3.2: Hostname with port in .well-known file"); + let (host, port) = delegated_hostname.split_at(pos); + query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await; + FedDest::Named(host.to_owned(), port.to_owned()) } else { debug!("Delegated hostname has no port in this branch"); if let Some(hostname_override) = query_srv_record(&delegated_hostname).await { debug!("3.3: SRV lookup successful"); - let force_port = hostname_override.port(); - if let Ok(override_ip) = services() - .globals - .dns_resolver() - .lookup_ip(hostname_override.hostname()) - .await - { - services().globals.tls_name_override.write().unwrap().insert( - delegated_hostname.clone(), - (override_ip.iter().collect(), force_port.unwrap_or(8448)), - ); - } else { - debug!( - "Using SRV record {}, but could not resolve to IP", - hostname_override.hostname() - ); - } + let force_port = hostname_override.port(); + query_and_cache_override( + &delegated_hostname, + &hostname_override.hostname(), + force_port.unwrap_or(8448), + ) + .await; if let Some(port) = force_port { FedDest::Named(delegated_hostname, format!(":{port}")) @@ -410,6 +406,7 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe } } else { debug!("3.4: No SRV records, just use the hostname from .well-known"); + query_and_cache_override(&delegated_hostname, &delegated_hostname, 8448).await; add_port_to_hostname(&delegated_hostname) } } @@ -421,21 +418,14 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe match query_srv_record(&destination_str).await { Some(hostname_override) => { debug!("4: SRV record found"); - let force_port = hostname_override.port(); - if let Ok(override_ip) = - services().globals.dns_resolver().lookup_ip(hostname_override.hostname()).await - { - services().globals.tls_name_override.write().unwrap().insert( - hostname.clone(), - (override_ip.iter().collect(), force_port.unwrap_or(8448)), - ); - } else { - debug!( - "Using SRV record {}, but could not resolve to IP", - hostname_override.hostname() - ); - } + let force_port = hostname_override.port(); + query_and_cache_override( + &hostname, + &hostname_override.hostname(), + force_port.unwrap_or(8448), + ) + .await; if let Some(port) = force_port { FedDest::Named(hostname.clone(), format!(":{port}")) @@ -445,6 +435,7 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe }, None => { debug!("5: No SRV record found"); + query_and_cache_override(&destination_str, &destination_str, 8448).await; add_port_to_hostname(&destination_str) }, } @@ -453,7 +444,6 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe } }, }; - debug!("Actual destination: {actual_destination:?}"); // Can't use get_ip_with_port here because we don't want to add a port // to an IP address if it wasn't specified @@ -467,9 +457,29 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe } else { FedDest::Named(hostname, ":8448".to_owned()) }; + + debug!("Actual destination: {actual_destination:?} hostname: {hostname:?}"); (actual_destination, hostname) } +async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) { + match services().globals.dns_resolver().lookup_ip(hostname.to_owned()).await { + Ok(override_ip) => { + debug!("Caching result of {:?} overriding {:?}", hostname, overname); + + services() + .globals + .tls_name_override + .write() + .unwrap() + .insert(overname.to_owned(), (override_ip.iter().collect(), port)); + }, + Err(e) => { + debug!("Got {:?} for {:?} to override {:?}", e.kind(), hostname, overname); + }, + } +} + async fn query_srv_record(hostname: &'_ str) -> Option { fn handle_successful_srv(srv: &SrvLookup) -> Option { srv.iter().next().map(|result| { @@ -501,6 +511,10 @@ async fn query_srv_record(hostname: &'_ str) -> Option { } async fn request_well_known(destination: &str) -> Option { + if !services().globals.tls_name_override.read().unwrap().contains_key(destination) { + query_and_cache_override(destination, destination, 8448).await; + } + let response = services() .globals .default_client() diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 8313dc0e..f54f0686 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -495,6 +495,7 @@ fn reqwest_client_builder(config: &Config) -> Result { }); let mut reqwest_client_builder = reqwest::Client::builder() + .trust_dns(true) .pool_max_idle_per_host(0) .connect_timeout(Duration::from_secs(60)) .timeout(Duration::from_secs(60 * 5)) @@ -522,6 +523,7 @@ fn url_preview_reqwest_client_builder(config: &Config) -> Result Date: Sun, 17 Mar 2024 02:25:50 -0400 Subject: [PATCH 15/20] add flush suite to sending service; trigger on read receipts. Signed-off-by: Jason Volk Signed-off-by: strawberry --- src/api/client_server/read_marker.rs | 4 +++ src/database/key_value/sending.rs | 4 +++ src/service/sending/mod.rs | 45 +++++++++++++++++++++++++--- 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs index 182748d6..a6097c17 100644 --- a/src/api/client_server/read_marker.rs +++ b/src/api/client_server/read_marker.rs @@ -81,6 +81,8 @@ pub async fn set_read_marker_route(body: Ruma) -> room_id: body.room_id.clone(), }, )?; + + services().sending.flush_room(&body.room_id)?; } Ok(set_read_marker::v3::Response {}) @@ -136,6 +138,8 @@ pub async fn create_receipt_route(body: Ruma) -> Re room_id: body.room_id.clone(), }, )?; + + services().sending.flush_room(&body.room_id)?; }, create_receipt::v3::ReceiptType::ReadPrivate => { let count = services() diff --git a/src/database/key_value/sending.rs b/src/database/key_value/sending.rs index a3ede405..5087cbe7 100644 --- a/src/database/key_value/sending.rs +++ b/src/database/key_value/sending.rs @@ -90,6 +90,10 @@ impl service::sending::Data for KeyValueDatabase { fn mark_as_active(&self, events: &[(SendingEventType, Vec)]) -> Result<()> { for (e, key) in events { + if key.is_empty() { + continue; + } + let value = if let SendingEventType::Edu(value) = &e { &**value } else { diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index fae05404..8c8d0243 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -25,7 +25,7 @@ use ruma::{ }, device_id, events::{push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType}, - push, uint, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, ServerName, UInt, UserId, + push, uint, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, RoomId, ServerName, UInt, UserId, }; use tokio::{ select, @@ -80,6 +80,7 @@ impl OutgoingKind { pub enum SendingEventType { Pdu(Vec), // pduid Edu(Vec), // pdu json + Flush, // none } pub struct Service { @@ -237,9 +238,11 @@ impl Service { events.push(e); } } else { - self.db.mark_as_active(&new_events)?; - for (e, _) in new_events { - events.push(e); + if !new_events.is_empty() { + self.db.mark_as_active(&new_events)?; + for (e, _) in new_events { + events.push(e); + } } if let OutgoingKind::Normal(server_name) = outgoing_kind { @@ -421,6 +424,29 @@ impl Service { Ok(()) } + #[tracing::instrument(skip(self, room_id))] + pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { + let servers: HashSet = + services().rooms.state_cache.room_servers(room_id).filter_map(std::result::Result::ok).collect(); + + self.flush_servers(servers.into_iter()) + } + + #[tracing::instrument(skip(self, servers))] + pub fn flush_servers>(&self, servers: I) -> Result<()> { + let requests = servers + .into_iter() + .filter(|server| server != services().globals.server_name()) + .map(OutgoingKind::Normal) + .collect::>(); + + for outgoing_kind in requests.into_iter() { + self.sender.send((outgoing_kind, SendingEventType::Flush, Vec::::new())).unwrap(); + } + + Ok(()) + } + /// Cleanup event data /// Used for instance after we remove an appservice registration #[tracing::instrument(skip(self))] @@ -461,6 +487,9 @@ impl Service { SendingEventType::Edu(_) => { // Appservices don't need EDUs (?) }, + SendingEventType::Flush => { + // flush only; no new content + }, } } @@ -480,6 +509,7 @@ impl Service { .iter() .map(|e| match e { SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b, + SendingEventType::Flush => &[], }) .collect::>(), ))) @@ -521,6 +551,9 @@ impl Service { SendingEventType::Edu(_) => { // Push gateways don't need EDUs (?) }, + SendingEventType::Flush => { + // flush only; no new content + }, } } @@ -601,6 +634,9 @@ impl Service { edu_jsons.push(raw); } }, + SendingEventType::Flush => { + // flush only; no new content + }, } } @@ -618,6 +654,7 @@ impl Service { .iter() .map(|e| match e { SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b, + SendingEventType::Flush => &[], }) .collect::>(), ))) From 3ac536857864797c6e0f7734228874e5f4d7d907 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 17 Mar 2024 02:42:49 -0400 Subject: [PATCH 16/20] bump conduwuit version to 0.1.8 Signed-off-by: strawberry --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ddc84c1f..5e042beb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -412,7 +412,7 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" [[package]] name = "conduit" -version = "0.7.0-alpha+conduwuit-0.1.7" +version = "0.7.0-alpha+conduwuit-0.1.8" dependencies = [ "argon2", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index 0c84c897..56087707 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ authors = ["strawberry ", "timokoesters Date: Sun, 17 Mar 2024 12:11:24 -0400 Subject: [PATCH 17/20] ignore deactivated users and remote user profiles wih forbidden_usernames Signed-off-by: strawberry --- src/database/mod.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/database/mod.rs b/src/database/mod.rs index e2b25bf0..d61dfad2 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -933,8 +933,13 @@ impl KeyValueDatabase { { let patterns = &services().globals.config.forbidden_usernames; if !patterns.is_empty() { - for user in services().users.iter() { - let user_id = user?; + for user_id in services() + .users + .iter() + .filter_map(std::result::Result::ok) + .filter(|user| !services().users.is_deactivated(user).unwrap_or(true)) + .filter(|user| user.server_name() == services().globals.server_name()) + { let matches = patterns.matches(user_id.localpart()); if matches.matched_any() { warn!( From a0161ed7c1a8b62434650ab2912a4005852c259b Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 17 Mar 2024 12:16:04 -0400 Subject: [PATCH 18/20] config option to allow incoming remote read receipts Signed-off-by: strawberry --- conduwuit-example.toml | 5 ++++- src/api/server_server.rs | 4 ++++ src/config/mod.rs | 7 +++++++ src/service/globals/mod.rs | 2 ++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/conduwuit-example.toml b/conduwuit-example.toml index 3fdf279d..089088f4 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -357,7 +357,7 @@ url_preview_check_root_domain = false -### Presence +### Presence / Typing Indicators / Read Receipts # Config option to control local (your server only) presence updates/requests. Defaults to false. # Note that presence on conduwuit is very fast unlike Synapse's. @@ -385,6 +385,9 @@ url_preview_check_root_domain = false # Config option to control how many seconds before presence updates that you are offline. Defaults to 30 minutes. #presence_offline_timeout_s = 1800 +# Config option to control whether we should receive remote incoming read receipts. +# Defaults to true. +#allow_incoming_read_receipts = true # Other options not in [global]: diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 07be27f4..a2395526 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -823,6 +823,10 @@ pub async fn send_transaction_message_route( } }, Edu::Receipt(receipt) => { + if !services().globals.allow_incoming_read_receipts() { + continue; + } + for (room_id, room_updates) in receipt.receipts { for (user_id, user_updates) in room_updates.read { if let Some((event_id, _)) = user_updates diff --git a/src/config/mod.rs b/src/config/mod.rs index 19731bbe..c6252a83 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -144,6 +144,9 @@ pub struct Config { #[serde(default = "default_presence_offline_timeout_s")] pub presence_offline_timeout_s: u64, + #[serde(default = "true_fn")] + pub allow_incoming_read_receipts: bool, + #[serde(default)] pub zstd_compression: bool, @@ -282,6 +285,10 @@ impl fmt::Display for Config { "Allow local presence requests (updates)", &self.allow_local_presence.to_string(), ), + ( + "Allow incoming remote read receipts", + &self.allow_incoming_read_receipts.to_string(), + ), ( "Block non-admin room invites (local and remote, admins can still send and receive invites)", &self.block_non_admin_invites.to_string(), diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index f54f0686..6f30e1b2 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -359,6 +359,8 @@ impl Service<'_> { pub fn presence_offline_timeout_s(&self) -> u64 { self.config.presence_offline_timeout_s } + pub fn allow_incoming_read_receipts(&self) -> bool { self.config.allow_incoming_read_receipts } + pub fn rocksdb_log_level(&self) -> &String { &self.config.rocksdb_log_level } pub fn rocksdb_max_log_file_size(&self) -> usize { self.config.rocksdb_max_log_file_size } From 7f22f0e3a6732902c5ba72dbc0081b535594f0ac Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 17 Mar 2024 15:09:36 -0400 Subject: [PATCH 19/20] keypair logging adjustments Signed-off-by: strawberry --- src/database/key_value/globals.rs | 10 ++++++++-- src/service/globals/mod.rs | 2 +- src/utils/mod.rs | 5 +++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index 48e55578..63d0e2b6 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -8,6 +8,7 @@ use ruma::{ signatures::Ed25519KeyPair, DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, }; +use tracing::debug; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; @@ -185,7 +186,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" fn load_keypair(&self) -> Result { let keypair_bytes = self.global.get(b"keypair")?.map_or_else( || { + debug!("No keypair found in database, assuming this is a new deployment and generating one."); let keypair = utils::generate_keypair(); + debug!("Generated keypair bytes: {:?}", keypair); self.global.insert(b"keypair", &keypair)?; Ok::<_, Error>(keypair) }, @@ -200,6 +203,7 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" ) .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) .and_then(|version| { + debug!("Keypair version: {version}"); // 2. key parts .next() @@ -207,8 +211,10 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" .map(|key| (version, key)) }) .and_then(|(version, key)| { - Ed25519KeyPair::from_der(key, version) - .map_err(|_| Error::bad_database("Private or public keys are invalid.")) + let keypair = Ed25519KeyPair::from_der(key, version) + .map_err(|_| Error::bad_database("Private or public keys are invalid.")); + debug!("Private and public key bytes: {keypair:?}"); + keypair }) } diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 6f30e1b2..21aaeb14 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -152,7 +152,7 @@ impl Service<'_> { let keypair = match keypair { Ok(k) => k, Err(e) => { - error!("Keypair invalid. Deleting..."); + error!("Homeserver signing keypair in database is invalid. Deleting..."); db.remove_keypair()?; return Err(e); }, diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 672224f6..3f054d7d 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -11,6 +11,7 @@ use argon2::{password_hash::SaltString, PasswordHasher}; use rand::prelude::*; use ring::digest; use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId}; +use tracing::debug; use crate::{services, Error, Result}; @@ -30,8 +31,11 @@ pub(crate) fn increment(old: Option<&[u8]>) -> Option> { Some(number.to_be_bytes().to_vec()) } +/// Generates a new homeserver signing key. First 8 bytes are the version (a +/// random alphanumeric string), the rest are generated by Ed25519KeyPair pub fn generate_keypair() -> Vec { let mut value = random_string(8).as_bytes().to_vec(); + debug!("Keypair version bytes: {value:?}"); value.push(0xFF); value.extend_from_slice( &ruma::signatures::Ed25519KeyPair::generate().expect("Ed25519KeyPair generation always works (?)"), @@ -58,6 +62,7 @@ pub fn user_id_from_bytes(bytes: &[u8]) -> Result { .map_err(|_| Error::bad_database("Failed to parse user id from bytes")) } +/// Generats a random *alphanumeric* string pub fn random_string(length: usize) -> String { thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(length).map(char::from).collect() } From 2fad03597a57d35ad40baec0aab0bc5cdd1c38cc Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 17 Mar 2024 17:38:59 -0400 Subject: [PATCH 20/20] a Signed-off-by: strawberry --- src/clap.rs | 39 ++++++++++++++++++- src/database/key_value/globals.rs | 3 +- src/main.rs | 64 +++++++++++++++++++++++++++---- 3 files changed, 96 insertions(+), 10 deletions(-) diff --git a/src/clap.rs b/src/clap.rs index 446398c7..1057bb96 100644 --- a/src/clap.rs +++ b/src/clap.rs @@ -2,7 +2,7 @@ use std::path::PathBuf; -use clap::Parser; +use clap::{Parser, Subcommand}; /// Commandline arguments #[derive(Parser, Debug)] @@ -11,6 +11,43 @@ pub struct Args { #[arg(short, long)] /// Optional argument to the path of a conduwuit config TOML file pub config: Option, + + #[clap(subcommand)] + /// Optional subcommand to export the homeserver signing key and exit + pub signing_key: Option, +} + +#[derive(Debug, Subcommand)] +pub enum SigningKey { + /// Filesystem path to export the homeserver signing key to. + /// The output will be: `ed25519 ` which + /// is Synapse's format + ExportPath { + path: PathBuf, + }, + + /// Filesystem path for conduwuit to attempt to read and import the + /// homeserver signing key. The expected format is Synapse's format: + /// `ed25519 ` + ImportPath { + path: PathBuf, + + #[arg(long)] + /// Optional argument to import the key but don't overwrite our signing + /// key, and instead add it to `old_verify_keys`. This field tells other + /// servers that this is our old public key that can still be used to + /// sign old events. + /// + /// See https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2server for more details. + add_to_old_public_keys: bool, + + #[arg(long)] + /// Timestamp (`expired_ts`) in seconds since UNIX epoch that the old + /// homeserver signing key stopped being used. + /// + /// See https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2server for more details. + timestamp: u64, + }, } /// Parse commandline arguments into structured data diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index 63d0e2b6..058a1a04 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -211,9 +211,10 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" .map(|key| (version, key)) }) .and_then(|(version, key)| { + debug!("Keypair bytes: {:?}", key); let keypair = Ed25519KeyPair::from_der(key, version) .map_err(|_| Error::bad_database("Private or public keys are invalid.")); - debug!("Private and public key bytes: {keypair:?}"); + debug!("Private and public key: {keypair:?}"); keypair }) } diff --git a/src/main.rs b/src/main.rs index 03c9bdf6..31e4dd20 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,8 +15,12 @@ use axum::{ use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; #[cfg(feature = "axum_dual_protocol")] use axum_server_dual_protocol::ServerExt; -use conduit::api::{client_server, server_server}; +use base64::{engine::general_purpose, Engine as _}; pub use conduit::*; // Re-export everything from the library crate +use conduit::{ + api::{client_server, server_server}, + clap::{Args, SigningKey}, +}; use either::Either::{Left, Right}; use figment::{ providers::{Env, Format, Toml}, @@ -28,12 +32,15 @@ use http::{ }; #[cfg(unix)] use hyperlocal::SocketIncoming; -use ruma::api::{ - client::{ - error::{Error as RumaError, ErrorBody, ErrorKind}, - uiaa::UiaaResponse, +use ruma::{ + api::{ + client::{ + error::{Error as RumaError, ErrorBody, ErrorKind}, + uiaa::UiaaResponse, + }, + IncomingRequest, }, - IncomingRequest, + serde::Base64, }; #[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] use tikv_jemallocator::Jemalloc; @@ -73,7 +80,7 @@ async fn main() { } else if args.config.is_some() { Figment::new() .merge( - Toml::file(args.config.expect( + Toml::file(args.config.as_ref().expect( "conduwuit config commandline argument was specified, but appears to be invalid. This should be \ set to the path of a valid TOML file.", )) @@ -169,8 +176,16 @@ async fn main() { let config = &services().globals.config; - /* ad-hoc config validation/checks */ + /* homeserver signing keypair subcommand stuff */ + if let Some(subcommands) = &args.signing_key { + if signing_key_operations(subcommands).await.is_ok() { + return; + } + } + debug!("Ed25519KeyPair: {:?}", services().globals.keypair()); + + /* ad-hoc config validation/checks */ if config.unix_socket_path.is_some() && !cfg!(unix) { error!( "UNIX socket support is only available on *nix platforms. Please remove \"unix_socket_path\" from your \ @@ -912,3 +927,36 @@ fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { Ok(()) } + +/// Homeserver signing key commands/operations +async fn signing_key_operations(subcommands: &SigningKey) -> Result<()> { + match subcommands { + SigningKey::ExportPath { + path, + } => { + let mut file = tokio::fs::File::create(path).await?; + let mut content = String::new(); + + content.push_str("ed25519 "); + + let version = services().globals.keypair().version(); + + content.push_str(version); + content.push(' '); + + let keypair = services().globals.keypair(); + debug!("Ed25519KeyPair: {:?}", keypair); + + //let key_base64 = Base64::new(key); + + Ok(()) + }, + SigningKey::ImportPath { + path, + add_to_old_public_keys, + timestamp, + } => { + unimplemented!() + }, + } +}