feat: incoming invites over federation

This commit is contained in:
Timo Kösters 2021-04-11 21:01:27 +02:00
parent b0ea692706
commit 8773e5013d
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
10 changed files with 307 additions and 146 deletions

View file

@ -11,10 +11,10 @@ use ruma::{
events::{
ignored_user_list,
room::{create::CreateEventContent, member, message},
EventType,
AnyStrippedStateEvent, EventType,
},
serde::{to_canonical_value, CanonicalJsonObject, CanonicalJsonValue, Raw},
EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId,
uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId,
};
use sled::IVec;
use state_res::{Event, StateMap};
@ -51,8 +51,8 @@ pub struct Rooms {
pub(super) userroomid_joined: sled::Tree,
pub(super) roomuserid_joined: sled::Tree,
pub(super) roomuseroncejoinedids: sled::Tree,
pub(super) userroomid_invited: sled::Tree,
pub(super) roomuserid_invited: sled::Tree,
pub(super) userroomid_invitestate: sled::Tree,
pub(super) roomuserid_invitecount: sled::Tree,
pub(super) userroomid_left: sled::Tree,
/// Remember the current state hash of a room.
@ -145,12 +145,12 @@ impl Rooms {
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))]
pub fn state_get(
pub fn state_get_id(
&self,
shortstatehash: u64,
event_type: &EventType,
state_key: &str,
) -> Result<Option<PduEvent>> {
) -> Result<Option<EventId>> {
let mut key = event_type.as_ref().as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(&state_key.as_bytes());
@ -161,7 +161,8 @@ impl Rooms {
let mut stateid = shortstatehash.to_be_bytes().to_vec();
stateid.extend_from_slice(&shortstatekey);
self.stateid_shorteventid
Ok(self
.stateid_shorteventid
.get(&stateid)?
.map(|bytes| self.shorteventid_eventid.get(&bytes).ok().flatten())
.flatten()
@ -178,13 +179,24 @@ impl Rooms {
)
})
.map(|r| r.ok())
.flatten()
.map_or(Ok(None), |event_id| self.get_pdu(&event_id))
.flatten())
} else {
Ok(None)
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))]
pub fn state_get(
&self,
shortstatehash: u64,
event_type: &EventType,
state_key: &str,
) -> Result<Option<PduEvent>> {
self.state_get_id(shortstatehash, event_type, state_key)?
.map_or(Ok(None), |event_id| self.get_pdu(&event_id))
}
/// Returns the state hash for this pdu.
#[tracing::instrument(skip(self))]
pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
@ -354,6 +366,21 @@ impl Rooms {
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))]
pub fn room_state_get_id(
&self,
room_id: &RoomId,
event_type: &EventType,
state_key: &str,
) -> Result<Option<EventId>> {
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? {
self.state_get_id(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))]
pub fn room_state_get(
@ -395,7 +422,7 @@ impl Rooms {
}
/// Returns the json of a pdu.
pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<serde_json::Value>> {
pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map_or_else::<Result<_>, _, _>(
@ -666,29 +693,64 @@ impl Rooms {
// if the state_key fails
let target_user_id = UserId::try_from(state_key.clone())
.expect("This state_key was previously validated");
let membership = serde_json::from_value::<member::MembershipState>(
pdu.content
.get("membership")
.ok_or_else(|| {
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid member event content",
)
})?
.clone(),
)
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid membership state content.",
)
})?;
let invite_state = match membership {
member::MembershipState::Invite => {
let mut state = Vec::new();
// Add recommended events
if let Some(e) =
self.room_state_get(&pdu.room_id, &EventType::RoomJoinRules, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) = self.room_state_get(
&pdu.room_id,
&EventType::RoomCanonicalAlias,
"",
)? {
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.room_state_get(&pdu.room_id, &EventType::RoomAvatar, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.room_state_get(&pdu.room_id, &EventType::RoomName, "")?
{
state.push(e.to_stripped_state_event());
}
Some(state)
}
_ => None,
};
// Update our membership info, we do this here incase a user is invited
// and immediately leaves we need the DB to record the invite event for auth
self.update_membership(
&pdu.room_id,
&target_user_id,
serde_json::from_value::<member::MembershipState>(
pdu.content
.get("membership")
.ok_or_else(|| {
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid member event content",
)
})?
.clone(),
)
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid membership state content.",
)
})?,
membership,
&pdu.sender,
invite_state,
&db.account_data,
&db.globals,
)?;
@ -1044,10 +1106,10 @@ impl Rooms {
// Our depth is the maximum depth of prev_events + 1
let depth = prev_events
.iter()
.filter_map(|event_id| Some(self.get_pdu_json(event_id).ok()??.get("depth")?.as_u64()?))
.filter_map(|event_id| Some(self.get_pdu(event_id).ok()??.depth))
.max()
.unwrap_or(0_u64)
+ 1;
.unwrap_or(uint!(0))
+ uint!(1);
let mut unsigned = unsigned.unwrap_or_default();
if let Some(state_key) = &state_key {
@ -1071,9 +1133,7 @@ impl Rooms {
content,
state_key,
prev_events,
depth: depth
.try_into()
.map_err(|_| Error::bad_database("Depth is invalid"))?,
depth,
auth_events: auth_events
.iter()
.map(|(_, pdu)| pdu.event_id.clone())
@ -1384,6 +1444,7 @@ impl Rooms {
user_id: &UserId,
membership: member::MembershipState,
sender: &UserId,
invite_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
account_data: &super::account_data::AccountData,
globals: &super::globals::Globals,
) -> Result<()> {
@ -1487,8 +1548,8 @@ impl Rooms {
self.roomserverids.insert(&roomserver_id, &[])?;
self.userroomid_joined.insert(&userroom_id, &[])?;
self.roomuserid_joined.insert(&roomuser_id, &[])?;
self.userroomid_invited.remove(&userroom_id)?;
self.roomuserid_invited.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
self.userroomid_left.remove(&userroom_id)?;
}
member::MembershipState::Invite => {
@ -1508,8 +1569,13 @@ impl Rooms {
}
self.roomserverids.insert(&roomserver_id, &[])?;
self.userroomid_invited.insert(&userroom_id, &[])?;
self.roomuserid_invited.insert(&roomuser_id, &[])?;
self.userroomid_invitestate.insert(
&userroom_id,
serde_json::to_vec(&invite_state.unwrap_or_default())
.expect("state to bytes always works"),
)?;
self.roomuserid_invitecount
.insert(&roomuser_id, &globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_left.remove(&userroom_id)?;
@ -1526,8 +1592,8 @@ impl Rooms {
self.userroomid_left.insert(&userroom_id, &[])?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invited.remove(&userroom_id)?;
self.roomuserid_invited.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
}
_ => {}
}
@ -1797,7 +1863,7 @@ impl Rooms {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
self.roomuserid_invited
self.roomuserid_invitecount
.scan_prefix(prefix)
.keys()
.map(|key| {
@ -1816,6 +1882,22 @@ impl Rooms {
})
}
/// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self))]
pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_invitecount
.get(key)?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid invitecount in db.")
})?))
})
}
/// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self))]
pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<RoomId>> {
@ -1840,27 +1922,32 @@ impl Rooms {
/// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self))]
pub fn rooms_invited(&self, user_id: &UserId) -> impl Iterator<Item = Result<RoomId>> {
pub fn rooms_invited(
&self,
user_id: &UserId,
) -> impl Iterator<Item = Result<(RoomId, Vec<Raw<AnyStrippedStateEvent>>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
self.userroomid_invited
.scan_prefix(prefix)
.keys()
.map(|key| {
Ok(RoomId::try_from(
utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
})?,
self.userroomid_invitestate.scan_prefix(prefix).map(|r| {
let (key, state) = r?;
let room_id = RoomId::try_from(
utils::string_from_bytes(
&key.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?)
})
.map_err(|_| {
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok((room_id, state))
})
}
/// Returns an iterator over all rooms a user left.
@ -1906,7 +1993,7 @@ impl Rooms {
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_invited.get(userroom_id)?.is_some())
Ok(self.userroomid_invitestate.get(userroom_id)?.is_some())
}
pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {