From 5a1c41e66b4fec8ab76fd268fc9c9e282fd19428 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 11 Jan 2025 18:43:54 -0500 Subject: [PATCH] knocking implementation Signed-off-by: strawberry add sync bit of knocking Signed-off-by: strawberry --- src/api/client/membership.rs | 716 +++++++++++++++++++++++++-- src/api/client/sync/v3.rs | 39 +- src/api/client/sync/v4.rs | 9 + src/api/router.rs | 3 + src/api/server/invite.rs | 14 +- src/api/server/make_knock.rs | 38 +- src/api/server/make_leave.rs | 6 +- src/api/server/mod.rs | 4 + src/api/server/send_join.rs | 13 +- src/api/server/send_knock.rs | 75 ++- src/api/server/utils.rs | 17 +- src/database/maps.rs | 8 + src/service/rooms/state_cache/mod.rs | 142 +++++- src/service/rooms/timeline/mod.rs | 11 +- 14 files changed, 978 insertions(+), 117 deletions(-) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 0c493a37..d94fc3c7 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Borrow, collections::{BTreeMap, HashMap, HashSet}, net::IpAddr, sync::Arc, @@ -8,7 +9,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduwuit::{ debug, debug_info, debug_warn, err, info, - pdu::{self, gen_event_id_canonical_json, PduBuilder}, + pdu::{gen_event_id_canonical_json, PduBuilder}, result::FlatOk, trace, utils::{self, shuffle, IterStream, ReadyExt}, @@ -19,6 +20,7 @@ use ruma::{ api::{ client::{ error::ErrorKind, + knock::knock_room, membership::{ ban_user, forget_room, get_member_events, invite_user, join_room_by_id, join_room_by_id_or_alias, @@ -37,11 +39,12 @@ use ruma::{ }, StateEventType, }, - state_res, CanonicalJsonObject, CanonicalJsonValue, OwnedRoomId, OwnedServerName, - OwnedUserId, RoomId, RoomVersionId, ServerName, UserId, + state_res, CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId, + OwnedServerName, OwnedUserId, RoomId, RoomVersionId, ServerName, UserId, }; use service::{ appservice::RegistrationInfo, + pdu::gen_event_id, rooms::{state::RoomMutexGuard, state_compressor::HashSetCompressStateEvent}, Services, }; @@ -348,6 +351,116 @@ pub(crate) async fn join_room_by_id_or_alias_route( Ok(join_room_by_id_or_alias::v3::Response { room_id: join_room_response.room_id }) } +/// # `POST /_matrix/client/*/knock/{roomIdOrAlias}` +/// +/// Tries to knock the room to ask permission to join for the sender user. +#[tracing::instrument(skip_all, fields(%client), name = "knock")] +pub(crate) async fn knock_room_route( + State(services): State, + InsecureClientIp(client): InsecureClientIp, + body: Ruma, +) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let body = body.body; + + let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias) { + | Ok(room_id) => { + banned_room_check( + &services, + sender_user, + Some(&room_id), + room_id.server_name(), + client, + ) + .await?; + + let mut servers = body.via.clone(); + servers.extend( + services + .rooms + .state_cache + .servers_invite_via(&room_id) + .map(ToOwned::to_owned) + .collect::>() + .await, + ); + + servers.extend( + services + .rooms + .state_cache + .invite_state(sender_user, &room_id) + .await + .unwrap_or_default() + .iter() + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); + + if let Some(server) = room_id.server_name() { + servers.push(server.to_owned()); + } + + servers.sort_unstable(); + servers.dedup(); + shuffle(&mut servers); + + (servers, room_id) + }, + | Err(room_alias) => { + let (room_id, mut servers) = services + .rooms + .alias + .resolve_alias(&room_alias, Some(body.via.clone())) + .await?; + + banned_room_check( + &services, + sender_user, + Some(&room_id), + Some(room_alias.server_name()), + client, + ) + .await?; + + let addl_via_servers = services + .rooms + .state_cache + .servers_invite_via(&room_id) + .map(ToOwned::to_owned); + + let addl_state_servers = services + .rooms + .state_cache + .invite_state(sender_user, &room_id) + .await + .unwrap_or_default(); + + let mut addl_servers: Vec<_> = addl_state_servers + .iter() + .map(|event| event.get_field("sender")) + .filter_map(FlatOk::flat_ok) + .map(|user: &UserId| user.server_name().to_owned()) + .stream() + .chain(addl_via_servers) + .collect() + .await; + + addl_servers.sort_unstable(); + addl_servers.dedup(); + shuffle(&mut addl_servers); + servers.append(&mut addl_servers); + + (servers, room_id) + }, + }; + + knock_room_by_id_helper(&services, sender_user, &room_id, body.reason.clone(), &servers) + .boxed() + .await +} + /// # `POST /_matrix/client/v3/rooms/{roomId}/leave` /// /// Tries to leave the sender user from a room. @@ -403,6 +516,17 @@ pub(crate) async fn invite_user_route( ))); } + if let Ok(target_user_membership) = services + .rooms + .state_accessor + .get_member(&body.room_id, user_id) + .await + { + if target_user_membership.membership == MembershipState::Ban { + return Err!(Request(Forbidden("User is banned from this room."))); + } + } + if recipient_ignored_by_sender { // silently drop the invite to the recipient if they've been ignored by the // sender, pretend it worked @@ -862,7 +986,7 @@ async fn join_room_by_id_helper_remote( .hash_and_sign_event(&mut join_event_stub, &room_version_id)?; // Generate event id - let event_id = pdu::gen_event_id(&join_event_stub, &room_version_id)?; + let event_id = gen_event_id(&join_event_stub, &room_version_id)?; // Add event_id back join_event_stub @@ -1030,7 +1154,7 @@ async fn join_room_by_id_helper_remote( }; let auth_check = state_res::event_auth::auth_check( - &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), + &state_res::RoomVersion::new(&room_version_id)?, &parsed_join_pdu, None, // TODO: third party invite |k, s| state_fetch(k, s.to_owned()), @@ -1043,10 +1167,10 @@ async fn join_room_by_id_helper_remote( } info!("Compressing state from send_join"); - let compressed = state - .iter() - .stream() - .then(|(&k, id)| services.rooms.state_compressor.compress_state_event(k, id)) + let compressed: HashSet<_> = services + .rooms + .state_compressor + .compress_state_events(state.iter().map(|(ssk, eid)| (ssk, eid.borrow()))) .collect() .await; @@ -1282,7 +1406,7 @@ async fn join_room_by_id_helper_local( .hash_and_sign_event(&mut join_event_stub, &room_version_id)?; // Generate event id - let event_id = pdu::gen_event_id(&join_event_stub, &room_version_id)?; + let event_id = gen_event_id(&join_event_stub, &room_version_id)?; // Add event_id back join_event_stub @@ -1392,6 +1516,7 @@ async fn make_join_request( ); make_join_response_and_server = Err!(BadServerResponse("No server available to assist in joining.")); + return make_join_response_and_server; } } @@ -1569,7 +1694,7 @@ pub async fn leave_all_rooms(services: &Services, user_id: &UserId) { for room_id in all_rooms { // ignore errors if let Err(e) = leave_room(services, user_id, &room_id, None).await { - warn!(%room_id, %user_id, %e, "Failed to leave room"); + warn!(%user_id, "Failed to leave {room_id} remotely: {e}"); } services.rooms.state_cache.forget(&room_id, user_id); @@ -1585,11 +1710,15 @@ pub async fn leave_room( //use conduwuit::utils::stream::OptionStream; use futures::TryFutureExt; - // Ask a remote server if we don't have this room + // Ask a remote server if we don't have this room and are not knocking on it if !services .rooms .state_cache .server_in_room(services.globals.server_name(), room_id) + .await && !services + .rooms + .state_cache + .is_knocked(user_id, room_id) .await { if let Err(e) = remote_leave_room(services, user_id, room_id).await { @@ -1601,7 +1730,8 @@ pub async fn leave_room( .rooms .state_cache .invite_state(user_id, room_id) - .map_err(|_| services.rooms.state_cache.left_state(user_id, room_id)) + .or_else(|_| services.rooms.state_cache.knock_state(user_id, room_id)) + .or_else(|_| services.rooms.state_cache.left_state(user_id, room_id)) .await .ok(); @@ -1683,13 +1813,6 @@ async fn remote_leave_room( let mut make_leave_response_and_server = Err!(BadServerResponse("No server available to assist in leaving.")); - let invite_state = services - .rooms - .state_cache - .invite_state(user_id, room_id) - .await - .map_err(|_| err!(Request(BadState("User is not invited."))))?; - let mut servers: HashSet = services .rooms .state_cache @@ -1698,13 +1821,39 @@ async fn remote_leave_room( .collect() .await; - servers.extend( - invite_state - .iter() - .filter_map(|event| event.get_field("sender").ok().flatten()) - .filter_map(|sender: &str| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); + if let Ok(invite_state) = services + .rooms + .state_cache + .invite_state(user_id, room_id) + .await + { + servers.extend( + invite_state + .iter() + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); + } else if let Ok(knock_state) = services + .rooms + .state_cache + .knock_state(user_id, room_id) + .await + { + servers.extend( + knock_state + .iter() + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) + .filter_map(|sender| { + if !services.globals.user_is_local(sender) { + Some(sender.server_name().to_owned()) + } else { + None + } + }), + ); + } if let Some(room_id_server_name) = room_id.server_name() { servers.insert(room_id_server_name.to_owned()); @@ -1779,7 +1928,7 @@ async fn remote_leave_room( .hash_and_sign_event(&mut leave_event_stub, &room_version_id)?; // Generate event id - let event_id = pdu::gen_event_id(&leave_event_stub, &room_version_id)?; + let event_id = gen_event_id(&leave_event_stub, &room_version_id)?; // Add event_id back leave_event_stub @@ -1805,3 +1954,514 @@ async fn remote_leave_room( Ok(()) } + +async fn knock_room_by_id_helper( + services: &Services, + sender_user: &UserId, + room_id: &RoomId, + reason: Option, + servers: &[OwnedServerName], +) -> Result { + let state_lock = services.rooms.state.mutex.lock(room_id).await; + + if services + .rooms + .state_cache + .is_invited(sender_user, room_id) + .await + { + debug_warn!("{sender_user} is already invited in {room_id} but attempted to knock"); + return Err!(Request(Forbidden( + "You cannot knock on a room you are already invited/accepted to." + ))); + } + + if services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { + debug_warn!("{sender_user} is already joined in {room_id} but attempted to knock"); + return Err!(Request(Forbidden("You cannot knock on a room you are already joined in."))); + } + + if services + .rooms + .state_cache + .is_knocked(sender_user, room_id) + .await + { + debug_warn!("{sender_user} is already knocked in {room_id}"); + return Ok(knock_room::v3::Response { room_id: room_id.into() }); + } + + if let Ok(membership) = services + .rooms + .state_accessor + .get_member(room_id, sender_user) + .await + { + if membership.membership == MembershipState::Ban { + debug_warn!("{sender_user} is banned from {room_id} but attempted to knock"); + return Err!(Request(Forbidden("You cannot knock on a room you are banned from."))); + } + } + + let server_in_room = services + .rooms + .state_cache + .server_in_room(services.globals.server_name(), room_id) + .await; + + let local_knock = server_in_room + || servers.is_empty() + || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])); + + if local_knock { + knock_room_helper_local(services, sender_user, room_id, reason, servers, state_lock) + .boxed() + .await?; + } else { + knock_room_helper_remote(services, sender_user, room_id, reason, servers, state_lock) + .boxed() + .await?; + } + + Ok(knock_room::v3::Response::new(room_id.to_owned())) +} + +async fn knock_room_helper_local( + services: &Services, + sender_user: &UserId, + room_id: &RoomId, + reason: Option, + servers: &[OwnedServerName], + state_lock: RoomMutexGuard, +) -> Result { + debug_info!("We can knock locally"); + + let room_version_id = services.rooms.state.get_room_version(room_id).await?; + + if matches!( + room_version_id, + RoomVersionId::V1 + | RoomVersionId::V2 + | RoomVersionId::V3 + | RoomVersionId::V4 + | RoomVersionId::V5 + | RoomVersionId::V6 + ) { + return Err!(Request(Forbidden("This room does not support knocking."))); + } + + let content = RoomMemberEventContent { + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), + blurhash: services.users.blurhash(sender_user).await.ok(), + reason: reason.clone(), + ..RoomMemberEventContent::new(MembershipState::Knock) + }; + + // Try normal knock first + let Err(error) = services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder::state(sender_user.to_string(), &content), + sender_user, + room_id, + &state_lock, + ) + .await + else { + return Ok(()); + }; + + if servers.is_empty() || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])) + { + return Err(error); + } + + warn!("We couldn't do the knock locally, maybe federation can help to satisfy the knock"); + + let (make_knock_response, remote_server) = + make_knock_request(services, sender_user, room_id, servers).await?; + + info!("make_knock finished"); + + let room_version_id = make_knock_response.room_version; + + if !services.server.supported_room_version(&room_version_id) { + return Err!(BadServerResponse( + "Remote room version {room_version_id} is not supported by conduwuit" + )); + } + + let mut knock_event_stub = serde_json::from_str::( + make_knock_response.event.get(), + ) + .map_err(|e| { + err!(BadServerResponse("Invalid make_knock event json received from server: {e:?}")) + })?; + + knock_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), + ); + knock_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + knock_event_stub.insert( + "content".to_owned(), + to_canonical_value(RoomMemberEventContent { + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), + blurhash: services.users.blurhash(sender_user).await.ok(), + reason, + ..RoomMemberEventContent::new(MembershipState::Knock) + }) + .expect("event is valid, we just created it"), + ); + + // In order to create a compatible ref hash (EventID) the `hashes` field needs + // to be present + services + .server_keys + .hash_and_sign_event(&mut knock_event_stub, &room_version_id)?; + + // Generate event id + let event_id = gen_event_id(&knock_event_stub, &room_version_id)?; + + // Add event_id + knock_event_stub + .insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); + + // It has enough fields to be called a proper event now + let knock_event = knock_event_stub; + + info!("Asking {remote_server} for send_knock in room {room_id}"); + let send_knock_request = federation::knock::send_knock::v1::Request { + room_id: room_id.to_owned(), + event_id: event_id.clone(), + pdu: services + .sending + .convert_to_outgoing_federation_event(knock_event.clone()) + .await, + }; + + let send_knock_response = services + .sending + .send_federation_request(&remote_server, send_knock_request) + .await?; + + info!("send_knock finished"); + + services + .rooms + .short + .get_or_create_shortroomid(room_id) + .await; + + info!("Parsing knock event"); + + let parsed_knock_pdu = PduEvent::from_id_val(&event_id, knock_event.clone()) + .map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?; + + info!("Updating membership locally to knock state with provided stripped state events"); + services + .rooms + .state_cache + .update_membership( + room_id, + sender_user, + parsed_knock_pdu + .get_content::() + .expect("we just created this"), + sender_user, + Some(send_knock_response.knock_room_state), + None, + false, + ) + .await?; + + info!("Appending room knock event locally"); + services + .rooms + .timeline + .append_pdu( + &parsed_knock_pdu, + knock_event, + vec![(*parsed_knock_pdu.event_id).to_owned()], + &state_lock, + ) + .await?; + + Ok(()) +} + +async fn knock_room_helper_remote( + services: &Services, + sender_user: &UserId, + room_id: &RoomId, + reason: Option, + servers: &[OwnedServerName], + state_lock: RoomMutexGuard, +) -> Result { + info!("Knocking {room_id} over federation."); + + let (make_knock_response, remote_server) = + make_knock_request(services, sender_user, room_id, servers).await?; + + info!("make_knock finished"); + + let room_version_id = make_knock_response.room_version; + + if !services.server.supported_room_version(&room_version_id) { + return Err!(BadServerResponse( + "Remote room version {room_version_id} is not supported by conduwuit" + )); + } + + let mut knock_event_stub: CanonicalJsonObject = + serde_json::from_str(make_knock_response.event.get()).map_err(|e| { + err!(BadServerResponse("Invalid make_knock event json received from server: {e:?}")) + })?; + + knock_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), + ); + knock_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + knock_event_stub.insert( + "content".to_owned(), + to_canonical_value(RoomMemberEventContent { + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), + blurhash: services.users.blurhash(sender_user).await.ok(), + reason, + ..RoomMemberEventContent::new(MembershipState::Knock) + }) + .expect("event is valid, we just created it"), + ); + + // In order to create a compatible ref hash (EventID) the `hashes` field needs + // to be present + services + .server_keys + .hash_and_sign_event(&mut knock_event_stub, &room_version_id)?; + + // Generate event id + let event_id = gen_event_id(&knock_event_stub, &room_version_id)?; + + // Add event_id + knock_event_stub + .insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); + + // It has enough fields to be called a proper event now + let knock_event = knock_event_stub; + + info!("Asking {remote_server} for send_knock in room {room_id}"); + let send_knock_request = federation::knock::send_knock::v1::Request { + room_id: room_id.to_owned(), + event_id: event_id.clone(), + pdu: services + .sending + .convert_to_outgoing_federation_event(knock_event.clone()) + .await, + }; + + let send_knock_response = services + .sending + .send_federation_request(&remote_server, send_knock_request) + .await?; + + info!("send_knock finished"); + + services + .rooms + .short + .get_or_create_shortroomid(room_id) + .await; + + info!("Parsing knock event"); + let parsed_knock_pdu = PduEvent::from_id_val(&event_id, knock_event.clone()) + .map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?; + + info!("Going through send_knock response knock state events"); + let state = send_knock_response + .knock_room_state + .iter() + .map(|event| serde_json::from_str::(event.clone().into_json().get())) + .filter_map(Result::ok); + + let mut state_map: HashMap = HashMap::new(); + + for event in state { + let Some(state_key) = event.get("state_key") else { + debug_warn!("send_knock stripped state event missing state_key: {event:?}"); + continue; + }; + let Some(event_type) = event.get("type") else { + debug_warn!("send_knock stripped state event missing event type: {event:?}"); + continue; + }; + + let Ok(state_key) = serde_json::from_value::(state_key.clone().into()) else { + debug_warn!("send_knock stripped state event has invalid state_key: {event:?}"); + continue; + }; + let Ok(event_type) = serde_json::from_value::(event_type.clone().into()) + else { + debug_warn!("send_knock stripped state event has invalid event type: {event:?}"); + continue; + }; + + let event_id = gen_event_id(&event, &room_version_id)?; + let shortstatekey = services + .rooms + .short + .get_or_create_shortstatekey(&event_type, &state_key) + .await; + + services.rooms.outlier.add_pdu_outlier(&event_id, &event); + state_map.insert(shortstatekey, event_id.clone()); + } + + info!("Compressing state from send_knock"); + let compressed: HashSet<_> = services + .rooms + .state_compressor + .compress_state_events(state_map.iter().map(|(ssk, eid)| (ssk, eid.borrow()))) + .collect() + .await; + + debug!("Saving compressed state"); + let HashSetCompressStateEvent { + shortstatehash: statehash_before_knock, + added, + removed, + } = services + .rooms + .state_compressor + .save_state(room_id, Arc::new(compressed)) + .await?; + + debug!("Forcing state for new room"); + services + .rooms + .state + .force_state(room_id, statehash_before_knock, added, removed, &state_lock) + .await?; + + let statehash_after_knock = services + .rooms + .state + .append_to_state(&parsed_knock_pdu) + .await?; + + info!("Updating membership locally to knock state with provided stripped state events"); + services + .rooms + .state_cache + .update_membership( + room_id, + sender_user, + parsed_knock_pdu + .get_content::() + .expect("we just created this"), + sender_user, + Some(send_knock_response.knock_room_state), + None, + false, + ) + .await?; + + info!("Appending room knock event locally"); + services + .rooms + .timeline + .append_pdu( + &parsed_knock_pdu, + knock_event, + vec![(*parsed_knock_pdu.event_id).to_owned()], + &state_lock, + ) + .await?; + + info!("Setting final room state for new room"); + // We set the room state after inserting the pdu, so that we never have a moment + // in time where events in the current room state do not exist + services + .rooms + .state + .set_room_state(room_id, statehash_after_knock, &state_lock); + + Ok(()) +} + +async fn make_knock_request( + services: &Services, + sender_user: &UserId, + room_id: &RoomId, + servers: &[OwnedServerName], +) -> Result<(federation::knock::create_knock_event_template::v1::Response, OwnedServerName)> { + let mut make_knock_response_and_server = + Err!(BadServerResponse("No server available to assist in knocking.")); + + let mut make_knock_counter: usize = 0; + + for remote_server in servers { + if services.globals.server_is_ours(remote_server) { + continue; + } + + info!("Asking {remote_server} for make_knock ({make_knock_counter})"); + + let make_knock_response = services + .sending + .send_federation_request( + remote_server, + federation::knock::create_knock_event_template::v1::Request { + room_id: room_id.to_owned(), + user_id: sender_user.to_owned(), + ver: services.server.supported_room_versions().collect(), + }, + ) + .await; + + trace!("make_knock response: {make_knock_response:?}"); + make_knock_counter = make_knock_counter.saturating_add(1); + + make_knock_response_and_server = make_knock_response.map(|r| (r, remote_server.clone())); + + if make_knock_response_and_server.is_ok() { + break; + } + + if make_knock_counter > 40 { + warn!( + "50 servers failed to provide valid make_knock response, assuming no server can \ + assist in knocking." + ); + make_knock_response_and_server = + Err!(BadServerResponse("No server available to assist in knocking.")); + + return make_knock_response_and_server; + } + } + + make_knock_response_and_server +} diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index 95c8c2d4..a4dc0205 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -33,8 +33,8 @@ use ruma::{ self, v3::{ Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, - LeftRoom, Presence, RoomAccountData, RoomSummary, Rooms, State as RoomState, - Timeline, ToDevice, + KnockState, KnockedRoom, LeftRoom, Presence, RoomAccountData, RoomSummary, Rooms, + State as RoomState, Timeline, ToDevice, }, DeviceLists, UnreadNotificationsCount, }, @@ -266,6 +266,35 @@ pub(crate) async fn build_sync_events( invited_rooms }); + let knocked_rooms = services + .rooms + .state_cache + .rooms_knocked(sender_user) + .fold_default(|mut knocked_rooms: BTreeMap<_, _>, (room_id, knock_state)| async move { + // Get and drop the lock to wait for remaining operations to finish + let insert_lock = services.rooms.timeline.mutex_insert.lock(&room_id).await; + drop(insert_lock); + + let knock_count = services + .rooms + .state_cache + .get_knock_count(&room_id, sender_user) + .await + .ok(); + + // Knocked before last sync + if Some(since) >= knock_count { + return knocked_rooms; + } + + let knocked_room = KnockedRoom { + knock_state: KnockState { events: knock_state }, + }; + + knocked_rooms.insert(room_id, knocked_room); + knocked_rooms + }); + let presence_updates: OptionFuture<_> = services .globals .allow_local_presence() @@ -300,7 +329,7 @@ pub(crate) async fn build_sync_events( .users .remove_to_device_events(sender_user, sender_device, since); - let rooms = join3(joined_rooms, left_rooms, invited_rooms); + let rooms = join4(joined_rooms, left_rooms, invited_rooms, knocked_rooms); let ephemeral = join3(remove_to_device_events, to_device_events, presence_updates); let top = join5(account_data, ephemeral, device_one_time_keys_count, keys_changed, rooms) .boxed() @@ -308,7 +337,7 @@ pub(crate) async fn build_sync_events( let (account_data, ephemeral, device_one_time_keys_count, keys_changed, rooms) = top; let ((), to_device_events, presence_updates) = ephemeral; - let (joined_rooms, left_rooms, invited_rooms) = rooms; + let (joined_rooms, left_rooms, invited_rooms, knocked_rooms) = rooms; let (joined_rooms, mut device_list_updates, left_encrypted_users) = joined_rooms; device_list_updates.extend(keys_changed); @@ -349,7 +378,7 @@ pub(crate) async fn build_sync_events( leave: left_rooms, join: joined_rooms, invite: invited_rooms, - knock: BTreeMap::new(), // TODO + knock: knocked_rooms, }, to_device: ToDevice { events: to_device_events }, }; diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs index 9915752e..24c7e286 100644 --- a/src/api/client/sync/v4.rs +++ b/src/api/client/sync/v4.rs @@ -113,9 +113,18 @@ pub(crate) async fn sync_events_v4_route( .collect() .await; + let all_knocked_rooms: Vec<_> = services + .rooms + .state_cache + .rooms_knocked(sender_user) + .map(|r| r.0) + .collect() + .await; + let all_rooms = all_joined_rooms .iter() .chain(all_invited_rooms.iter()) + .chain(all_knocked_rooms.iter()) .map(Clone::clone) .collect(); diff --git a/src/api/router.rs b/src/api/router.rs index 1b38670d..1d42fc5e 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -99,6 +99,7 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::join_room_by_id_route) .ruma_route(&client::join_room_by_id_or_alias_route) .ruma_route(&client::joined_members_route) + .ruma_route(&client::knock_room_route) .ruma_route(&client::leave_room_route) .ruma_route(&client::forget_room_route) .ruma_route(&client::joined_rooms_route) @@ -204,8 +205,10 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&server::get_room_state_route) .ruma_route(&server::get_room_state_ids_route) .ruma_route(&server::create_leave_event_template_route) + .ruma_route(&server::create_knock_event_template_route) .ruma_route(&server::create_leave_event_v1_route) .ruma_route(&server::create_leave_event_v2_route) + .ruma_route(&server::create_knock_event_v1_route) .ruma_route(&server::create_join_event_template_route) .ruma_route(&server::create_join_event_v1_route) .ruma_route(&server::create_join_event_v2_route) diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 6d3be04c..1fea268b 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -6,8 +6,9 @@ use ruma::{ api::{client::error::ErrorKind, federation::membership::create_invite}, events::room::member::{MembershipState, RoomMemberEventContent}, serde::JsonObject, - CanonicalJsonValue, OwnedEventId, OwnedUserId, UserId, + CanonicalJsonValue, OwnedUserId, UserId, }; +use service::pdu::gen_event_id; use crate::Ruma; @@ -86,12 +87,7 @@ pub(crate) async fn create_invite_route( .map_err(|e| err!(Request(InvalidParam("Failed to sign event: {e}"))))?; // Generate event id - let event_id = OwnedEventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&signed_event, &body.room_version) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); + let event_id = gen_event_id(&signed_event, &body.room_version)?; // Add event_id back signed_event.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.to_string())); @@ -115,12 +111,12 @@ pub(crate) async fn create_invite_route( let mut invite_state = body.invite_room_state.clone(); let mut event: JsonObject = serde_json::from_str(body.event.get()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?; + .map_err(|e| err!(Request(BadJson("Invalid invite event PDU: {e}"))))?; event.insert("event_id".to_owned(), "$placeholder".into()); let pdu: PduEvent = serde_json::from_value(event.into()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event."))?; + .map_err(|e| err!(Request(BadJson("Invalid invite event PDU: {e}"))))?; invite_state.push(pdu.to_stripped_state_event()); diff --git a/src/api/server/make_knock.rs b/src/api/server/make_knock.rs index 6d9d6d55..90b9b629 100644 --- a/src/api/server/make_knock.rs +++ b/src/api/server/make_knock.rs @@ -1,5 +1,5 @@ use axum::extract::State; -use conduwuit::Err; +use conduwuit::{debug_warn, Err}; use ruma::{ api::{client::error::ErrorKind, federation::knock::create_knock_event_template}, events::room::member::{MembershipState, RoomMemberEventContent}, @@ -15,7 +15,8 @@ use crate::{service::pdu::PduBuilder, Error, Result, Ruma}; /// /// Creates a knock template. pub(crate) async fn create_knock_event_template_route( - State(services): State, body: Ruma, + State(services): State, + body: Ruma, ) -> Result { if !services.rooms.metadata.exists(&body.room_id).await { return Err!(Request(NotFound("Room is unknown to this server."))); @@ -39,8 +40,8 @@ pub(crate) async fn create_knock_event_template_route( .contains(body.origin()) { warn!( - "Server {} for remote user {} tried knocking room ID {} which has a server name that is globally \ - forbidden. Rejecting.", + "Server {} for remote user {} tried knocking room ID {} which has a server name \ + that is globally forbidden. Rejecting.", body.origin(), &body.user_id, &body.room_id, @@ -63,29 +64,44 @@ pub(crate) async fn create_knock_event_template_route( if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6) { return Err(Error::BadRequest( - ErrorKind::IncompatibleRoomVersion { - room_version: room_version_id, - }, + ErrorKind::IncompatibleRoomVersion { room_version: room_version_id }, "Room version does not support knocking.", )); } if !body.ver.contains(&room_version_id) { return Err(Error::BadRequest( - ErrorKind::IncompatibleRoomVersion { - room_version: room_version_id, - }, + ErrorKind::IncompatibleRoomVersion { room_version: room_version_id }, "Your homeserver does not support the features required to knock on this room.", )); } let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + if let Ok(membership) = services + .rooms + .state_accessor + .get_member(&body.room_id, &body.user_id) + .await + { + if membership.membership == MembershipState::Ban { + debug_warn!( + "Remote user {} is banned from {} but attempted to knock", + &body.user_id, + &body.room_id + ); + return Err!(Request(Forbidden("You cannot knock on a room you are banned from."))); + } + } + let (_pdu, mut pdu_json) = services .rooms .timeline .create_hash_and_sign_event( - PduBuilder::state(body.user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Knock)), + PduBuilder::state( + body.user_id.to_string(), + &RoomMemberEventContent::new(MembershipState::Knock), + ), &body.user_id, &body.room_id, &state_lock, diff --git a/src/api/server/make_leave.rs b/src/api/server/make_leave.rs index 746a4858..936e0fbb 100644 --- a/src/api/server/make_leave.rs +++ b/src/api/server/make_leave.rs @@ -9,7 +9,7 @@ use serde_json::value::to_raw_value; use super::make_join::maybe_strip_event_id; use crate::{service::pdu::PduBuilder, Ruma}; -/// # `PUT /_matrix/federation/v1/make_leave/{roomId}/{eventId}` +/// # `GET /_matrix/federation/v1/make_leave/{roomId}/{eventId}` /// /// Creates a leave template. pub(crate) async fn create_leave_event_template_route( @@ -21,7 +21,9 @@ pub(crate) async fn create_leave_event_template_route( } if body.user_id.server_name() != body.origin() { - return Err!(Request(BadJson("Not allowed to leave on behalf of another server/user."))); + return Err!(Request(Forbidden( + "Not allowed to leave on behalf of another server/user." + ))); } // ACL check origin diff --git a/src/api/server/mod.rs b/src/api/server/mod.rs index 9b7d91cb..5c1ff3f7 100644 --- a/src/api/server/mod.rs +++ b/src/api/server/mod.rs @@ -6,6 +6,7 @@ pub(super) mod hierarchy; pub(super) mod invite; pub(super) mod key; pub(super) mod make_join; +pub(super) mod make_knock; pub(super) mod make_leave; pub(super) mod media; pub(super) mod openid; @@ -13,6 +14,7 @@ pub(super) mod publicrooms; pub(super) mod query; pub(super) mod send; pub(super) mod send_join; +pub(super) mod send_knock; pub(super) mod send_leave; pub(super) mod state; pub(super) mod state_ids; @@ -28,6 +30,7 @@ pub(super) use hierarchy::*; pub(super) use invite::*; pub(super) use key::*; pub(super) use make_join::*; +pub(super) use make_knock::*; pub(super) use make_leave::*; pub(super) use media::*; pub(super) use openid::*; @@ -35,6 +38,7 @@ pub(super) use publicrooms::*; pub(super) use query::*; pub(super) use send::*; pub(super) use send_join::*; +pub(super) use send_knock::*; pub(super) use send_leave::*; pub(super) use state::*; pub(super) use state_ids::*; diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 6cbe5143..97a65bf8 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -186,14 +186,13 @@ async fn create_join_event( .map_err(|e| err!(Request(InvalidParam(warn!("Failed to sign send_join event: {e}")))))?; let origin: OwnedServerName = serde_json::from_value( - serde_json::to_value( - value - .get("origin") - .ok_or_else(|| err!(Request(BadJson("Event missing origin property."))))?, - ) - .expect("CanonicalJson is valid json value"), + value + .get("origin") + .ok_or_else(|| err!(Request(BadJson("Event does not have an origin server name."))))? + .clone() + .into(), ) - .map_err(|e| err!(Request(BadJson(warn!("origin field is not a valid server name: {e}")))))?; + .map_err(|e| err!(Request(BadJson("Event has an invalid origin server name: {e}"))))?; let mutex_lock = services .rooms diff --git a/src/api/server/send_knock.rs b/src/api/server/send_knock.rs index 49ec4bf8..95478081 100644 --- a/src/api/server/send_knock.rs +++ b/src/api/server/send_knock.rs @@ -1,7 +1,8 @@ use axum::extract::State; -use conduwuit::{err, pdu::gen_event_id_canonical_json, warn, Err, Error, PduEvent, Result}; +use conduwuit::{err, pdu::gen_event_id_canonical_json, warn, Err, PduEvent, Result}; +use futures::FutureExt; use ruma::{ - api::{client::error::ErrorKind, federation::knock::send_knock}, + api::federation::knock::send_knock, events::{ room::member::{MembershipState, RoomMemberEventContent}, StateEventType, @@ -17,7 +18,8 @@ use crate::Ruma; /// /// Submits a signed knock event. pub(crate) async fn create_knock_event_v1_route( - State(services): State, body: Ruma, + State(services): State, + body: Ruma, ) -> Result { if services .globals @@ -26,7 +28,8 @@ pub(crate) async fn create_knock_event_v1_route( .contains(body.origin()) { warn!( - "Server {} tried knocking room ID {} who has a server name that is globally forbidden. Rejecting.", + "Server {} tried knocking room ID {} who has a server name that is globally \ + forbidden. Rejecting.", body.origin(), &body.room_id, ); @@ -41,7 +44,8 @@ pub(crate) async fn create_knock_event_v1_route( .contains(&server.to_owned()) { warn!( - "Server {} tried knocking room ID {} which has a server name that is globally forbidden. Rejecting.", + "Server {} tried knocking room ID {} which has a server name that is globally \ + forbidden. Rejecting.", body.origin(), &body.room_id, ); @@ -50,7 +54,7 @@ pub(crate) async fn create_knock_event_v1_route( } if !services.rooms.metadata.exists(&body.room_id).await { - return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); + return Err!(Request(NotFound("Room is unknown to this server."))); } // ACL check origin server @@ -74,44 +78,42 @@ pub(crate) async fn create_knock_event_v1_route( let event_type: StateEventType = serde_json::from_value( value .get("type") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing type property."))? + .ok_or_else(|| err!(Request(InvalidParam("Event has no event type."))))? .clone() .into(), ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event has invalid event type."))?; + .map_err(|e| err!(Request(InvalidParam("Event has invalid event type: {e}"))))?; if event_type != StateEventType::RoomMember { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, + return Err!(Request(InvalidParam( "Not allowed to send non-membership state event to knock endpoint.", - )); + ))); } let content: RoomMemberEventContent = serde_json::from_value( value .get("content") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing content property"))? + .ok_or_else(|| err!(Request(InvalidParam("Membership event has no content"))))? .clone() .into(), ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event content is empty or invalid"))?; + .map_err(|e| err!(Request(InvalidParam("Event has invalid membership content: {e}"))))?; if content.membership != MembershipState::Knock { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Not allowed to send a non-knock membership event to knock endpoint.", - )); + return Err!(Request(InvalidParam( + "Not allowed to send a non-knock membership event to knock endpoint." + ))); } // ACL check sender server name let sender: OwnedUserId = serde_json::from_value( value .get("sender") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing sender property."))? + .ok_or_else(|| err!(Request(InvalidParam("Event has no sender user ID."))))? .clone() .into(), ) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "sender is not a valid user ID."))?; + .map_err(|e| err!(Request(BadJson("Event sender is not a valid user ID: {e}"))))?; services .rooms @@ -127,36 +129,32 @@ pub(crate) async fn create_knock_event_v1_route( let state_key: OwnedUserId = serde_json::from_value( value .get("state_key") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing state_key property."))? + .ok_or_else(|| err!(Request(InvalidParam("Event does not have a state_key"))))? .clone() .into(), ) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "state_key is invalid or not a user ID."))?; + .map_err(|e| err!(Request(BadJson("Event does not have a valid state_key: {e}"))))?; if state_key != sender { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "State key does not match sender user", - )); + return Err!(Request(InvalidParam("state_key does not match sender user of event."))); }; let origin: OwnedServerName = serde_json::from_value( - serde_json::to_value( - value - .get("origin") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing origin property."))?, - ) - .expect("CanonicalJson is valid json value"), + value + .get("origin") + .ok_or_else(|| err!(Request(BadJson("Event does not have an origin server name."))))? + .clone() + .into(), ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?; + .map_err(|e| err!(Request(BadJson("Event has an invalid origin server name: {e}"))))?; let mut event: JsonObject = serde_json::from_str(body.pdu.get()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid knock event PDU."))?; + .map_err(|e| err!(Request(InvalidParam("Invalid knock event PDU: {e}"))))?; event.insert("event_id".to_owned(), "$placeholder".into()); let pdu: PduEvent = serde_json::from_value(event.into()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid knock event PDU."))?; + .map_err(|e| err!(Request(InvalidParam("Invalid knock event PDU: {e}"))))?; let mutex_lock = services .rooms @@ -169,19 +167,18 @@ pub(crate) async fn create_knock_event_v1_route( .rooms .event_handler .handle_incoming_pdu(&origin, &body.room_id, &event_id, value.clone(), true) + .boxed() .await? .ok_or_else(|| err!(Request(InvalidParam("Could not accept as timeline event."))))?; drop(mutex_lock); - let knock_room_state = services.rooms.state.summary_stripped(&pdu).await; - services .sending .send_pdu_room(&body.room_id, &pdu_id) .await?; - Ok(send_knock::v1::Response { - knock_room_state, - }) + let knock_room_state = services.rooms.state.summary_stripped(&pdu).await; + + Ok(send_knock::v1::Response { knock_room_state }) } diff --git a/src/api/server/utils.rs b/src/api/server/utils.rs index 112cf858..4f3fa245 100644 --- a/src/api/server/utils.rs +++ b/src/api/server/utils.rs @@ -1,6 +1,6 @@ use conduwuit::{implement, is_false, Err, Result}; use conduwuit_service::Services; -use futures::{future::OptionFuture, join, FutureExt}; +use futures::{future::OptionFuture, join, FutureExt, StreamExt}; use ruma::{EventId, RoomId, ServerName}; pub(super) struct AccessCheck<'a> { @@ -31,6 +31,15 @@ pub(super) async fn check(&self) -> Result { .state_cache .server_in_room(self.origin, self.room_id); + // if any user on our homeserver is trying to knock this room, we'll need to + // acknowledge bans or leaves + let user_is_knocking = self + .services + .rooms + .state_cache + .room_members_knocked(self.room_id) + .count(); + let server_can_see: OptionFuture<_> = self .event_id .map(|event_id| { @@ -42,14 +51,14 @@ pub(super) async fn check(&self) -> Result { }) .into(); - let (world_readable, server_in_room, server_can_see, acl_check) = - join!(world_readable, server_in_room, server_can_see, acl_check); + let (world_readable, server_in_room, server_can_see, acl_check, user_is_knocking) = + join!(world_readable, server_in_room, server_can_see, acl_check, user_is_knocking); if !acl_check { return Err!(Request(Forbidden("Server access denied."))); } - if !world_readable && !server_in_room { + if !world_readable && !server_in_room && user_is_knocking == 0 { return Err!(Request(Forbidden("Server is not in room."))); } diff --git a/src/database/maps.rs b/src/database/maps.rs index e9b26818..bc409919 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -184,6 +184,10 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "roomuserid_leftcount", ..descriptor::RANDOM }, + Descriptor { + name: "roomuserid_knockedcount", + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "roomuserid_privateread", ..descriptor::RANDOM_SMALL @@ -377,6 +381,10 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "userroomid_leftstate", ..descriptor::RANDOM }, + Descriptor { + name: "userroomid_knockedstate", + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "userroomid_notificationcount", ..descriptor::RANDOM diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 89421dfd..0d25142d 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -10,7 +10,7 @@ use conduwuit::{ warn, Result, }; use database::{serialize_key, Deserialized, Ignore, Interfix, Json, Map}; -use futures::{future::join4, pin_mut, stream::iter, Stream, StreamExt}; +use futures::{future::join5, pin_mut, stream::iter, Stream, StreamExt}; use itertools::Itertools; use ruma::{ events::{ @@ -51,11 +51,13 @@ struct Data { roomuserid_invitecount: Arc, roomuserid_joined: Arc, roomuserid_leftcount: Arc, + roomuserid_knockedcount: Arc, roomuseroncejoinedids: Arc, serverroomids: Arc, userroomid_invitestate: Arc, userroomid_joined: Arc, userroomid_leftstate: Arc, + userroomid_knockedstate: Arc, } type AppServiceInRoomCache = RwLock>>; @@ -81,11 +83,13 @@ impl crate::Service for Service { roomuserid_invitecount: args.db["roomuserid_invitecount"].clone(), roomuserid_joined: args.db["roomuserid_joined"].clone(), roomuserid_leftcount: args.db["roomuserid_leftcount"].clone(), + roomuserid_knockedcount: args.db["roomuserid_knockedcount"].clone(), roomuseroncejoinedids: args.db["roomuseroncejoinedids"].clone(), serverroomids: args.db["serverroomids"].clone(), userroomid_invitestate: args.db["userroomid_invitestate"].clone(), userroomid_joined: args.db["userroomid_joined"].clone(), userroomid_leftstate: args.db["userroomid_leftstate"].clone(), + userroomid_knockedstate: args.db["userroomid_knockedstate"].clone(), }, })) } @@ -336,6 +340,9 @@ impl Service { self.db.userroomid_leftstate.remove(&userroom_id); self.db.roomuserid_leftcount.remove(&roomuser_id); + self.db.userroomid_knockedstate.remove(&userroom_id); + self.db.roomuserid_knockedcount.remove(&roomuser_id); + self.db.roomid_inviteviaservers.remove(room_id); } @@ -352,12 +359,13 @@ impl Service { // (timo) TODO let leftstate = Vec::>::new(); - let count = self.services.globals.next_count().unwrap(); self.db .userroomid_leftstate .raw_put(&userroom_id, Json(leftstate)); - self.db.roomuserid_leftcount.raw_put(&roomuser_id, count); + self.db + .roomuserid_leftcount + .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); self.db.userroomid_joined.remove(&userroom_id); self.db.roomuserid_joined.remove(&roomuser_id); @@ -365,6 +373,44 @@ impl Service { self.db.userroomid_invitestate.remove(&userroom_id); self.db.roomuserid_invitecount.remove(&roomuser_id); + self.db.userroomid_knockedstate.remove(&userroom_id); + self.db.roomuserid_knockedcount.remove(&roomuser_id); + + self.db.roomid_inviteviaservers.remove(room_id); + } + + /// Direct DB function to directly mark a user as knocked. It is not + /// recommended to use this directly. You most likely should use + /// `update_membership` instead + #[tracing::instrument(skip(self), level = "debug")] + pub fn mark_as_knocked( + &self, + user_id: &UserId, + room_id: &RoomId, + knocked_state: Option>>, + ) { + let userroom_id = (user_id, room_id); + let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); + + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id"); + + self.db + .userroomid_knockedstate + .raw_put(&userroom_id, Json(knocked_state.unwrap_or_default())); + self.db + .roomuserid_knockedcount + .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); + + self.db.userroomid_joined.remove(&userroom_id); + self.db.roomuserid_joined.remove(&roomuser_id); + + self.db.userroomid_invitestate.remove(&userroom_id); + self.db.roomuserid_invitecount.remove(&roomuser_id); + + self.db.userroomid_leftstate.remove(&userroom_id); + self.db.roomuserid_leftcount.remove(&roomuser_id); + self.db.roomid_inviteviaservers.remove(room_id); } @@ -528,6 +574,20 @@ impl Service { .map(|(_, user_id): (Ignore, &UserId)| user_id) } + /// Returns an iterator over all knocked members of a room. + #[tracing::instrument(skip(self), level = "debug")] + pub fn room_members_knocked<'a>( + &'a self, + room_id: &'a RoomId, + ) -> impl Stream + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuserid_knockedcount + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) + } + #[tracing::instrument(skip(self), level = "trace")] pub async fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { let key = (room_id, user_id); @@ -538,6 +598,16 @@ impl Service { .deserialized() } + #[tracing::instrument(skip(self), level = "trace")] + pub async fn get_knock_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.db + .roomuserid_knockedcount + .qry(&key) + .await + .deserialized() + } + #[tracing::instrument(skip(self), level = "trace")] pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { let key = (room_id, user_id); @@ -576,6 +646,25 @@ impl Service { .ignore_err() } + /// Returns an iterator over all rooms a user is currently knocking. + #[tracing::instrument(skip(self), level = "trace")] + pub fn rooms_knocked<'a>( + &'a self, + user_id: &'a UserId, + ) -> impl Stream + Send + 'a { + type KeyVal<'a> = (Key<'a>, Raw>); + type Key<'a> = (&'a UserId, &'a RoomId); + + let prefix = (user_id, Interfix); + self.db + .userroomid_knockedstate + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) + .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) + .ignore_err() + } + #[tracing::instrument(skip(self), level = "trace")] pub async fn invite_state( &self, @@ -593,6 +682,23 @@ impl Service { }) } + #[tracing::instrument(skip(self), level = "trace")] + pub async fn knock_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result>> { + let key = (user_id, room_id); + self.db + .userroomid_knockedstate + .qry(&key) + .await + .deserialized() + .and_then(|val: Raw>| { + val.deserialize_as().map_err(Into::into) + }) + } + #[tracing::instrument(skip(self), level = "trace")] pub async fn left_state( &self, @@ -641,6 +747,12 @@ impl Service { self.db.userroomid_joined.qry(&key).await.is_ok() } + #[tracing::instrument(skip(self), level = "trace")] + pub async fn is_knocked<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_knockedstate.qry(&key).await.is_ok() + } + #[tracing::instrument(skip(self), level = "trace")] pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool { let key = (user_id, room_id); @@ -659,9 +771,10 @@ impl Service { user_id: &UserId, room_id: &RoomId, ) -> Option { - let states = join4( + let states = join5( self.is_joined(user_id, room_id), self.is_left(user_id, room_id), + self.is_knocked(user_id, room_id), self.is_invited(user_id, room_id), self.once_joined(user_id, room_id), ) @@ -670,8 +783,9 @@ impl Service { match states { | (true, ..) => Some(MembershipState::Join), | (_, true, ..) => Some(MembershipState::Leave), - | (_, _, true, ..) => Some(MembershipState::Invite), - | (false, false, false, true) => Some(MembershipState::Ban), + | (_, _, true, ..) => Some(MembershipState::Knock), + | (_, _, _, true, ..) => Some(MembershipState::Invite), + | (false, false, false, false, true) => Some(MembershipState::Ban), | _ => None, } } @@ -747,6 +861,7 @@ impl Service { pub async fn update_joined_count(&self, room_id: &RoomId) { let mut joinedcount = 0_u64; let mut invitedcount = 0_u64; + let mut knockedcount = 0_u64; let mut joined_servers = HashSet::new(); self.room_members(room_id) @@ -764,8 +879,19 @@ impl Service { .unwrap_or(0), ); + knockedcount = knockedcount.saturating_add( + self.room_members_knocked(room_id) + .count() + .await + .try_into() + .unwrap_or(0), + ); + self.db.roomid_joinedcount.raw_put(room_id, joinedcount); self.db.roomid_invitedcount.raw_put(room_id, invitedcount); + self.db + .roomuserid_knockedcount + .raw_put(room_id, knockedcount); self.room_servers(room_id) .ready_for_each(|old_joined_server| { @@ -820,7 +946,6 @@ impl Service { self.db .userroomid_invitestate .raw_put(&userroom_id, Json(last_state.unwrap_or_default())); - self.db .roomuserid_invitecount .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); @@ -831,6 +956,9 @@ impl Service { self.db.userroomid_leftstate.remove(&userroom_id); self.db.roomuserid_leftcount.remove(&roomuser_id); + self.db.userroomid_knockedstate.remove(&userroom_id); + self.db.roomuserid_knockedcount.remove(&roomuser_id); + if let Some(servers) = invite_via.filter(is_not_empty!()) { self.add_servers_invite_via(room_id, servers).await; } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index fe7f885a..3ebc432f 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -498,14 +498,15 @@ impl Service { .expect("This state_key was previously validated"); let content: RoomMemberEventContent = pdu.get_content()?; - let invite_state = match content.membership { - | MembershipState::Invite => + let stripped_state = match content.membership { + | MembershipState::Invite | MembershipState::Knock => self.services.state.summary_stripped(pdu).await.into(), | _ => None, }; - // Update our membership info, we do this here incase a user is invited - // and immediately leaves we need the DB to record the invite event for auth + // Update our membership info, we do this here incase a user is invited or + // knocked and immediately leaves we need the DB to record the invite or + // knock event for auth self.services .state_cache .update_membership( @@ -513,7 +514,7 @@ impl Service { target_user_id, content, &pdu.sender, - invite_state, + stripped_state, None, true, )