knocking implementation

Signed-off-by: strawberry <strawberry@puppygock.gay>

add sync bit of knocking

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2025-01-11 18:43:54 -05:00
parent fabd3cf567
commit 5a1c41e66b
14 changed files with 978 additions and 117 deletions

View file

@ -10,7 +10,7 @@ use conduwuit::{
warn, Result,
};
use database::{serialize_key, Deserialized, Ignore, Interfix, Json, Map};
use futures::{future::join4, pin_mut, stream::iter, Stream, StreamExt};
use futures::{future::join5, pin_mut, stream::iter, Stream, StreamExt};
use itertools::Itertools;
use ruma::{
events::{
@ -51,11 +51,13 @@ struct Data {
roomuserid_invitecount: Arc<Map>,
roomuserid_joined: Arc<Map>,
roomuserid_leftcount: Arc<Map>,
roomuserid_knockedcount: Arc<Map>,
roomuseroncejoinedids: Arc<Map>,
serverroomids: Arc<Map>,
userroomid_invitestate: Arc<Map>,
userroomid_joined: Arc<Map>,
userroomid_leftstate: Arc<Map>,
userroomid_knockedstate: Arc<Map>,
}
type AppServiceInRoomCache = RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>;
@ -81,11 +83,13 @@ impl crate::Service for Service {
roomuserid_invitecount: args.db["roomuserid_invitecount"].clone(),
roomuserid_joined: args.db["roomuserid_joined"].clone(),
roomuserid_leftcount: args.db["roomuserid_leftcount"].clone(),
roomuserid_knockedcount: args.db["roomuserid_knockedcount"].clone(),
roomuseroncejoinedids: args.db["roomuseroncejoinedids"].clone(),
serverroomids: args.db["serverroomids"].clone(),
userroomid_invitestate: args.db["userroomid_invitestate"].clone(),
userroomid_joined: args.db["userroomid_joined"].clone(),
userroomid_leftstate: args.db["userroomid_leftstate"].clone(),
userroomid_knockedstate: args.db["userroomid_knockedstate"].clone(),
},
}))
}
@ -336,6 +340,9 @@ impl Service {
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
self.db.userroomid_knockedstate.remove(&userroom_id);
self.db.roomuserid_knockedcount.remove(&roomuser_id);
self.db.roomid_inviteviaservers.remove(room_id);
}
@ -352,12 +359,13 @@ impl Service {
// (timo) TODO
let leftstate = Vec::<Raw<AnySyncStateEvent>>::new();
let count = self.services.globals.next_count().unwrap();
self.db
.userroomid_leftstate
.raw_put(&userroom_id, Json(leftstate));
self.db.roomuserid_leftcount.raw_put(&roomuser_id, count);
self.db
.roomuserid_leftcount
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap());
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
@ -365,6 +373,44 @@ impl Service {
self.db.userroomid_invitestate.remove(&userroom_id);
self.db.roomuserid_invitecount.remove(&roomuser_id);
self.db.userroomid_knockedstate.remove(&userroom_id);
self.db.roomuserid_knockedcount.remove(&roomuser_id);
self.db.roomid_inviteviaservers.remove(room_id);
}
/// Direct DB function to directly mark a user as knocked. It is not
/// recommended to use this directly. You most likely should use
/// `update_membership` instead
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_as_knocked(
&self,
user_id: &UserId,
room_id: &RoomId,
knocked_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
) {
let userroom_id = (user_id, room_id);
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
let roomuser_id = (room_id, user_id);
let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
self.db
.userroomid_knockedstate
.raw_put(&userroom_id, Json(knocked_state.unwrap_or_default()));
self.db
.roomuserid_knockedcount
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap());
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
self.db.userroomid_invitestate.remove(&userroom_id);
self.db.roomuserid_invitecount.remove(&roomuser_id);
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
self.db.roomid_inviteviaservers.remove(room_id);
}
@ -528,6 +574,20 @@ impl Service {
.map(|(_, user_id): (Ignore, &UserId)| user_id)
}
/// Returns an iterator over all knocked members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_members_knocked<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_knockedcount
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, user_id): (Ignore, &UserId)| user_id)
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
@ -538,6 +598,16 @@ impl Service {
.deserialized()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn get_knock_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
self.db
.roomuserid_knockedcount
.qry(&key)
.await
.deserialized()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
@ -576,6 +646,25 @@ impl Service {
.ignore_err()
}
/// Returns an iterator over all rooms a user is currently knocking.
#[tracing::instrument(skip(self), level = "trace")]
pub fn rooms_knocked<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a {
type KeyVal<'a> = (Key<'a>, Raw<Vec<AnyStrippedStateEvent>>);
type Key<'a> = (&'a UserId, &'a RoomId);
let prefix = (user_id, Interfix);
self.db
.userroomid_knockedstate
.stream_prefix(&prefix)
.ignore_err()
.map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state))
.map(|(room_id, state)| Ok((room_id, state.deserialize_as()?)))
.ignore_err()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn invite_state(
&self,
@ -593,6 +682,23 @@ impl Service {
})
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn knock_state(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
let key = (user_id, room_id);
self.db
.userroomid_knockedstate
.qry(&key)
.await
.deserialized()
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| {
val.deserialize_as().map_err(Into::into)
})
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn left_state(
&self,
@ -641,6 +747,12 @@ impl Service {
self.db.userroomid_joined.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn is_knocked<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_knockedstate.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
@ -659,9 +771,10 @@ impl Service {
user_id: &UserId,
room_id: &RoomId,
) -> Option<MembershipState> {
let states = join4(
let states = join5(
self.is_joined(user_id, room_id),
self.is_left(user_id, room_id),
self.is_knocked(user_id, room_id),
self.is_invited(user_id, room_id),
self.once_joined(user_id, room_id),
)
@ -670,8 +783,9 @@ impl Service {
match states {
| (true, ..) => Some(MembershipState::Join),
| (_, true, ..) => Some(MembershipState::Leave),
| (_, _, true, ..) => Some(MembershipState::Invite),
| (false, false, false, true) => Some(MembershipState::Ban),
| (_, _, true, ..) => Some(MembershipState::Knock),
| (_, _, _, true, ..) => Some(MembershipState::Invite),
| (false, false, false, false, true) => Some(MembershipState::Ban),
| _ => None,
}
}
@ -747,6 +861,7 @@ impl Service {
pub async fn update_joined_count(&self, room_id: &RoomId) {
let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64;
let mut knockedcount = 0_u64;
let mut joined_servers = HashSet::new();
self.room_members(room_id)
@ -764,8 +879,19 @@ impl Service {
.unwrap_or(0),
);
knockedcount = knockedcount.saturating_add(
self.room_members_knocked(room_id)
.count()
.await
.try_into()
.unwrap_or(0),
);
self.db.roomid_joinedcount.raw_put(room_id, joinedcount);
self.db.roomid_invitedcount.raw_put(room_id, invitedcount);
self.db
.roomuserid_knockedcount
.raw_put(room_id, knockedcount);
self.room_servers(room_id)
.ready_for_each(|old_joined_server| {
@ -820,7 +946,6 @@ impl Service {
self.db
.userroomid_invitestate
.raw_put(&userroom_id, Json(last_state.unwrap_or_default()));
self.db
.roomuserid_invitecount
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap());
@ -831,6 +956,9 @@ impl Service {
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
self.db.userroomid_knockedstate.remove(&userroom_id);
self.db.roomuserid_knockedcount.remove(&roomuser_id);
if let Some(servers) = invite_via.filter(is_not_empty!()) {
self.add_servers_invite_via(room_id, servers).await;
}

View file

@ -498,14 +498,15 @@ impl Service {
.expect("This state_key was previously validated");
let content: RoomMemberEventContent = pdu.get_content()?;
let invite_state = match content.membership {
| MembershipState::Invite =>
let stripped_state = match content.membership {
| MembershipState::Invite | MembershipState::Knock =>
self.services.state.summary_stripped(pdu).await.into(),
| _ => None,
};
// Update our membership info, we do this here incase a user is invited
// and immediately leaves we need the DB to record the invite event for auth
// Update our membership info, we do this here incase a user is invited or
// knocked and immediately leaves we need the DB to record the invite or
// knock event for auth
self.services
.state_cache
.update_membership(
@ -513,7 +514,7 @@ impl Service {
target_user_id,
content,
&pdu.sender,
invite_state,
stripped_state,
None,
true,
)