apply new rustfmt.toml changes, fix some clippy lints

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2024-12-15 00:05:47 -05:00
parent 0317cc8cc5
commit 77e0b76408
No known key found for this signature in database
296 changed files with 7147 additions and 4300 deletions

View file

@ -52,7 +52,8 @@ impl crate::Service for Service {
appservice: args.depend::<appservice::Service>("appservice"),
globals: args.depend::<globals::Service>("globals"),
sending: args.depend::<sending::Service>("sending"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
},
}))
}
@ -62,8 +63,15 @@ impl crate::Service for Service {
impl Service {
#[tracing::instrument(skip(self))]
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> {
if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user {
pub fn set_alias(
&self,
alias: &RoomAliasId,
room_id: &RoomId,
user_id: &UserId,
) -> Result<()> {
if alias == self.services.globals.admin_alias
&& user_id != self.services.globals.server_user
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Only the server user can set this alias",
@ -120,7 +128,9 @@ impl Service {
}
pub async fn resolve_with_servers(
&self, room: &RoomOrAliasId, servers: Option<Vec<OwnedServerName>>,
&self,
room: &RoomOrAliasId,
servers: Option<Vec<OwnedServerName>>,
) -> Result<(OwnedRoomId, Vec<OwnedServerName>)> {
if room.is_room_id() {
let room_id = RoomId::parse(room).expect("valid RoomId");
@ -133,14 +143,16 @@ impl Service {
#[tracing::instrument(skip(self), name = "resolve")]
pub async fn resolve_alias(
&self, room_alias: &RoomAliasId, servers: Option<Vec<OwnedServerName>>,
&self,
room_alias: &RoomAliasId,
servers: Option<Vec<OwnedServerName>>,
) -> Result<(OwnedRoomId, Vec<OwnedServerName>)> {
let server_name = room_alias.server_name();
let server_is_ours = self.services.globals.server_is_ours(server_name);
let servers_contains_ours = || {
servers
.as_ref()
.is_some_and(|servers| servers.contains(&self.services.globals.config.server_name))
servers.as_ref().is_some_and(|servers| {
servers.contains(&self.services.globals.config.server_name)
})
};
if !server_is_ours && !servers_contains_ours() {
@ -150,8 +162,8 @@ impl Service {
}
let room_id = match self.resolve_local_alias(room_alias).await {
Ok(r) => Some(r),
Err(_) => self.resolve_appservice_alias(room_alias).await?,
| Ok(r) => Some(r),
| Err(_) => self.resolve_appservice_alias(room_alias).await?,
};
room_id.map_or_else(
@ -166,7 +178,10 @@ impl Service {
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn local_aliases_for_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &RoomAliasId> + Send + 'a {
pub fn local_aliases_for_room<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &RoomAliasId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.aliasid_alias
@ -208,10 +223,15 @@ impl Service {
if let Ok(content) = self
.services
.state_accessor
.room_state_get_content::<RoomPowerLevelsEventContent>(&room_id, &StateEventType::RoomPowerLevels, "")
.room_state_get_content::<RoomPowerLevelsEventContent>(
&room_id,
&StateEventType::RoomPowerLevels,
"",
)
.await
{
return Ok(RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomCanonicalAlias));
return Ok(RoomPowerLevels::from(content)
.user_can_send_state(user_id, StateEventType::RoomCanonicalAlias));
}
// If there is no power levels event, only the room creator can change
@ -232,7 +252,10 @@ impl Service {
self.db.alias_userid.get(alias.alias()).await.deserialized()
}
async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
async fn resolve_appservice_alias(
&self,
room_alias: &RoomAliasId,
) -> Result<Option<OwnedRoomId>> {
use ruma::api::appservice::query::query_room_alias;
for appservice in self.services.appservice.read().await.values() {
@ -242,9 +265,7 @@ impl Service {
.sending
.send_appservice_request(
appservice.registration.clone(),
query_room_alias::v1::Request {
room_alias: room_alias.to_owned(),
},
query_room_alias::v1::Request { room_alias: room_alias.to_owned() },
)
.await,
Ok(Some(_opt_result))
@ -261,19 +282,27 @@ impl Service {
}
pub async fn appservice_checks(
&self, room_alias: &RoomAliasId, appservice_info: &Option<RegistrationInfo>,
&self,
room_alias: &RoomAliasId,
appservice_info: &Option<RegistrationInfo>,
) -> Result<()> {
if !self
.services
.globals
.server_is_ours(room_alias.server_name())
{
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Alias is from another server.",
));
}
if let Some(ref info) = appservice_info {
if !info.aliases.is_match(room_alias.as_str()) {
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace."));
return Err(Error::BadRequest(
ErrorKind::Exclusive,
"Room alias is not in namespace.",
));
}
} else if self
.services
@ -281,7 +310,10 @@ impl Service {
.is_exclusive_alias(room_alias)
.await
{
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias reserved by appservice."));
return Err(Error::BadRequest(
ErrorKind::Exclusive,
"Room alias reserved by appservice.",
));
}
Ok(())

View file

@ -6,7 +6,9 @@ use ruma::{api::federation, OwnedRoomId, OwnedServerName, RoomAliasId, ServerNam
#[implement(super::Service)]
pub(super) async fn remote_resolve(
&self, room_alias: &RoomAliasId, servers: Vec<OwnedServerName>,
&self,
room_alias: &RoomAliasId,
servers: Vec<OwnedServerName>,
) -> Result<(OwnedRoomId, Vec<OwnedServerName>)> {
debug!(?room_alias, servers = ?servers, "resolve");
let servers = once(room_alias.server_name())
@ -17,12 +19,12 @@ pub(super) async fn remote_resolve(
let mut resolved_room_id: Option<OwnedRoomId> = None;
for server in servers {
match self.remote_request(room_alias, &server).await {
Err(e) => debug_error!("Failed to query for {room_alias:?} from {server}: {e}"),
Ok(Response {
room_id,
servers,
}) => {
debug!("Server {server} answered with {room_id:?} for {room_alias:?} servers: {servers:?}");
| Err(e) => debug_error!("Failed to query for {room_alias:?} from {server}: {e}"),
| Ok(Response { room_id, servers }) => {
debug!(
"Server {server} answered with {room_id:?} for {room_alias:?} servers: \
{servers:?}"
);
resolved_room_id.get_or_insert(room_id);
add_server(&mut resolved_servers, server);
@ -37,16 +39,20 @@ pub(super) async fn remote_resolve(
resolved_room_id
.map(|room_id| (room_id, resolved_servers))
.ok_or_else(|| err!(Request(NotFound("No servers could assist in resolving the room alias"))))
.ok_or_else(|| {
err!(Request(NotFound("No servers could assist in resolving the room alias")))
})
}
#[implement(super::Service)]
async fn remote_request(&self, room_alias: &RoomAliasId, server: &ServerName) -> Result<Response> {
async fn remote_request(
&self,
room_alias: &RoomAliasId,
server: &ServerName,
) -> Result<Response> {
use federation::query::get_room_information::v1::Request;
let request = Request {
room_alias: room_alias.to_owned(),
};
let request = Request { room_alias: room_alias.to_owned() };
self.services
.sending

View file

@ -19,14 +19,18 @@ impl Data {
let db = &args.db;
let config = &args.server.config;
let cache_size = f64::from(config.auth_chain_cache_capacity);
let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier).expect("valid cache size");
let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier)
.expect("valid cache size");
Self {
shorteventid_authchain: db["shorteventid_authchain"].clone(),
auth_chain_cache: Mutex::new(LruCache::new(cache_size)),
}
}
pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[ShortEventId]>> {
pub(super) async fn get_cached_eventid_authchain(
&self,
key: &[u64],
) -> Result<Arc<[ShortEventId]>> {
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
// Check RAM cache

View file

@ -43,7 +43,9 @@ impl crate::Service for Service {
impl Service {
pub async fn event_ids_iter<'a, I>(
&'a self, room_id: &RoomId, starting_events: I,
&'a self,
room_id: &RoomId,
starting_events: I,
) -> Result<impl Stream<Item = Arc<EventId>> + Send + '_>
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
@ -57,7 +59,11 @@ impl Service {
Ok(stream)
}
pub async fn get_event_ids<'a, I>(&'a self, room_id: &RoomId, starting_events: I) -> Result<Vec<Arc<EventId>>>
pub async fn get_event_ids<'a, I>(
&'a self,
room_id: &RoomId,
starting_events: I,
) -> Result<Vec<Arc<EventId>>>
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{
@ -74,7 +80,11 @@ impl Service {
}
#[tracing::instrument(skip_all, name = "auth_chain")]
pub async fn get_auth_chain<'a, I>(&'a self, room_id: &RoomId, starting_events: I) -> Result<Vec<ShortEventId>>
pub async fn get_auth_chain<'a, I>(
&'a self,
room_id: &RoomId,
starting_events: I,
) -> Result<Vec<ShortEventId>>
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{
@ -110,7 +120,8 @@ impl Service {
continue;
}
let chunk_key: Vec<ShortEventId> = chunk.iter().map(|(short, _)| short).copied().collect();
let chunk_key: Vec<ShortEventId> =
chunk.iter().map(|(short, _)| short).copied().collect();
if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await {
trace!("Found cache entry for whole chunk");
full_auth_chain.extend(cached.iter().copied());
@ -169,7 +180,11 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id))]
async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<ShortEventId>> {
async fn get_auth_chain_inner(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<HashSet<ShortEventId>> {
let mut todo = vec![Arc::from(event_id)];
let mut found = HashSet::new();
@ -177,8 +192,10 @@ impl Service {
trace!(?event_id, "processing auth event");
match self.services.timeline.get_pdu(&event_id).await {
Err(e) => debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events"),
Ok(pdu) => {
| Err(e) => {
debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events");
},
| Ok(pdu) => {
if pdu.room_id != room_id {
return Err!(Request(Forbidden(error!(
?event_id,
@ -196,7 +213,11 @@ impl Service {
.await;
if found.insert(sauthevent) {
trace!(?event_id, ?auth_event, "adding auth event to processing queue");
trace!(
?event_id,
?auth_event,
"adding auth event to processing queue"
);
todo.push(auth_event.clone());
}
}

View file

@ -32,10 +32,14 @@ pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_i
pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id); }
#[implement(Service)]
pub fn public_rooms(&self) -> impl Stream<Item = &RoomId> + Send { self.db.publicroomids.keys().ignore_err() }
pub fn public_rooms(&self) -> impl Stream<Item = &RoomId> + Send {
self.db.publicroomids.keys().ignore_err()
}
#[implement(Service)]
pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.visibility(room_id).await == Visibility::Public }
pub async fn is_public_room(&self, room_id: &RoomId) -> bool {
self.visibility(room_id).await == Visibility::Public
}
#[implement(Service)]
pub async fn visibility(&self, room_id: &RoomId) -> Visibility {

View file

@ -5,10 +5,14 @@ use std::{
};
use conduwuit::{
debug, debug_error, implement, info, pdu, trace, utils::math::continue_exponential_backoff_secs, warn, PduEvent,
debug, debug_error, implement, info, pdu, trace,
utils::math::continue_exponential_backoff_secs, warn, PduEvent,
};
use futures::TryFutureExt;
use ruma::{api::federation::event::get_event, CanonicalJsonValue, EventId, RoomId, RoomVersionId, ServerName};
use ruma::{
api::federation::event::get_event, CanonicalJsonValue, EventId, RoomId, RoomVersionId,
ServerName,
};
/// Find the event and auth it. Once the event is validated (steps 1 - 8)
/// it is appended to the outliers Tree.
@ -21,7 +25,11 @@ use ruma::{api::federation::event::get_event, CanonicalJsonValue, EventId, RoomI
/// d. TODO: Ask other servers over federation?
#[implement(super::Service)]
pub(super) async fn fetch_and_handle_outliers<'a>(
&self, origin: &'a ServerName, events: &'a [Arc<EventId>], create_event: &'a PduEvent, room_id: &'a RoomId,
&self,
origin: &'a ServerName,
events: &'a [Arc<EventId>],
create_event: &'a PduEvent,
room_id: &'a RoomId,
room_version_id: &'a RoomVersionId,
) -> Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)> {
let back_off = |id| match self
@ -32,10 +40,12 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
.expect("locked")
.entry(id)
{
hash_map::Entry::Vacant(e) => {
| hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)),
| hash_map::Entry::Occupied(mut e) => {
*e.get_mut() = (Instant::now(), e.get().1.saturating_add(1));
},
};
let mut events_with_auth_events = Vec::with_capacity(events.len());
@ -67,7 +77,12 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
// Exponential backoff
const MIN_DURATION: u64 = 5 * 60;
const MAX_DURATION: u64 = 60 * 60 * 24;
if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) {
if continue_exponential_backoff_secs(
MIN_DURATION,
MAX_DURATION,
time.elapsed(),
*tries,
) {
info!("Backing off from {next_id}");
continue;
}
@ -86,18 +101,16 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
match self
.services
.sending
.send_federation_request(
origin,
get_event::v1::Request {
event_id: (*next_id).to_owned(),
include_unredacted_content: None,
},
)
.send_federation_request(origin, get_event::v1::Request {
event_id: (*next_id).to_owned(),
include_unredacted_content: None,
})
.await
{
Ok(res) => {
| Ok(res) => {
debug!("Got {next_id} over federation");
let Ok((calculated_event_id, value)) = pdu::gen_event_id_canonical_json(&res.pdu, room_version_id)
let Ok((calculated_event_id, value)) =
pdu::gen_event_id_canonical_json(&res.pdu, room_version_id)
else {
back_off((*next_id).to_owned());
continue;
@ -105,15 +118,18 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
if calculated_event_id != *next_id {
warn!(
"Server didn't return event id we requested: requested: {next_id}, we got \
{calculated_event_id}. Event: {:?}",
"Server didn't return event id we requested: requested: {next_id}, \
we got {calculated_event_id}. Event: {:?}",
&res.pdu
);
}
if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) {
if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array())
{
for auth_event in auth_events {
if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) {
if let Ok(auth_event) =
serde_json::from_value(auth_event.clone().into())
{
let a: Arc<EventId> = auth_event;
todo_auth_events.push(a);
} else {
@ -127,7 +143,7 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
events_in_reverse_order.push((next_id.clone(), value));
events_all.insert(next_id);
},
Err(e) => {
| Err(e) => {
debug_error!("Failed to fetch event {next_id}: {e}");
back_off((*next_id).to_owned());
},
@ -158,20 +174,32 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
// Exponential backoff
const MIN_DURATION: u64 = 5 * 60;
const MAX_DURATION: u64 = 60 * 60 * 24;
if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) {
if continue_exponential_backoff_secs(
MIN_DURATION,
MAX_DURATION,
time.elapsed(),
*tries,
) {
debug!("Backing off from {next_id}");
continue;
}
}
match Box::pin(self.handle_outlier_pdu(origin, create_event, &next_id, room_id, value.clone(), true)).await
match Box::pin(self.handle_outlier_pdu(
origin,
create_event,
&next_id,
room_id,
value.clone(),
true,
))
.await
{
Ok((pdu, json)) => {
| Ok((pdu, json)) =>
if next_id == *id {
pdus.push((pdu, Some(json)));
}
},
Err(e) => {
},
| Err(e) => {
warn!("Authentication of event {next_id} failed: {e:?}");
back_off(next_id.into());
},

View file

@ -8,7 +8,8 @@ use futures::{future, FutureExt};
use ruma::{
int,
state_res::{self},
uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, ServerName,
uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId,
ServerName,
};
use super::check_room_id;
@ -17,7 +18,11 @@ use super::check_room_id;
#[allow(clippy::type_complexity)]
#[tracing::instrument(skip_all)]
pub(super) async fn fetch_prev(
&self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId,
&self,
origin: &ServerName,
create_event: &PduEvent,
room_id: &RoomId,
room_version_id: &RoomVersionId,
initial_set: Vec<Arc<EventId>>,
) -> Result<(
Vec<Arc<EventId>>,
@ -35,7 +40,13 @@ pub(super) async fn fetch_prev(
self.services.server.check_running()?;
if let Some((pdu, mut json_opt)) = self
.fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id, room_version_id)
.fetch_and_handle_outliers(
origin,
&[prev_event_id.clone()],
create_event,
room_id,
room_version_id,
)
.boxed()
.await
.pop()
@ -67,7 +78,8 @@ pub(super) async fn fetch_prev(
}
}
graph.insert(prev_event_id.clone(), pdu.prev_events.iter().cloned().collect());
graph
.insert(prev_event_id.clone(), pdu.prev_events.iter().cloned().collect());
} else {
// Time based check failed
graph.insert(prev_event_id.clone(), HashSet::new());

View file

@ -6,7 +6,8 @@ use std::{
use conduwuit::{debug, implement, warn, Err, Error, PduEvent, Result};
use futures::FutureExt;
use ruma::{
api::federation::event::get_room_state_ids, events::StateEventType, EventId, RoomId, RoomVersionId, ServerName,
api::federation::event::get_room_state_ids, events::StateEventType, EventId, RoomId,
RoomVersionId, ServerName,
};
/// Call /state_ids to find out what the state at this pdu is. We trust the
@ -15,20 +16,21 @@ use ruma::{
#[implement(super::Service)]
#[tracing::instrument(skip(self, create_event, room_version_id))]
pub(super) async fn fetch_state(
&self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId,
&self,
origin: &ServerName,
create_event: &PduEvent,
room_id: &RoomId,
room_version_id: &RoomVersionId,
event_id: &EventId,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
debug!("Fetching state ids");
let res = self
.services
.sending
.send_federation_request(
origin,
get_room_state_ids::v1::Request {
room_id: room_id.to_owned(),
event_id: (*event_id).to_owned(),
},
)
.send_federation_request(origin, get_room_state_ids::v1::Request {
room_id: room_id.to_owned(),
event_id: (*event_id).to_owned(),
})
.await
.inspect_err(|e| warn!("Fetching state for event failed: {e}"))?;
@ -58,14 +60,13 @@ pub(super) async fn fetch_state(
.await;
match state.entry(shortstatekey) {
hash_map::Entry::Vacant(v) => {
| hash_map::Entry::Vacant(v) => {
v.insert(Arc::from(&*pdu.event_id));
},
hash_map::Entry::Occupied(_) => {
| hash_map::Entry::Occupied(_) =>
return Err(Error::bad_database(
"State event's type and state_key combination exists multiple times.",
))
},
)),
}
}

View file

@ -7,7 +7,8 @@ use std::{
use conduwuit::{debug, err, implement, warn, Error, Result};
use futures::{FutureExt, TryFutureExt};
use ruma::{
api::client::error::ErrorKind, events::StateEventType, CanonicalJsonValue, EventId, RoomId, ServerName, UserId,
api::client::error::ErrorKind, events::StateEventType, CanonicalJsonValue, EventId, RoomId,
ServerName, UserId,
};
use super::{check_room_id, get_room_version_id};
@ -43,8 +44,12 @@ use crate::rooms::timeline::RawPduId;
#[implement(super::Service)]
#[tracing::instrument(skip(self, origin, value, is_timeline_event), name = "pdu")]
pub async fn handle_incoming_pdu<'a>(
&self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId,
value: BTreeMap<String, CanonicalJsonValue>, is_timeline_event: bool,
&self,
origin: &'a ServerName,
room_id: &'a RoomId,
event_id: &'a EventId,
value: BTreeMap<String, CanonicalJsonValue>,
is_timeline_event: bool,
) -> Result<Option<RawPduId>> {
// 1. Skip the PDU if we already have it as a timeline event
if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await {
@ -144,10 +149,10 @@ pub async fn handle_incoming_pdu<'a>(
.expect("locked")
.entry(prev_id.into())
{
Entry::Vacant(e) => {
| Entry::Vacant(e) => {
e.insert((now, 1));
},
Entry::Occupied(mut e) => {
| Entry::Occupied(mut e) => {
*e.get_mut() = (now, e.get().1.saturating_add(1));
},
};

View file

@ -17,8 +17,13 @@ use super::{check_room_id, get_room_version_id, to_room_version};
#[implement(super::Service)]
#[allow(clippy::too_many_arguments)]
pub(super) async fn handle_outlier_pdu<'a>(
&self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId,
mut value: CanonicalJsonObject, auth_events_known: bool,
&self,
origin: &'a ServerName,
create_event: &'a PduEvent,
event_id: &'a EventId,
room_id: &'a RoomId,
mut value: CanonicalJsonObject,
auth_events_known: bool,
) -> Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)> {
// 1. Remove unsigned field
value.remove("unsigned");
@ -34,8 +39,8 @@ pub(super) async fn handle_outlier_pdu<'a>(
.verify_event(&value, Some(&room_version_id))
.await
{
Ok(ruma::signatures::Verified::All) => value,
Ok(ruma::signatures::Verified::Signatures) => {
| Ok(ruma::signatures::Verified::All) => value,
| Ok(ruma::signatures::Verified::Signatures) => {
// Redact
debug_info!("Calculated hash does not match (redaction): {event_id}");
let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else {
@ -44,24 +49,26 @@ pub(super) async fn handle_outlier_pdu<'a>(
// Skip the PDU if it is redacted and we already have it as an outlier event
if self.services.timeline.pdu_exists(event_id).await {
return Err!(Request(InvalidParam("Event was redacted and we already knew about it")));
return Err!(Request(InvalidParam(
"Event was redacted and we already knew about it"
)));
}
obj
},
Err(e) => {
| Err(e) =>
return Err!(Request(InvalidParam(debug_error!(
"Signature verification failed for {event_id}: {e}"
))))
},
)))),
};
// Now that we have checked the signature and hashes we can add the eventID and
// convert to our PduEvent type
val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
let incoming_pdu =
serde_json::from_value::<PduEvent>(serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"))
.map_err(|_| Error::bad_database("Event is not a valid PDU."))?;
let incoming_pdu = serde_json::from_value::<PduEvent>(
serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"),
)
.map_err(|_| Error::bad_database("Event is not a valid PDU."))?;
check_room_id(room_id, &incoming_pdu)?;
@ -108,10 +115,10 @@ pub(super) async fn handle_outlier_pdu<'a>(
.clone()
.expect("all auth events have state keys"),
)) {
hash_map::Entry::Vacant(v) => {
| hash_map::Entry::Vacant(v) => {
v.insert(auth_event);
},
hash_map::Entry::Occupied(_) => {
| hash_map::Entry::Occupied(_) => {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Auth event's type and state_key combination exists multiple times.",

View file

@ -4,7 +4,9 @@ use std::{
time::Instant,
};
use conduwuit::{debug, implement, utils::math::continue_exponential_backoff_secs, Error, PduEvent, Result};
use conduwuit::{
debug, implement, utils::math::continue_exponential_backoff_secs, Error, PduEvent, Result,
};
use ruma::{api::client::error::ErrorKind, CanonicalJsonValue, EventId, RoomId, ServerName};
#[implement(super::Service)]
@ -15,15 +17,23 @@ use ruma::{api::client::error::ErrorKind, CanonicalJsonValue, EventId, RoomId, S
name = "prev"
)]
pub(super) async fn handle_prev_pdu<'a>(
&self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId,
eventid_info: &mut HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>,
create_event: &Arc<PduEvent>, first_pdu_in_room: &Arc<PduEvent>, prev_id: &EventId,
&self,
origin: &'a ServerName,
event_id: &'a EventId,
room_id: &'a RoomId,
eventid_info: &mut HashMap<
Arc<EventId>,
(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>),
>,
create_event: &Arc<PduEvent>,
first_pdu_in_room: &Arc<PduEvent>,
prev_id: &EventId,
) -> Result {
// Check for disabled again because it might have changed
if self.services.metadata.is_disabled(room_id).await {
debug!(
"Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and event \
ID {event_id}"
"Federaton of room {room_id} is currently disabled on this server. Request by \
origin {origin} and event ID {event_id}"
);
return Err(Error::BadRequest(
ErrorKind::forbidden(),

View file

@ -23,8 +23,8 @@ use conduwuit::{
};
use futures::TryFutureExt;
use ruma::{
events::room::create::RoomCreateEventContent, state_res::RoomVersion, EventId, OwnedEventId, OwnedRoomId, RoomId,
RoomVersionId,
events::room::create::RoomCreateEventContent, state_res::RoomVersion, EventId, OwnedEventId,
OwnedRoomId, RoomId, RoomVersionId,
};
use crate::{globals, rooms, sending, server_keys, Dep};
@ -69,8 +69,10 @@ impl crate::Service for Service {
pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"),
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_compressor: args
.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
server: args.server.clone(),
},
@ -95,7 +97,9 @@ impl crate::Service for Service {
}
impl Service {
async fn event_exists(&self, event_id: Arc<EventId>) -> bool { self.services.timeline.pdu_exists(&event_id).await }
async fn event_exists(&self, event_id: Arc<EventId>) -> bool {
self.services.timeline.pdu_exists(&event_id).await
}
async fn event_fetch(&self, event_id: Arc<EventId>) -> Option<Arc<PduEvent>> {
self.services

View file

@ -3,9 +3,13 @@ use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId, R
use serde_json::value::RawValue as RawJsonValue;
#[implement(super::Service)]
pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value = serde_json::from_str::<CanonicalJsonObject>(pdu.get())
.map_err(|e| err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}"))))?;
pub async fn parse_incoming_pdu(
&self,
pdu: &RawJsonValue,
) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value = serde_json::from_str::<CanonicalJsonObject>(pdu.get()).map_err(|e| {
err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}")))
})?;
let room_id: OwnedRoomId = value
.get("room_id")
@ -20,8 +24,9 @@ pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEvent
.await
.map_err(|_| err!("Server is not in room {room_id}"))?;
let (event_id, value) = gen_event_id_canonical_json(pdu, &room_version_id)
.map_err(|e| err!(Request(InvalidParam("Could not convert event to canonical json: {e}"))))?;
let (event_id, value) = gen_event_id_canonical_json(pdu, &room_version_id).map_err(|e| {
err!(Request(InvalidParam("Could not convert event to canonical json: {e}")))
})?;
Ok((event_id, value, room_id))
}

View file

@ -20,7 +20,10 @@ use crate::rooms::state_compressor::CompressedStateEvent;
#[implement(super::Service)]
#[tracing::instrument(skip_all, name = "resolve")]
pub async fn resolve_state(
&self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>,
&self,
room_id: &RoomId,
room_version_id: &RoomVersionId,
incoming_state: HashMap<u64, Arc<EventId>>,
) -> Result<Arc<HashSet<CompressedStateEvent>>> {
debug!("Loading current room state ids");
let current_sstatehash = self
@ -76,10 +79,16 @@ pub async fn resolve_state(
let event_fetch = |event_id| self.event_fetch(event_id);
let event_exists = |event_id| self.event_exists(event_id);
let state = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists)
.boxed()
.await
.map_err(|e| err!(Database(error!("State resolution failed: {e:?}"))))?;
let state = state_res::resolve(
room_version_id,
&fork_states,
&auth_chain_sets,
&event_fetch,
&event_exists,
)
.boxed()
.await
.map_err(|e| err!(Database(error!("State resolution failed: {e:?}"))))?;
drop(lock);

View file

@ -21,7 +21,8 @@ use ruma::{
// request and build the state from a known point and resolve if > 1 prev_event
#[tracing::instrument(skip_all, name = "state")]
pub(super) async fn state_at_incoming_degree_one(
&self, incoming_pdu: &Arc<PduEvent>,
&self,
incoming_pdu: &Arc<PduEvent>,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
let prev_event = &*incoming_pdu.prev_events[0];
let Ok(prev_event_sstatehash) = self
@ -70,7 +71,10 @@ pub(super) async fn state_at_incoming_degree_one(
#[implement(super::Service)]
#[tracing::instrument(skip_all, name = "state")]
pub(super) async fn state_at_incoming_resolved(
&self, incoming_pdu: &Arc<PduEvent>, room_id: &RoomId, room_version_id: &RoomVersionId,
&self,
incoming_pdu: &Arc<PduEvent>,
room_id: &RoomId,
room_version_id: &RoomVersionId,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
debug!("Calculating state at event using state res");
let mut extremity_sstatehashes = HashMap::with_capacity(incoming_pdu.prev_events.len());
@ -157,10 +161,16 @@ pub(super) async fn state_at_incoming_resolved(
let event_fetch = |event_id| self.event_fetch(event_id);
let event_exists = |event_id| self.event_exists(event_id);
let result = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists)
.boxed()
.await
.map_err(|e| err!(Database(warn!(?e, "State resolution on prev events failed."))));
let result = state_res::resolve(
room_version_id,
&fork_states,
&auth_chain_sets,
&event_fetch,
&event_exists,
)
.boxed()
.await
.map_err(|e| err!(Database(warn!(?e, "State resolution on prev events failed."))));
drop(lock);

View file

@ -19,8 +19,12 @@ use crate::rooms::{state_compressor::HashSetCompressStateEvent, timeline::RawPdu
#[implement(super::Service)]
pub(super) async fn upgrade_outlier_to_timeline_pdu(
&self, incoming_pdu: Arc<PduEvent>, val: BTreeMap<String, CanonicalJsonValue>, create_event: &PduEvent,
origin: &ServerName, room_id: &RoomId,
&self,
incoming_pdu: Arc<PduEvent>,
val: BTreeMap<String, CanonicalJsonValue>,
create_event: &PduEvent,
origin: &ServerName,
room_id: &RoomId,
) -> Result<Option<RawPduId>> {
// Skip the PDU if we already have it as a timeline event
if let Ok(pduid) = self
@ -63,7 +67,8 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
.await?;
}
let state_at_incoming_event = state_at_incoming_event.expect("we always set this to some above");
let state_at_incoming_event =
state_at_incoming_event.expect("we always set this to some above");
let room_version = to_room_version(&room_version_id);
debug!("Performing auth check");
@ -124,24 +129,34 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
!auth_check
|| incoming_pdu.kind == TimelineEventType::RoomRedaction
&& match room_version_id {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &incoming_pdu.redacts {
!self
.services
.state_accessor
.user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true)
.user_can_redact(
redact_id,
&incoming_pdu.sender,
&incoming_pdu.room_id,
true,
)
.await?
} else {
false
}
},
_ => {
| _ => {
let content: RoomRedactionEventContent = incoming_pdu.get_content()?;
if let Some(redact_id) = &content.redacts {
!self
.services
.state_accessor
.user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true)
.user_can_redact(
redact_id,
&incoming_pdu.sender,
&incoming_pdu.room_id,
true,
)
.await?
} else {
false
@ -229,11 +244,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
// Set the new room state to the resolved state
debug!("Forcing new room state");
let HashSetCompressStateEvent {
shortstatehash,
added,
removed,
} = self
let HashSetCompressStateEvent { shortstatehash, added, removed } = self
.services
.state_compressor
.save_state(room_id, new_room_state)

View file

@ -51,7 +51,11 @@ impl crate::Service for Service {
#[tracing::instrument(skip(self), level = "debug")]
#[inline]
pub async fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
&self,
user_id: &UserId,
device_id: &DeviceId,
room_id: &RoomId,
ll_user: &UserId,
) -> bool {
let key = (user_id, device_id, room_id, ll_user);
self.db.lazyloadedids.qry(&key).await.is_ok()
@ -60,7 +64,12 @@ pub async fn lazy_load_was_sent_before(
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn lazy_load_mark_sent(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>, count: PduCount,
&self,
user_id: &UserId,
device_id: &DeviceId,
room_id: &RoomId,
lazy_load: HashSet<OwnedUserId>,
count: PduCount,
) {
let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count);
@ -72,7 +81,13 @@ pub fn lazy_load_mark_sent(
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn lazy_load_confirm_delivery(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount) {
pub fn lazy_load_confirm_delivery(
&self,
user_id: &UserId,
device_id: &DeviceId,
room_id: &RoomId,
since: PduCount,
) {
let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), since);
let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&key) else {

View file

@ -58,7 +58,9 @@ pub async fn exists(&self, room_id: &RoomId) -> bool {
}
#[implement(Service)]
pub fn iter_ids(&self) -> impl Stream<Item = &RoomId> + Send + '_ { self.db.roomid_shortroomid.keys().ignore_err() }
pub fn iter_ids(&self) -> impl Stream<Item = &RoomId> + Send + '_ {
self.db.roomid_shortroomid.keys().ignore_err()
}
#[implement(Service)]
#[inline]
@ -81,12 +83,18 @@ pub fn ban_room(&self, room_id: &RoomId, banned: bool) {
}
#[implement(Service)]
pub fn list_banned_rooms(&self) -> impl Stream<Item = &RoomId> + Send + '_ { self.db.bannedroomids.keys().ignore_err() }
pub fn list_banned_rooms(&self) -> impl Stream<Item = &RoomId> + Send + '_ {
self.db.bannedroomids.keys().ignore_err()
}
#[implement(Service)]
#[inline]
pub async fn is_disabled(&self, room_id: &RoomId) -> bool { self.db.disabledroomids.get(room_id).await.is_ok() }
pub async fn is_disabled(&self, room_id: &RoomId) -> bool {
self.db.disabledroomids.get(room_id).await.is_ok()
}
#[implement(Service)]
#[inline]
pub async fn is_banned(&self, room_id: &RoomId) -> bool { self.db.bannedroomids.get(room_id).await.is_ok() }
pub async fn is_banned(&self, room_id: &RoomId) -> bool {
self.db.bannedroomids.get(room_id).await.is_ok()
}

View file

@ -56,26 +56,27 @@ impl Data {
}
pub(super) fn get_relations<'a>(
&'a self, user_id: &'a UserId, shortroomid: ShortRoomId, target: ShortEventId, from: PduCount, dir: Direction,
&'a self,
user_id: &'a UserId,
shortroomid: ShortRoomId,
target: ShortEventId,
from: PduCount,
dir: Direction,
) -> impl Stream<Item = PdusIterItem> + Send + '_ {
let mut current = ArrayVec::<u8, 16>::new();
current.extend(target.to_be_bytes());
current.extend(from.saturating_inc(dir).into_unsigned().to_be_bytes());
let current = current.as_slice();
match dir {
Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(),
Direction::Backward => self.tofrom_relation.rev_raw_keys_from(current).boxed(),
| Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(),
| Direction::Backward => self.tofrom_relation.rev_raw_keys_from(current).boxed(),
}
.ignore_err()
.ready_take_while(move |key| key.starts_with(&target.to_be_bytes()))
.map(|to_from| u64_from_u8(&to_from[8..16]))
.map(PduCount::from_unsigned)
.wide_filter_map(move |shorteventid| async move {
let pdu_id: RawPduId = PduId {
shortroomid,
shorteventid,
}
.into();
let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into();
let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
@ -99,7 +100,9 @@ impl Data {
self.referencedevents.qry(&key).await.is_ok()
}
pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) { self.softfailedeventids.insert(event_id, []); }
pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) {
self.softfailedeventids.insert(event_id, []);
}
pub(super) async fn is_event_soft_failed(&self, event_id: &EventId) -> bool {
self.softfailedeventids.get(event_id).await.is_ok()

View file

@ -36,8 +36,8 @@ impl Service {
#[tracing::instrument(skip(self, from, to), level = "debug")]
pub fn add_relation(&self, from: PduCount, to: PduCount) {
match (from, to) {
(PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t),
_ => {
| (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t),
| _ => {
// TODO: Relations with backfilled pdus
},
}
@ -45,15 +45,21 @@ impl Service {
#[allow(clippy::too_many_arguments)]
pub async fn get_relations(
&self, user_id: &UserId, room_id: &RoomId, target: &EventId, from: PduCount, limit: usize, max_depth: u8,
&self,
user_id: &UserId,
room_id: &RoomId,
target: &EventId,
from: PduCount,
limit: usize,
max_depth: u8,
dir: Direction,
) -> Vec<PdusIterItem> {
let room_id = self.services.short.get_or_create_shortroomid(room_id).await;
let target = match self.services.timeline.get_pdu_count(target).await {
Ok(PduCount::Normal(c)) => c,
| Ok(PduCount::Normal(c)) => c,
// TODO: Support backfilled relations
_ => 0, // This will result in an empty iterator
| _ => 0, // This will result in an empty iterator
};
let mut pdus: Vec<_> = self
@ -66,9 +72,9 @@ impl Service {
'limit: while let Some(stack_pdu) = stack.pop() {
let target = match stack_pdu.0 .0 {
PduCount::Normal(c) => c,
| PduCount::Normal(c) => c,
// TODO: Support backfilled relations
PduCount::Backfilled(_) => 0, // This will result in an empty iterator
| PduCount::Backfilled(_) => 0, // This will result in an empty iterator
};
let relations: Vec<_> = self
@ -106,7 +112,9 @@ impl Service {
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_event_soft_failed(&self, event_id: &EventId) { self.db.mark_event_soft_failed(event_id) }
pub fn mark_event_soft_failed(&self, event_id: &EventId) {
self.db.mark_event_soft_failed(event_id);
}
#[inline]
#[tracing::instrument(skip(self), level = "debug")]

View file

@ -40,7 +40,12 @@ impl Data {
}
}
pub(super) async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) {
pub(super) async fn readreceipt_update(
&self,
user_id: &UserId,
room_id: &RoomId,
event: &ReceiptEvent,
) {
// Remove old entry
let last_possible_key = (room_id, u64::MAX);
self.readreceiptid_readreceipt
@ -57,7 +62,9 @@ impl Data {
}
pub(super) fn readreceipts_since<'a>(
&'a self, room_id: &'a RoomId, since: u64,
&'a self,
room_id: &'a RoomId,
since: u64,
) -> impl Stream<Item = ReceiptItem<'_>> + Send + 'a {
type Key<'a> = (&'a RoomId, u64, &'a UserId);
type KeyVal<'a> = (Key<'a>, CanonicalJsonObject);
@ -87,12 +94,20 @@ impl Data {
self.roomuserid_lastprivatereadupdate.put(key, next_count);
}
pub(super) async fn private_read_get_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
pub(super) async fn private_read_get_count(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<u64> {
let key = (room_id, user_id);
self.roomuserid_privateread.qry(&key).await.deserialized()
}
pub(super) async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
pub(super) async fn last_privateread_update(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> u64 {
let key = (room_id, user_id);
self.roomuserid_lastprivatereadupdate
.qry(&key)

View file

@ -44,7 +44,12 @@ impl crate::Service for Service {
impl Service {
/// Replaces the previous read receipt.
pub async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) {
pub async fn readreceipt_update(
&self,
user_id: &UserId,
room_id: &RoomId,
event: &ReceiptEvent,
) {
self.db.readreceipt_update(user_id, room_id, event).await;
self.services
.sending
@ -54,23 +59,21 @@ impl Service {
}
/// Gets the latest private read receipt from the user in the room
pub async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Raw<AnySyncEphemeralRoomEvent>> {
let pdu_count = self
.private_read_get_count(room_id, user_id)
.map_err(|e| err!(Database(warn!("No private read receipt was set in {room_id}: {e}"))));
let shortroomid = self
.services
.short
.get_shortroomid(room_id)
.map_err(|e| err!(Database(warn!("Short room ID does not exist in database for {room_id}: {e}"))));
pub async fn private_read_get(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Raw<AnySyncEphemeralRoomEvent>> {
let pdu_count = self.private_read_get_count(room_id, user_id).map_err(|e| {
err!(Database(warn!("No private read receipt was set in {room_id}: {e}")))
});
let shortroomid = self.services.short.get_shortroomid(room_id).map_err(|e| {
err!(Database(warn!("Short room ID does not exist in database for {room_id}: {e}")))
});
let (pdu_count, shortroomid) = try_join!(pdu_count, shortroomid)?;
let shorteventid = PduCount::Normal(pdu_count);
let pdu_id: RawPduId = PduId {
shortroomid,
shorteventid,
}
.into();
let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into();
let pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await?;
@ -80,21 +83,17 @@ impl Service {
event_id,
BTreeMap::from_iter([(
ruma::events::receipt::ReceiptType::ReadPrivate,
BTreeMap::from_iter([(
user_id,
ruma::events::receipt::Receipt {
ts: None, // TODO: start storing the timestamp so we can return one
thread: ruma::events::receipt::ReceiptThread::Unthreaded,
},
)]),
BTreeMap::from_iter([(user_id, ruma::events::receipt::Receipt {
ts: None, // TODO: start storing the timestamp so we can return one
thread: ruma::events::receipt::ReceiptThread::Unthreaded,
})]),
)]),
)]);
let receipt_event_content = ReceiptEventContent(content);
let receipt_sync_event = SyncEphemeralRoomEvent {
content: receipt_event_content,
};
let receipt_sync_event = SyncEphemeralRoomEvent { content: receipt_event_content };
let event = serde_json::value::to_raw_value(&receipt_sync_event).expect("receipt created manually");
let event = serde_json::value::to_raw_value(&receipt_sync_event)
.expect("receipt created manually");
Ok(Raw::from_json(event))
}
@ -104,7 +103,9 @@ impl Service {
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn readreceipts_since<'a>(
&'a self, room_id: &'a RoomId, since: u64,
&'a self,
room_id: &'a RoomId,
since: u64,
) -> impl Stream<Item = ReceiptItem<'_>> + Send + 'a {
self.db.readreceipts_since(room_id, since)
}
@ -119,7 +120,11 @@ impl Service {
/// Returns the private read marker PDU count.
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub async fn private_read_get_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
pub async fn private_read_get_count(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<u64> {
self.db.private_read_get_count(room_id, user_id).await
}
@ -137,7 +142,9 @@ where
{
let mut json = BTreeMap::new();
for value in receipts {
let receipt = serde_json::from_str::<SyncEphemeralRoomEvent<ReceiptEventContent>>(value.json().get());
let receipt = serde_json::from_str::<SyncEphemeralRoomEvent<ReceiptEventContent>>(
value.json().get(),
);
if let Ok(value) = receipt {
for (event, receipt) in value.content {
json.insert(event, receipt);
@ -149,9 +156,7 @@ where
let content = ReceiptEventContent::from_iter(json);
Raw::from_json(
serde_json::value::to_raw_value(&SyncEphemeralRoomEvent {
content,
})
.expect("received valid json"),
serde_json::value::to_raw_value(&SyncEphemeralRoomEvent { content })
.expect("received valid json"),
)
}

View file

@ -49,18 +49,18 @@ pub struct RoomQuery<'a> {
type TokenId = ArrayVec<u8, TOKEN_ID_MAX_LEN>;
const TOKEN_ID_MAX_LEN: usize = size_of::<ShortRoomId>() + WORD_MAX_LEN + 1 + size_of::<RawPduId>();
const TOKEN_ID_MAX_LEN: usize =
size_of::<ShortRoomId>() + WORD_MAX_LEN + 1 + size_of::<RawPduId>();
const WORD_MAX_LEN: usize = 50;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data {
tokenids: args.db["tokenids"].clone(),
},
db: Data { tokenids: args.db["tokenids"].clone() },
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
}))
@ -103,7 +103,8 @@ pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_b
#[implement(Service)]
pub async fn search_pdus<'a>(
&'a self, query: &'a RoomQuery<'a>,
&'a self,
query: &'a RoomQuery<'a>,
) -> Result<(usize, impl Stream<Item = PduEvent> + Send + 'a)> {
let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await;
@ -136,7 +137,10 @@ pub async fn search_pdus<'a>(
// result is modeled as a stream such that callers don't have to be refactored
// though an additional async/wrap still exists for now
#[implement(Service)]
pub async fn search_pdu_ids(&self, query: &RoomQuery<'_>) -> Result<impl Stream<Item = RawPduId> + Send + '_> {
pub async fn search_pdu_ids(
&self,
query: &RoomQuery<'_>,
) -> Result<impl Stream<Item = RawPduId> + Send + '_> {
let shortroomid = self.services.short.get_shortroomid(query.room_id).await?;
let pdu_ids = self.search_pdu_ids_query_room(query, shortroomid).await;
@ -147,7 +151,11 @@ pub async fn search_pdu_ids(&self, query: &RoomQuery<'_>) -> Result<impl Stream<
}
#[implement(Service)]
async fn search_pdu_ids_query_room(&self, query: &RoomQuery<'_>, shortroomid: ShortRoomId) -> Vec<Vec<RawPduId>> {
async fn search_pdu_ids_query_room(
&self,
query: &RoomQuery<'_>,
shortroomid: ShortRoomId,
) -> Vec<Vec<RawPduId>> {
tokenize(&query.criteria.search_term)
.stream()
.wide_then(|word| async move {
@ -162,7 +170,9 @@ async fn search_pdu_ids_query_room(&self, query: &RoomQuery<'_>, shortroomid: Sh
/// Iterate over PduId's containing a word
#[implement(Service)]
fn search_pdu_ids_query_words<'a>(
&'a self, shortroomid: ShortRoomId, word: &'a str,
&'a self,
shortroomid: ShortRoomId,
word: &'a str,
) -> impl Stream<Item = RawPduId> + Send + '_ {
self.search_pdu_ids_query_word(shortroomid, word)
.map(move |key| -> RawPduId {
@ -173,7 +183,11 @@ fn search_pdu_ids_query_words<'a>(
/// Iterate over raw database results for a word
#[implement(Service)]
fn search_pdu_ids_query_word(&self, shortroomid: ShortRoomId, word: &str) -> impl Stream<Item = Val<'_>> + Send + '_ {
fn search_pdu_ids_query_word(
&self,
shortroomid: ShortRoomId,
word: &str,
) -> impl Stream<Item = Val<'_>> + Send + '_ {
// rustc says const'ing this not yet stable
let end_id: RawPduId = PduId {
shortroomid,

View file

@ -60,7 +60,10 @@ pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEvent
}
#[implement(Service)]
pub fn multi_get_or_create_shorteventid<'a, I>(&'a self, event_ids: I) -> impl Stream<Item = ShortEventId> + Send + '_
pub fn multi_get_or_create_shorteventid<'a, I>(
&'a self,
event_ids: I,
) -> impl Stream<Item = ShortEventId> + Send + '_
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
<I as Iterator>::Item: AsRef<[u8]> + Send + Sync + 'a,
@ -72,8 +75,8 @@ where
event_ids.next().map(|event_id| (event_id, result))
})
.map(|(event_id, result)| match result {
Ok(ref short) => utils::u64_from_u8(short),
Err(_) => self.create_shorteventid(event_id),
| Ok(ref short) => utils::u64_from_u8(short),
| Err(_) => self.create_shorteventid(event_id),
})
}
@ -104,7 +107,11 @@ pub async fn get_shorteventid(&self, event_id: &EventId) -> Result<ShortEventId>
}
#[implement(Service)]
pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> ShortStateKey {
pub async fn get_or_create_shortstatekey(
&self,
event_type: &StateEventType,
state_key: &str,
) -> ShortStateKey {
const BUFSIZE: usize = size_of::<ShortStateKey>();
if let Ok(shortstatekey) = self.get_shortstatekey(event_type, state_key).await {
@ -127,7 +134,11 @@ pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, sta
}
#[implement(Service)]
pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<ShortStateKey> {
pub async fn get_shortstatekey(
&self,
event_type: &StateEventType,
state_key: &str,
) -> Result<ShortStateKey> {
let key = (event_type, state_key);
self.db
.statekey_shortstatekey
@ -153,7 +164,10 @@ where
}
#[implement(Service)]
pub fn multi_get_eventid_from_short<'a, Id, I>(&'a self, shorteventid: I) -> impl Stream<Item = Result<Id>> + Send + 'a
pub fn multi_get_eventid_from_short<'a, Id, I>(
&'a self,
shorteventid: I,
) -> impl Stream<Item = Result<Id>> + Send + 'a
where
I: Iterator<Item = &'a ShortEventId> + Send + 'a,
Id: for<'de> Deserialize<'de> + Sized + ToOwned + 'a,
@ -168,7 +182,10 @@ where
}
#[implement(Service)]
pub async fn get_statekey_from_short(&self, shortstatekey: ShortStateKey) -> Result<(StateEventType, String)> {
pub async fn get_statekey_from_short(
&self,
shortstatekey: ShortStateKey,
) -> Result<(StateEventType, String)> {
const BUFSIZE: usize = size_of::<ShortStateKey>();
self.db

View file

@ -125,7 +125,8 @@ enum Identifier<'a> {
pub struct Service {
services: Services,
pub roomid_spacehierarchy_cache: Mutex<LruCache<OwnedRoomId, Option<CachedSpaceHierarchySummary>>>,
pub roomid_spacehierarchy_cache:
Mutex<LruCache<OwnedRoomId, Option<CachedSpaceHierarchySummary>>>,
}
struct Services {
@ -145,11 +146,13 @@ impl crate::Service for Service {
let cache_size = cache_size * config.cache_capacity_modifier;
Ok(Arc::new(Self {
services: Services {
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state: args.depend::<rooms::state::Service>("rooms::state"),
short: args.depend::<rooms::short::Service>("rooms::short"),
event_handler: args.depend::<rooms::event_handler::Service>("rooms::event_handler"),
event_handler: args
.depend::<rooms::event_handler::Service>("rooms::event_handler"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
sending: args.depend::<sending::Service>("sending"),
},
@ -166,28 +169,37 @@ impl Service {
/// Errors if the room does not exist, so a check if the room exists should
/// be done
pub async fn get_federation_hierarchy(
&self, room_id: &RoomId, server_name: &ServerName, suggested_only: bool,
&self,
room_id: &RoomId,
server_name: &ServerName,
suggested_only: bool,
) -> Result<federation::space::get_hierarchy::v1::Response> {
match self
.get_summary_and_children_local(&room_id.to_owned(), Identifier::ServerName(server_name))
.get_summary_and_children_local(
&room_id.to_owned(),
Identifier::ServerName(server_name),
)
.await?
{
Some(SummaryAccessibility::Accessible(room)) => {
| Some(SummaryAccessibility::Accessible(room)) => {
let mut children = Vec::new();
let mut inaccessible_children = Vec::new();
for (child, _via) in get_parent_children_via(&room, suggested_only) {
match self
.get_summary_and_children_local(&child, Identifier::ServerName(server_name))
.get_summary_and_children_local(
&child,
Identifier::ServerName(server_name),
)
.await?
{
Some(SummaryAccessibility::Accessible(summary)) => {
| Some(SummaryAccessibility::Accessible(summary)) => {
children.push((*summary).into());
},
Some(SummaryAccessibility::Inaccessible) => {
| Some(SummaryAccessibility::Inaccessible) => {
inaccessible_children.push(child);
},
None => (),
| None => (),
}
}
@ -197,16 +209,18 @@ impl Service {
inaccessible_children,
})
},
Some(SummaryAccessibility::Inaccessible) => {
Err(Error::BadRequest(ErrorKind::NotFound, "The requested room is inaccessible"))
},
None => Err(Error::BadRequest(ErrorKind::NotFound, "The requested room was not found")),
| Some(SummaryAccessibility::Inaccessible) =>
Err(Error::BadRequest(ErrorKind::NotFound, "The requested room is inaccessible")),
| None =>
Err(Error::BadRequest(ErrorKind::NotFound, "The requested room was not found")),
}
}
/// Gets the summary of a space using solely local information
async fn get_summary_and_children_local(
&self, current_room: &OwnedRoomId, identifier: Identifier<'_>,
&self,
current_room: &OwnedRoomId,
identifier: Identifier<'_>,
) -> Result<Option<SummaryAccessibility>> {
if let Some(cached) = self
.roomid_spacehierarchy_cache
@ -241,9 +255,7 @@ impl Service {
if let Ok(summary) = summary {
self.roomid_spacehierarchy_cache.lock().await.insert(
current_room.clone(),
Some(CachedSpaceHierarchySummary {
summary: summary.clone(),
}),
Some(CachedSpaceHierarchySummary { summary: summary.clone() }),
);
Ok(Some(SummaryAccessibility::Accessible(Box::new(summary))))
@ -258,20 +270,21 @@ impl Service {
/// Gets the summary of a space using solely federation
#[tracing::instrument(skip(self))]
async fn get_summary_and_children_federation(
&self, current_room: &OwnedRoomId, suggested_only: bool, user_id: &UserId, via: &[OwnedServerName],
&self,
current_room: &OwnedRoomId,
suggested_only: bool,
user_id: &UserId,
via: &[OwnedServerName],
) -> Result<Option<SummaryAccessibility>> {
for server in via {
debug_info!("Asking {server} for /hierarchy");
let Ok(response) = self
.services
.sending
.send_federation_request(
server,
federation::space::get_hierarchy::v1::Request {
room_id: current_room.to_owned(),
suggested_only,
},
)
.send_federation_request(server, federation::space::get_hierarchy::v1::Request {
room_id: current_room.to_owned(),
suggested_only,
})
.await
else {
continue;
@ -282,9 +295,7 @@ impl Service {
self.roomid_spacehierarchy_cache.lock().await.insert(
current_room.clone(),
Some(CachedSpaceHierarchySummary {
summary: summary.clone(),
}),
Some(CachedSpaceHierarchySummary { summary: summary.clone() }),
);
for child in response.children {
@ -356,7 +367,11 @@ impl Service {
/// Gets the summary of a space using either local or remote (federation)
/// sources
async fn get_summary_and_children_client(
&self, current_room: &OwnedRoomId, suggested_only: bool, user_id: &UserId, via: &[OwnedServerName],
&self,
current_room: &OwnedRoomId,
suggested_only: bool,
user_id: &UserId,
via: &[OwnedServerName],
) -> Result<Option<SummaryAccessibility>> {
if let Ok(Some(response)) = self
.get_summary_and_children_local(current_room, Identifier::UserId(user_id))
@ -370,7 +385,9 @@ impl Service {
}
async fn get_room_summary(
&self, current_room: &OwnedRoomId, children_state: Vec<Raw<HierarchySpaceChildEvent>>,
&self,
current_room: &OwnedRoomId,
children_state: Vec<Raw<HierarchySpaceChildEvent>>,
identifier: &Identifier<'_>,
) -> Result<SpaceHierarchyParentSummary, Error> {
let room_id: &RoomId = current_room;
@ -388,12 +405,20 @@ impl Service {
.allowed_room_ids(join_rule.clone());
if !self
.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids)
.is_accessible_child(
current_room,
&join_rule.clone().into(),
identifier,
&allowed_room_ids,
)
.await
{
debug_info!("User is not allowed to see room {room_id}");
// This error will be caught later
return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room"));
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"User is not allowed to see the room",
));
}
Ok(SpaceHierarchyParentSummary {
@ -446,7 +471,12 @@ impl Service {
}
pub async fn get_client_hierarchy(
&self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec<ShortRoomId>, max_depth: u64,
&self,
sender_user: &UserId,
room_id: &RoomId,
limit: usize,
short_room_ids: Vec<ShortRoomId>,
max_depth: u64,
suggested_only: bool,
) -> Result<client::space::get_hierarchy::v1::Response> {
let mut parents = VecDeque::new();
@ -454,27 +484,30 @@ impl Service {
// Don't start populating the results if we have to start at a specific room.
let mut populate_results = short_room_ids.is_empty();
let mut stack = vec![vec![(
room_id.to_owned(),
match room_id.server_name() {
Some(server_name) => vec![server_name.into()],
None => vec![],
},
)]];
let mut stack = vec![vec![(room_id.to_owned(), match room_id.server_name() {
| Some(server_name) => vec![server_name.into()],
| None => vec![],
})]];
let mut results = Vec::with_capacity(limit);
while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) } {
while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) }
{
if results.len() >= limit {
break;
}
match (
self.get_summary_and_children_client(&current_room, suggested_only, sender_user, &via)
.await?,
self.get_summary_and_children_client(
&current_room,
suggested_only,
sender_user,
&via,
)
.await?,
current_room == room_id,
) {
(Some(SummaryAccessibility::Accessible(summary)), _) => {
| (Some(SummaryAccessibility::Accessible(summary)), _) => {
let mut children: Vec<(OwnedRoomId, Vec<OwnedServerName>)> =
get_parent_children_via(&summary, suggested_only)
.into_iter()
@ -493,7 +526,9 @@ impl Service {
self.services
.short
.get_shortroomid(room)
.map_ok(|short| Some(&short) != short_room_ids.get(parents.len()))
.map_ok(|short| {
Some(&short) != short_room_ids.get(parents.len())
})
.unwrap_or_else(|_| false)
})
.map(Clone::clone)
@ -525,14 +560,20 @@ impl Service {
// Root room in the space hierarchy, we return an error
// if this one fails.
},
(Some(SummaryAccessibility::Inaccessible), true) => {
return Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room is inaccessible"));
| (Some(SummaryAccessibility::Inaccessible), true) => {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"The requested room is inaccessible",
));
},
(None, true) => {
return Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room was not found"));
| (None, true) => {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"The requested room was not found",
));
},
// Just ignore other unavailable rooms
(None | Some(SummaryAccessibility::Inaccessible), false) => (),
| (None | Some(SummaryAccessibility::Inaccessible), false) => (),
}
}
@ -544,15 +585,19 @@ impl Service {
let short_room_ids: Vec<_> = parents
.iter()
.stream()
.filter_map(|room_id| async move { self.services.short.get_shortroomid(room_id).await.ok() })
.filter_map(|room_id| async move {
self.services.short.get_shortroomid(room_id).await.ok()
})
.collect()
.await;
Some(
PaginationToken {
short_room_ids,
limit: UInt::new(max_depth).expect("When sent in request it must have been valid UInt"),
max_depth: UInt::new(max_depth).expect("When sent in request it must have been valid UInt"),
limit: UInt::new(max_depth)
.expect("When sent in request it must have been valid UInt"),
max_depth: UInt::new(max_depth)
.expect("When sent in request it must have been valid UInt"),
suggested_only,
}
.to_string(),
@ -566,9 +611,12 @@ impl Service {
/// Simply returns the stripped m.space.child events of a room
async fn get_stripped_space_child_events(
&self, room_id: &RoomId,
&self,
room_id: &RoomId,
) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> {
let Ok(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id).await else {
let Ok(current_shortstatehash) =
self.services.state.get_room_shortstatehash(room_id).await
else {
return Ok(None);
};
@ -581,18 +629,17 @@ impl Service {
let mut children_pdus = Vec::with_capacity(state.len());
for (key, id) in state {
let (event_type, state_key) = self.services.short.get_statekey_from_short(key).await?;
let (event_type, state_key) =
self.services.short.get_statekey_from_short(key).await?;
if event_type != StateEventType::SpaceChild {
continue;
}
let pdu = self
.services
.timeline
.get_pdu(&id)
.await
.map_err(|e| err!(Database("Event {id:?} in space state not found: {e:?}")))?;
let pdu =
self.services.timeline.get_pdu(&id).await.map_err(|e| {
err!(Database("Event {id:?} in space state not found: {e:?}"))
})?;
if let Ok(content) = pdu.get_content::<SpaceChildEventContent>() {
if content.via.is_empty() {
@ -610,11 +657,14 @@ impl Service {
/// With the given identifier, checks if a room is accessable
async fn is_accessible_child(
&self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>,
&self,
current_room: &OwnedRoomId,
join_rule: &SpaceRoomJoinRule,
identifier: &Identifier<'_>,
allowed_room_ids: &Vec<OwnedRoomId>,
) -> bool {
match identifier {
Identifier::ServerName(server_name) => {
| Identifier::ServerName(server_name) => {
// Checks if ACLs allow for the server to participate
if self
.services
@ -626,7 +676,7 @@ impl Service {
return false;
}
},
Identifier::UserId(user_id) => {
| Identifier::UserId(user_id) => {
if self
.services
.state_cache
@ -642,16 +692,18 @@ impl Service {
},
}
match &join_rule {
SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true,
SpaceRoomJoinRule::Restricted => {
| SpaceRoomJoinRule::Public
| SpaceRoomJoinRule::Knock
| SpaceRoomJoinRule::KnockRestricted => true,
| SpaceRoomJoinRule::Restricted => {
for room in allowed_room_ids {
match identifier {
Identifier::UserId(user) => {
| Identifier::UserId(user) => {
if self.services.state_cache.is_joined(user, room).await {
return true;
}
},
Identifier::ServerName(server) => {
| Identifier::ServerName(server) => {
if self.services.state_cache.server_in_room(server, room).await {
return true;
}
@ -661,7 +713,7 @@ impl Service {
false
},
// Invite only, Private, or Custom join rule
_ => false,
| _ => false,
}
}
}
@ -737,7 +789,8 @@ fn summary_to_chunk(summary: SpaceHierarchyParentSummary) -> SpaceHierarchyRooms
/// Returns the children of a SpaceHierarchyParentSummary, making use of the
/// children_state field
fn get_parent_children_via(
parent: &SpaceHierarchyParentSummary, suggested_only: bool,
parent: &SpaceHierarchyParentSummary,
suggested_only: bool,
) -> Vec<(OwnedRoomId, Vec<OwnedServerName>)> {
parent
.children_state
@ -755,7 +808,8 @@ fn get_parent_children_via(
}
fn next_room_to_traverse(
stack: &mut Vec<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>, parents: &mut VecDeque<OwnedRoomId>,
stack: &mut Vec<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>,
parents: &mut VecDeque<OwnedRoomId>,
) -> Option<(OwnedRoomId, Vec<OwnedServerName>)> {
while stack.last().is_some_and(Vec::is_empty) {
stack.pop();

View file

@ -69,18 +69,15 @@ fn get_summary_children() {
}
.into();
assert_eq!(
get_parent_children_via(&summary, false),
vec![
(owned_room_id!("!foo:example.org"), vec![owned_server_name!("example.org")]),
(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")]),
(owned_room_id!("!baz:example.org"), vec![owned_server_name!("example.org")])
]
);
assert_eq!(
get_parent_children_via(&summary, true),
vec![(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")])]
);
assert_eq!(get_parent_children_via(&summary, false), vec![
(owned_room_id!("!foo:example.org"), vec![owned_server_name!("example.org")]),
(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")]),
(owned_room_id!("!baz:example.org"), vec![owned_server_name!("example.org")])
]);
assert_eq!(get_parent_children_via(&summary, true), vec![(
owned_room_id!("!bar:example.org"),
vec![owned_server_name!("example.org")]
)]);
}
#[test]

View file

@ -16,7 +16,9 @@ use conduwuit::{
warn, PduEvent, Result,
};
use database::{Deserialized, Ignore, Interfix, Map};
use futures::{future::join_all, pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use futures::{
future::join_all, pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
};
use ruma::{
events::{
room::{create::RoomCreateEventContent, member::RoomMemberEventContent},
@ -70,8 +72,10 @@ impl crate::Service for Service {
short: args.depend::<rooms::short::Service>("rooms::short"),
spaces: args.depend::<rooms::spaces::Service>("rooms::spaces"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_compressor: args
.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data {
@ -100,7 +104,8 @@ impl Service {
shortstatehash: u64,
statediffnew: Arc<HashSet<CompressedStateEvent>>,
_statediffremoved: Arc<HashSet<CompressedStateEvent>>,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
* mutex */
) -> Result {
let event_ids = statediffnew
.iter()
@ -120,8 +125,9 @@ impl Service {
};
match pdu.kind {
TimelineEventType::RoomMember => {
let Some(user_id) = pdu.state_key.as_ref().map(UserId::parse).flat_ok() else {
| TimelineEventType::RoomMember => {
let Some(user_id) = pdu.state_key.as_ref().map(UserId::parse).flat_ok()
else {
continue;
};
@ -131,10 +137,18 @@ impl Service {
self.services
.state_cache
.update_membership(room_id, &user_id, membership_event, &pdu.sender, None, None, false)
.update_membership(
room_id,
&user_id,
membership_event,
&pdu.sender,
None,
None,
false,
)
.await?;
},
TimelineEventType::SpaceChild => {
| TimelineEventType::SpaceChild => {
self.services
.spaces
.roomid_spacehierarchy_cache
@ -142,7 +156,7 @@ impl Service {
.await
.remove(&pdu.room_id);
},
_ => continue,
| _ => continue,
}
}
@ -159,7 +173,10 @@ impl Service {
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, state_ids_compressed), level = "debug")]
pub async fn set_event_state(
&self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
&self,
event_id: &EventId,
room_id: &RoomId,
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> Result<ShortStateHash> {
const KEY_LEN: usize = size_of::<ShortEventId>();
const VAL_LEN: usize = size_of::<ShortStateHash>();
@ -190,22 +207,23 @@ impl Service {
Vec::new()
};
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = state_ids_compressed
.difference(&parent_stateinfo.full_state)
.copied()
.collect();
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = state_ids_compressed
.difference(&parent_stateinfo.full_state)
.copied()
.collect();
let statediffremoved: HashSet<_> = parent_stateinfo
.full_state
.difference(&state_ids_compressed)
.copied()
.collect();
let statediffremoved: HashSet<_> = parent_stateinfo
.full_state
.difference(&state_ids_compressed)
.copied()
.collect();
(Arc::new(statediffnew), Arc::new(statediffremoved))
} else {
(state_ids_compressed, Arc::new(HashSet::new()))
};
(Arc::new(statediffnew), Arc::new(statediffremoved))
} else {
(state_ids_compressed, Arc::new(HashSet::new()))
};
self.services.state_compressor.save_state_from_diff(
shortstatehash,
statediffnew,
@ -338,7 +356,8 @@ impl Service {
&self,
room_id: &RoomId,
shortstatehash: u64,
_mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
_mutex_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room
* state mutex */
) {
const BUFSIZE: usize = size_of::<u64>();
@ -366,7 +385,10 @@ impl Service {
.deserialized()
}
pub fn get_forward_extremities<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &EventId> + Send + '_ {
pub fn get_forward_extremities<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &EventId> + Send + '_ {
let prefix = (room_id, Interfix);
self.db
@ -380,7 +402,8 @@ impl Service {
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
_state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
_state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room
* state mutex */
) {
let prefix = (room_id, Interfix);
self.db
@ -399,26 +422,33 @@ impl Service {
/// This fetches auth events from the current state.
#[tracing::instrument(skip(self, content), level = "debug")]
pub async fn get_auth_events(
&self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>,
&self,
room_id: &RoomId,
kind: &TimelineEventType,
sender: &UserId,
state_key: Option<&str>,
content: &serde_json::value::RawValue,
) -> Result<StateMap<Arc<PduEvent>>> {
let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else {
return Ok(HashMap::new());
};
let mut sauthevents: HashMap<_, _> = state_res::auth_types_for_event(kind, sender, state_key, content)?
.iter()
.stream()
.broad_filter_map(|(event_type, state_key)| {
self.services
.short
.get_shortstatekey(event_type, state_key)
.map_ok(move |ssk| (ssk, (event_type, state_key)))
.map(Result::ok)
})
.map(|(ssk, (event_type, state_key))| (ssk, (event_type.to_owned(), state_key.to_owned())))
.collect()
.await;
let mut sauthevents: HashMap<_, _> =
state_res::auth_types_for_event(kind, sender, state_key, content)?
.iter()
.stream()
.broad_filter_map(|(event_type, state_key)| {
self.services
.short
.get_shortstatekey(event_type, state_key)
.map_ok(move |ssk| (ssk, (event_type, state_key)))
.map(Result::ok)
})
.map(|(ssk, (event_type, state_key))| {
(ssk, (event_type.to_owned(), state_key.to_owned()))
})
.collect()
.await;
let (state_keys, event_ids): (Vec<_>, Vec<_>) = self
.services

View file

@ -39,14 +39,16 @@ impl Data {
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
state_compressor: args
.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
}
}
pub(super) async fn state_full(
&self, shortstatehash: ShortStateHash,
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
let state = self
.state_full_pdus(shortstatehash)
@ -58,7 +60,10 @@ impl Data {
Ok(state)
}
pub(super) async fn state_full_pdus(&self, shortstatehash: ShortStateHash) -> Result<Vec<PduEvent>> {
pub(super) async fn state_full_pdus(
&self,
shortstatehash: ShortStateHash,
) -> Result<Vec<PduEvent>> {
let short_ids = self.state_full_shortids(shortstatehash).await?;
let full_pdus = self
@ -66,16 +71,19 @@ impl Data {
.short
.multi_get_eventid_from_short(short_ids.iter().map(ref_at!(1)))
.ready_filter_map(Result::ok)
.broad_filter_map(
|event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await.ok() },
)
.broad_filter_map(|event_id: OwnedEventId| async move {
self.services.timeline.get_pdu(&event_id).await.ok()
})
.collect()
.await;
Ok(full_pdus)
}
pub(super) async fn state_full_ids<Id>(&self, shortstatehash: ShortStateHash) -> Result<HashMap<ShortStateKey, Id>>
pub(super) async fn state_full_ids<Id>(
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<ShortStateKey, Id>>
where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
@ -96,7 +104,8 @@ impl Data {
}
pub(super) async fn state_full_shortids(
&self, shortstatehash: ShortStateHash,
&self,
shortstatehash: ShortStateHash,
) -> Result<Vec<(ShortStateKey, ShortEventId)>> {
let shortids = self
.services
@ -118,7 +127,10 @@ impl Data {
/// Returns a single EventId from `room_id` with key
/// (`event_type`,`state_key`).
pub(super) async fn state_get_id<Id>(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
@ -155,10 +167,15 @@ impl Data {
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn state_get(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
self.state_get_id(shortstatehash, event_type, state_key)
.and_then(|event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await })
.and_then(|event_id: OwnedEventId| async move {
self.services.timeline.get_pdu(&event_id).await
})
.await
}
@ -179,7 +196,8 @@ impl Data {
/// Returns the full room state.
pub(super) async fn room_state_full(
&self, room_id: &RoomId,
&self,
room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
self.services
.state
@ -203,7 +221,10 @@ impl Data {
/// Returns a single EventId from `room_id` with key
/// (`event_type`,`state_key`).
pub(super) async fn room_state_get_id<Id>(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
&self,
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
@ -218,7 +239,10 @@ impl Data {
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
&self,
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
self.services
.state

View file

@ -34,8 +34,8 @@ use ruma::{
},
room::RoomType,
space::SpaceRoomJoinRule,
EventEncryptionAlgorithm, EventId, JsOption, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId,
ServerName, UserId,
EventEncryptionAlgorithm, EventId, JsOption, OwnedRoomAliasId, OwnedRoomId, OwnedServerName,
OwnedUserId, RoomId, ServerName, UserId,
};
use serde::Deserialize;
@ -75,8 +75,12 @@ impl crate::Service for Service {
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(&args),
server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(server_visibility_cache_capacity)?)),
user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(user_visibility_cache_capacity)?)),
server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(
server_visibility_cache_capacity,
)?)),
user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(
user_visibility_cache_capacity,
)?)),
}))
}
@ -102,7 +106,10 @@ impl Service {
/// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_full_ids<Id>(&self, shortstatehash: ShortStateHash) -> Result<HashMap<ShortStateKey, Id>>
pub async fn state_full_ids<Id>(
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<ShortStateKey, Id>>
where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
@ -112,13 +119,15 @@ impl Service {
#[inline]
pub async fn state_full_shortids(
&self, shortstatehash: ShortStateHash,
&self,
shortstatehash: ShortStateHash,
) -> Result<Vec<(ShortStateKey, ShortEventId)>> {
self.db.state_full_shortids(shortstatehash).await
}
pub async fn state_full(
&self, shortstatehash: ShortStateHash,
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
self.db.state_full(shortstatehash).await
}
@ -127,7 +136,10 @@ impl Service {
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_get_id<Id>(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
@ -142,7 +154,10 @@ impl Service {
/// `state_key`).
#[inline]
pub async fn state_get(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
self.db
.state_get(shortstatehash, event_type, state_key)
@ -151,7 +166,10 @@ impl Service {
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub async fn state_get_content<T>(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<T>
where
T: for<'de> Deserialize<'de>,
@ -162,7 +180,11 @@ impl Service {
}
/// Get membership for given user in state
async fn user_membership(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> MembershipState {
async fn user_membership(
&self,
shortstatehash: ShortStateHash,
user_id: &UserId,
) -> MembershipState {
self.state_get_content(shortstatehash, &StateEventType::RoomMember, user_id.as_str())
.await
.map_or(MembershipState::Leave, |c: RoomMemberEventContent| c.membership)
@ -185,7 +207,12 @@ impl Service {
/// Whether a server is allowed to see an event through federation, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip_all, level = "trace")]
pub async fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> bool {
pub async fn server_can_see_event(
&self,
origin: &ServerName,
room_id: &RoomId,
event_id: &EventId,
) -> bool {
let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else {
return true;
};
@ -213,20 +240,20 @@ impl Service {
.ready_filter(|member| member.server_name() == origin);
let visibility = match history_visibility {
HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true,
HistoryVisibility::Invited => {
| HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true,
| HistoryVisibility::Invited => {
// Allow if any member on requesting server was AT LEAST invited, else deny
current_server_members
.any(|member| self.user_was_invited(shortstatehash, member))
.await
},
HistoryVisibility::Joined => {
| HistoryVisibility::Joined => {
// Allow if any member on requested server was joined, else deny
current_server_members
.any(|member| self.user_was_joined(shortstatehash, member))
.await
},
_ => {
| _ => {
error!("Unknown history visibility {history_visibility}");
false
},
@ -243,7 +270,12 @@ impl Service {
/// Whether a user is allowed to see an event, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip_all, level = "trace")]
pub async fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> bool {
pub async fn user_can_see_event(
&self,
user_id: &UserId,
room_id: &RoomId,
event_id: &EventId,
) -> bool {
let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else {
return true;
};
@ -267,17 +299,17 @@ impl Service {
});
let visibility = match history_visibility {
HistoryVisibility::WorldReadable => true,
HistoryVisibility::Shared => currently_member,
HistoryVisibility::Invited => {
| HistoryVisibility::WorldReadable => true,
| HistoryVisibility::Shared => currently_member,
| HistoryVisibility::Invited => {
// Allow if any member on requesting server was AT LEAST invited, else deny
self.user_was_invited(shortstatehash, user_id).await
},
HistoryVisibility::Joined => {
| HistoryVisibility::Joined => {
// Allow if any member on requested server was joined, else deny
self.user_was_joined(shortstatehash, user_id).await
},
_ => {
| _ => {
error!("Unknown history visibility {history_visibility}");
false
},
@ -307,9 +339,10 @@ impl Service {
});
match history_visibility {
HistoryVisibility::Invited => self.services.state_cache.is_invited(user_id, room_id).await,
HistoryVisibility::WorldReadable => true,
_ => false,
| HistoryVisibility::Invited =>
self.services.state_cache.is_invited(user_id, room_id).await,
| HistoryVisibility::WorldReadable => true,
| _ => false,
}
}
@ -320,7 +353,10 @@ impl Service {
/// Returns the full room state.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), PduEvent>> {
pub async fn room_state_full(
&self,
room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), PduEvent>> {
self.db.room_state_full(room_id).await
}
@ -334,7 +370,10 @@ impl Service {
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_get_id<Id>(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
&self,
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Sized + ToOwned,
@ -349,14 +388,20 @@ impl Service {
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
&self,
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<PduEvent> {
self.db.room_state_get(room_id, event_type, state_key).await
}
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub async fn room_state_get_content<T>(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
&self,
room_id: &RoomId,
event_type: &StateEventType,
state_key: &str,
) -> Result<T>
where
T: for<'de> Deserialize<'de>,
@ -381,18 +426,29 @@ impl Service {
JsOption::from_option(content)
}
pub async fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result<RoomMemberEventContent> {
pub async fn get_member(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<RoomMemberEventContent> {
self.room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str())
.await
}
pub async fn user_can_invite(
&self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard,
&self,
room_id: &RoomId,
sender: &UserId,
target_user: &UserId,
state_lock: &RoomMutexGuard,
) -> bool {
self.services
.timeline
.create_hash_and_sign_event(
PduBuilder::state(target_user.into(), &RoomMemberEventContent::new(MembershipState::Invite)),
PduBuilder::state(
target_user.into(),
&RoomMemberEventContent::new(MembershipState::Invite),
),
sender,
room_id,
state_lock,
@ -405,7 +461,9 @@ impl Service {
pub async fn is_world_readable(&self, room_id: &RoomId) -> bool {
self.room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "")
.await
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable)
.map(|c: RoomHistoryVisibilityEventContent| {
c.history_visibility == HistoryVisibility::WorldReadable
})
.unwrap_or(false)
}
@ -439,7 +497,11 @@ impl Service {
/// If federation is true, it allows redaction events from any user of the
/// same server as the original event sender
pub async fn user_can_redact(
&self, redacts: &EventId, sender: &UserId, room_id: &RoomId, federation: bool,
&self,
redacts: &EventId,
sender: &UserId,
room_id: &RoomId,
federation: bool,
) -> Result<bool> {
let redacting_event = self.services.timeline.get_pdu(redacts).await;
@ -451,7 +513,11 @@ impl Service {
}
if let Ok(pl_event_content) = self
.room_state_get_content::<RoomPowerLevelsEventContent>(room_id, &StateEventType::RoomPowerLevels, "")
.room_state_get_content::<RoomPowerLevelsEventContent>(
room_id,
&StateEventType::RoomPowerLevels,
"",
)
.await
{
let pl_event: RoomPowerLevels = pl_event_content.into();
@ -485,10 +551,15 @@ impl Service {
}
/// Returns the join rule (`SpaceRoomJoinRule`) for a given room
pub async fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec<OwnedRoomId>)> {
pub async fn get_join_rule(
&self,
room_id: &RoomId,
) -> Result<(SpaceRoomJoinRule, Vec<OwnedRoomId>)> {
self.room_state_get_content(room_id, &StateEventType::RoomJoinRules, "")
.await
.map(|c: RoomJoinRulesEventContent| (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule)))
.map(|c: RoomJoinRulesEventContent| {
(c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule))
})
.or_else(|_| Ok((SpaceRoomJoinRule::Invite, vec![])))
}
@ -497,10 +568,7 @@ impl Service {
let mut room_ids = Vec::with_capacity(1);
if let JoinRule::Restricted(r) | JoinRule::KnockRestricted(r) = join_rule {
for rule in r.allow {
if let AllowRule::RoomMembership(RoomMembership {
room_id: membership,
}) = rule
{
if let AllowRule::RoomMembership(RoomMembership { room_id: membership }) = rule {
room_ids.push(membership.clone());
}
}
@ -520,7 +588,10 @@ impl Service {
/// Gets the room's encryption algorithm if `m.room.encryption` state event
/// is found
pub async fn get_room_encryption(&self, room_id: &RoomId) -> Result<EventEncryptionAlgorithm> {
pub async fn get_room_encryption(
&self,
room_id: &RoomId,
) -> Result<EventEncryptionAlgorithm> {
self.room_state_get_content(room_id, &StateEventType::RoomEncryption, "")
.await
.map(|content: RoomEncryptionEventContent| content.algorithm)

View file

@ -20,7 +20,8 @@ use ruma::{
member::{MembershipState, RoomMemberEventContent},
power_levels::RoomPowerLevelsEventContent,
},
AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
RoomAccountDataEventType, StateEventType,
},
int,
serde::Raw,
@ -68,7 +69,8 @@ impl crate::Service for Service {
services: Services {
account_data: args.depend::<account_data::Service>("account_data"),
globals: args.depend::<globals::Service>("globals"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
users: args.depend::<users::Service>("users"),
},
db: Data {
@ -96,8 +98,13 @@ impl Service {
#[tracing::instrument(skip(self, last_state))]
#[allow(clippy::too_many_arguments)]
pub async fn update_membership(
&self, room_id: &RoomId, user_id: &UserId, membership_event: RoomMemberEventContent, sender: &UserId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, invite_via: Option<Vec<OwnedServerName>>,
&self,
room_id: &RoomId,
user_id: &UserId,
membership_event: RoomMemberEventContent,
sender: &UserId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
update_joined_count: bool,
) -> Result<()> {
let membership = membership_event.membership;
@ -138,7 +145,7 @@ impl Service {
}
match &membership {
MembershipState::Join => {
| MembershipState::Join => {
// Check if the user never joined this room
if !self.once_joined(user_id, room_id).await {
// Add the user ID to the join list then
@ -181,12 +188,21 @@ impl Service {
if let Ok(tag_event) = self
.services
.account_data
.get_room(&predecessor.room_id, user_id, RoomAccountDataEventType::Tag)
.get_room(
&predecessor.room_id,
user_id,
RoomAccountDataEventType::Tag,
)
.await
{
self.services
.account_data
.update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event)
.update(
Some(room_id),
user_id,
RoomAccountDataEventType::Tag,
&tag_event,
)
.await
.ok();
};
@ -195,7 +211,10 @@ impl Service {
if let Ok(mut direct_event) = self
.services
.account_data
.get_global::<DirectEvent>(user_id, GlobalAccountDataEventType::Direct)
.get_global::<DirectEvent>(
user_id,
GlobalAccountDataEventType::Direct,
)
.await
{
let mut room_ids_updated = false;
@ -213,7 +232,8 @@ impl Service {
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
&serde_json::to_value(&direct_event).expect("to json always works"),
&serde_json::to_value(&direct_event)
.expect("to json always works"),
)
.await?;
}
@ -223,7 +243,7 @@ impl Service {
self.mark_as_joined(user_id, room_id);
},
MembershipState::Invite => {
| MembershipState::Invite => {
// We want to know if the sender is ignored by the receiver
if self.services.users.user_is_ignored(sender, user_id).await {
return Ok(());
@ -232,10 +252,10 @@ impl Service {
self.mark_as_invited(user_id, room_id, last_state, invite_via)
.await;
},
MembershipState::Leave | MembershipState::Ban => {
| MembershipState::Leave | MembershipState::Ban => {
self.mark_as_left(user_id, room_id);
},
_ => {},
| _ => {},
}
if update_joined_count {
@ -246,7 +266,11 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id, appservice), level = "debug")]
pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> bool {
pub async fn appservice_in_room(
&self,
room_id: &RoomId,
appservice: &RegistrationInfo,
) -> bool {
if let Some(cached) = self
.appservice_in_room_cache
.read()
@ -347,7 +371,10 @@ impl Service {
/// Returns an iterator of all servers participating in this room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_servers<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &ServerName> + Send + 'a {
pub fn room_servers<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &ServerName> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomserverids
@ -357,7 +384,11 @@ impl Service {
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn server_in_room<'a>(&'a self, server: &'a ServerName, room_id: &'a RoomId) -> bool {
pub async fn server_in_room<'a>(
&'a self,
server: &'a ServerName,
room_id: &'a RoomId,
) -> bool {
let key = (server, room_id);
self.db.serverroomids.qry(&key).await.is_ok()
}
@ -365,7 +396,10 @@ impl Service {
/// Returns an iterator of all rooms a server participates in (as far as we
/// know).
#[tracing::instrument(skip(self), level = "debug")]
pub fn server_rooms<'a>(&'a self, server: &'a ServerName) -> impl Stream<Item = &RoomId> + Send + 'a {
pub fn server_rooms<'a>(
&'a self,
server: &'a ServerName,
) -> impl Stream<Item = &RoomId> + Send + 'a {
let prefix = (server, Interfix);
self.db
.serverroomids
@ -393,7 +427,9 @@ impl Service {
/// List the rooms common between two users
pub fn get_shared_rooms<'a>(
&'a self, user_a: &'a UserId, user_b: &'a UserId,
&'a self,
user_a: &'a UserId,
user_b: &'a UserId,
) -> impl Stream<Item = &RoomId> + Send + 'a {
use conduwuit::utils::set;
@ -404,7 +440,10 @@ impl Service {
/// Returns an iterator of all joined members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_members<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
pub fn room_members<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_joined
@ -422,7 +461,10 @@ impl Service {
#[tracing::instrument(skip(self), level = "debug")]
/// Returns an iterator of all our local users in the room, even if they're
/// deactivated/guests
pub fn local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
pub fn local_users_in_room<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &UserId> + Send + 'a {
self.room_members(room_id)
.ready_filter(|user| self.services.globals.user_is_local(user))
}
@ -430,7 +472,10 @@ impl Service {
#[tracing::instrument(skip(self), level = "debug")]
/// Returns an iterator of all our local joined users in a room who are
/// active (not deactivated, not guest)
pub fn active_local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
pub fn active_local_users_in_room<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &UserId> + Send + 'a {
self.local_users_in_room(room_id)
.filter(|user| self.services.users.is_active(user))
}
@ -447,7 +492,10 @@ impl Service {
/// Returns an iterator over all User IDs who ever joined a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_useroncejoined<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
pub fn room_useroncejoined<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuseroncejoinedids
@ -458,7 +506,10 @@ impl Service {
/// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_members_invited<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
pub fn room_members_invited<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_invitecount
@ -485,7 +536,10 @@ impl Service {
/// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self), level = "debug")]
pub fn rooms_joined<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = &RoomId> + Send + 'a {
pub fn rooms_joined<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = &RoomId> + Send + 'a {
self.db
.userroomid_joined
.keys_raw_prefix(user_id)
@ -495,7 +549,10 @@ impl Service {
/// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self), level = "debug")]
pub fn rooms_invited<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a {
pub fn rooms_invited<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a {
type KeyVal<'a> = (Key<'a>, Raw<Vec<AnyStrippedStateEvent>>);
type Key<'a> = (&'a UserId, &'a RoomId);
@ -510,30 +567,45 @@ impl Service {
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
pub async fn invite_state(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
let key = (user_id, room_id);
self.db
.userroomid_invitestate
.qry(&key)
.await
.deserialized()
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| val.deserialize_as().map_err(Into::into))
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| {
val.deserialize_as().map_err(Into::into)
})
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
pub async fn left_state(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
let key = (user_id, room_id);
self.db
.userroomid_leftstate
.qry(&key)
.await
.deserialized()
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| val.deserialize_as().map_err(Into::into))
.and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| {
val.deserialize_as().map_err(Into::into)
})
}
/// Returns an iterator over all rooms a user left.
#[tracing::instrument(skip(self), level = "debug")]
pub fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = SyncStateEventItem> + Send + 'a {
pub fn rooms_left<'a>(
&'a self,
user_id: &'a UserId,
) -> impl Stream<Item = SyncStateEventItem> + Send + 'a {
type KeyVal<'a> = (Key<'a>, Raw<Vec<Raw<AnySyncStateEvent>>>);
type Key<'a> = (&'a UserId, &'a RoomId);
@ -571,7 +643,11 @@ impl Service {
self.db.userroomid_leftstate.qry(&key).await.is_ok()
}
pub async fn user_membership(&self, user_id: &UserId, room_id: &RoomId) -> Option<MembershipState> {
pub async fn user_membership(
&self,
user_id: &UserId,
room_id: &RoomId,
) -> Option<MembershipState> {
let states = join4(
self.is_joined(user_id, room_id),
self.is_left(user_id, room_id),
@ -581,16 +657,19 @@ impl Service {
.await;
match states {
(true, ..) => Some(MembershipState::Join),
(_, true, ..) => Some(MembershipState::Leave),
(_, _, true, ..) => Some(MembershipState::Invite),
(false, false, false, true) => Some(MembershipState::Ban),
_ => None,
| (true, ..) => Some(MembershipState::Join),
| (_, true, ..) => Some(MembershipState::Leave),
| (_, _, true, ..) => Some(MembershipState::Invite),
| (false, false, false, true) => Some(MembershipState::Ban),
| _ => None,
}
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn servers_invite_via<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &ServerName> + Send + 'a {
pub fn servers_invite_via<'a>(
&'a self,
room_id: &'a RoomId,
) -> impl Stream<Item = &ServerName> + Send + 'a {
type KeyVal<'a> = (Ignore, Vec<&'a ServerName>);
self.db
@ -711,7 +790,10 @@ impl Service {
}
pub async fn mark_as_invited(
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
&self,
user_id: &UserId,
room_id: &RoomId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) {
let roomuser_id = (room_id, user_id);

View file

@ -69,7 +69,8 @@ pub(crate) type CompressedStateEvent = [u8; 2 * size_of::<ShortId>()];
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let config = &args.server.config;
let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier;
let cache_capacity =
f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier;
Ok(Arc::new(Self {
stateinfo_cache: LruCache::new(usize_from_f64(cache_capacity)?).into(),
db: Data {
@ -85,17 +86,16 @@ impl crate::Service for Service {
fn memory_usage(&self, out: &mut dyn Write) -> Result {
let (cache_len, ents) = {
let cache = self.stateinfo_cache.lock().expect("locked");
let ents = cache
.iter()
.map(at!(1))
.flat_map(|vec| vec.iter())
.fold(HashMap::new(), |mut ents, ssi| {
let ents = cache.iter().map(at!(1)).flat_map(|vec| vec.iter()).fold(
HashMap::new(),
|mut ents, ssi| {
for cs in &[&ssi.added, &ssi.removed, &ssi.full_state] {
ents.insert(Arc::as_ptr(cs), compressed_state_size(cs));
}
ents
});
},
);
(cache.len(), ents)
};
@ -117,7 +117,10 @@ impl crate::Service for Service {
impl Service {
/// Returns a stack with info on shortstatehash, full state, added diff and
/// removed diff for the selected shortstatehash and each parent layer.
pub async fn load_shortstatehash_info(&self, shortstatehash: ShortStateHash) -> Result<ShortStateInfoVec> {
pub async fn load_shortstatehash_info(
&self,
shortstatehash: ShortStateHash,
) -> Result<ShortStateInfoVec> {
if let Some(r) = self
.stateinfo_cache
.lock()
@ -143,12 +146,11 @@ impl Service {
Ok(stack)
}
async fn new_shortstatehash_info(&self, shortstatehash: ShortStateHash) -> Result<ShortStateInfoVec> {
let StateDiff {
parent,
added,
removed,
} = self.get_statediff(shortstatehash).await?;
async fn new_shortstatehash_info(
&self,
shortstatehash: ShortStateHash,
) -> Result<ShortStateInfoVec> {
let StateDiff { parent, added, removed } = self.get_statediff(shortstatehash).await?;
let Some(parent) = parent else {
return Ok(vec![ShortStateInfo {
@ -180,9 +182,17 @@ impl Service {
Ok(stack)
}
pub fn compress_state_events<'a, I>(&'a self, state: I) -> impl Stream<Item = CompressedStateEvent> + Send + 'a
pub fn compress_state_events<'a, I>(
&'a self,
state: I,
) -> impl Stream<Item = CompressedStateEvent> + Send + 'a
where
I: Iterator<Item = (&'a ShortStateKey, &'a EventId)> + Clone + Debug + ExactSizeIterator + Send + 'a,
I: Iterator<Item = (&'a ShortStateKey, &'a EventId)>
+ Clone
+ Debug
+ ExactSizeIterator
+ Send
+ 'a,
{
let event_ids = state.clone().map(at!(1));
@ -195,10 +205,16 @@ impl Service {
.stream()
.map(at!(0))
.zip(short_event_ids)
.map(|(shortstatekey, shorteventid)| compress_state_event(*shortstatekey, shorteventid))
.map(|(shortstatekey, shorteventid)| {
compress_state_event(*shortstatekey, shorteventid)
})
}
pub async fn compress_state_event(&self, shortstatekey: ShortStateKey, event_id: &EventId) -> CompressedStateEvent {
pub async fn compress_state_event(
&self,
shortstatekey: ShortStateKey,
event_id: &EventId,
) -> CompressedStateEvent {
let shorteventid = self
.services
.short
@ -227,8 +243,11 @@ impl Service {
/// * `parent_states` - A stack with info on shortstatehash, full state,
/// added diff and removed diff for each parent layer
pub fn save_state_from_diff(
&self, shortstatehash: ShortStateHash, statediffnew: Arc<HashSet<CompressedStateEvent>>,
statediffremoved: Arc<HashSet<CompressedStateEvent>>, diff_to_sibling: usize,
&self,
shortstatehash: ShortStateHash,
statediffnew: Arc<HashSet<CompressedStateEvent>>,
statediffremoved: Arc<HashSet<CompressedStateEvent>>,
diff_to_sibling: usize,
mut parent_states: ParentStatesVec,
) -> Result {
let statediffnew_len = statediffnew.len();
@ -274,14 +293,11 @@ impl Service {
if parent_states.is_empty() {
// There is no parent layer, create a new state
self.save_statediff(
shortstatehash,
&StateDiff {
parent: None,
added: statediffnew,
removed: statediffremoved,
},
);
self.save_statediff(shortstatehash, &StateDiff {
parent: None,
added: statediffnew,
removed: statediffremoved,
});
return Ok(());
};
@ -327,14 +343,11 @@ impl Service {
)?;
} else {
// Diff small enough, we add diff as layer on top of parent
self.save_statediff(
shortstatehash,
&StateDiff {
parent: Some(parent.shortstatehash),
added: statediffnew,
removed: statediffremoved,
},
);
self.save_statediff(shortstatehash, &StateDiff {
parent: Some(parent.shortstatehash),
added: statediffnew,
removed: statediffremoved,
});
}
Ok(())
@ -344,7 +357,9 @@ impl Service {
/// room state
#[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")]
pub async fn save_state(
&self, room_id: &RoomId, new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
&self,
room_id: &RoomId,
new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> Result<HashSetCompressStateEvent> {
let previous_shortstatehash = self
.services
@ -353,7 +368,8 @@ impl Service {
.await
.ok();
let state_hash = utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..]));
let state_hash =
utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..]));
let (new_shortstatehash, already_existed) = self
.services
@ -374,22 +390,23 @@ impl Service {
ShortStateInfoVec::new()
};
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = new_state_ids_compressed
.difference(&parent_stateinfo.full_state)
.copied()
.collect();
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = new_state_ids_compressed
.difference(&parent_stateinfo.full_state)
.copied()
.collect();
let statediffremoved: HashSet<_> = parent_stateinfo
.full_state
.difference(&new_state_ids_compressed)
.copied()
.collect();
let statediffremoved: HashSet<_> = parent_stateinfo
.full_state
.difference(&new_state_ids_compressed)
.copied()
.collect();
(Arc::new(statediffnew), Arc::new(statediffremoved))
} else {
(new_state_ids_compressed, Arc::new(HashSet::new()))
};
(Arc::new(statediffnew), Arc::new(statediffremoved))
} else {
(new_state_ids_compressed, Arc::new(HashSet::new()))
};
if !already_existed {
self.save_state_from_diff(
@ -418,7 +435,9 @@ impl Service {
.shortstatehash_statediff
.aqry::<BUFSIZE, _>(&shortstatehash)
.await
.map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?;
.map_err(|e| {
err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}"))
})?;
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()])
.ok()
@ -484,7 +503,10 @@ impl Service {
#[inline]
#[must_use]
fn compress_state_event(shortstatekey: ShortStateKey, shorteventid: ShortEventId) -> CompressedStateEvent {
fn compress_state_event(
shortstatekey: ShortStateKey,
shorteventid: ShortEventId,
) -> CompressedStateEvent {
const SIZE: usize = size_of::<CompressedStateEvent>();
let mut v = ArrayVec::<u8, SIZE>::new();
@ -497,7 +519,9 @@ fn compress_state_event(shortstatekey: ShortStateKey, shorteventid: ShortEventId
#[inline]
#[must_use]
pub fn parse_compressed_state_event(compressed_event: CompressedStateEvent) -> (ShortStateKey, ShortEventId) {
pub fn parse_compressed_state_event(
compressed_event: CompressedStateEvent,
) -> (ShortStateKey, ShortEventId) {
use utils::u64_from_u8;
let shortstatekey = u64_from_u8(&compressed_event[0..size_of::<ShortStateKey>()]);

View file

@ -11,8 +11,8 @@ use conduwuit::{
use database::{Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{
api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint, CanonicalJsonValue,
EventId, OwnedUserId, RoomId, UserId,
api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint,
CanonicalJsonValue, EventId, OwnedUserId, RoomId, UserId,
};
use serde_json::json;
@ -55,7 +55,9 @@ impl Service {
.timeline
.get_pdu_id(root_event_id)
.await
.map_err(|e| err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}"))))?;
.map_err(|e| {
err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}")))
})?;
let root_pdu = self
.services
@ -79,8 +81,9 @@ impl Service {
.get("m.relations")
.and_then(|r| r.as_object())
.and_then(|r| r.get("m.thread"))
.and_then(|relations| serde_json::from_value::<BundledThread>(relations.clone().into()).ok())
{
.and_then(|relations| {
serde_json::from_value::<BundledThread>(relations.clone().into()).ok()
}) {
// Thread already existed
relations.count = relations.count.saturating_add(uint!(1));
relations.latest_event = pdu.to_message_like_event();
@ -129,7 +132,11 @@ impl Service {
}
pub async fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, shorteventid: PduCount, _inc: &'a IncludeThreads,
&'a self,
user_id: &'a UserId,
room_id: &'a RoomId,
shorteventid: PduCount,
_inc: &'a IncludeThreads,
) -> Result<impl Stream<Item = (PduCount, PduEvent)> + Send + 'a> {
let shortroomid: ShortRoomId = self.services.short.get_shortroomid(room_id).await?;
@ -160,7 +167,11 @@ impl Service {
Ok(stream)
}
pub(super) fn update_participants(&self, root_id: &RawPduId, participants: &[OwnedUserId]) -> Result {
pub(super) fn update_participants(
&self,
root_id: &RawPduId,
participants: &[OwnedUserId],
) -> Result {
let users = participants
.iter()
.map(|user| user.as_bytes())

View file

@ -13,7 +13,9 @@ use conduwuit::{
};
use database::{Database, Deserialized, Json, KeyVal, Map};
use futures::{future::select_ok, FutureExt, Stream, StreamExt};
use ruma::{api::Direction, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use ruma::{
api::Direction, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
};
use tokio::sync::Mutex;
use super::{PduId, RawPduId};
@ -54,15 +56,19 @@ impl Data {
}
}
pub(super) async fn last_timeline_count(&self, sender_user: Option<&UserId>, room_id: &RoomId) -> Result<PduCount> {
pub(super) async fn last_timeline_count(
&self,
sender_user: Option<&UserId>,
room_id: &RoomId,
) -> Result<PduCount> {
match self
.lasttimelinecount_cache
.lock()
.await
.entry(room_id.into())
{
hash_map::Entry::Occupied(o) => Ok(*o.get()),
hash_map::Entry::Vacant(v) => Ok(self
| hash_map::Entry::Occupied(o) => Ok(*o.get()),
| hash_map::Entry::Vacant(v) => Ok(self
.pdus_rev(sender_user, room_id, PduCount::max())
.await?
.next()
@ -93,7 +99,10 @@ impl Data {
}
/// Returns the json of a pdu.
pub(super) async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
pub(super) async fn get_non_outlier_pdu_json(
&self,
event_id: &EventId,
) -> Result<CanonicalJsonObject> {
let pduid = self.get_pdu_id(event_id).await?;
self.pduid_pdu.get(&pduid).await.deserialized()
@ -160,12 +169,19 @@ impl Data {
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result<CanonicalJsonObject> {
pub(super) async fn get_pdu_json_from_id(
&self,
pdu_id: &RawPduId,
) -> Result<CanonicalJsonObject> {
self.pduid_pdu.get(pdu_id).await.deserialized()
}
pub(super) async fn append_pdu(
&self, pdu_id: &RawPduId, pdu: &PduEvent, json: &CanonicalJsonObject, count: PduCount,
&self,
pdu_id: &RawPduId,
pdu: &PduEvent,
json: &CanonicalJsonObject,
count: PduCount,
) {
debug_assert!(matches!(count, PduCount::Normal(_)), "PduCount not Normal");
@ -179,7 +195,12 @@ impl Data {
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes());
}
pub(super) fn prepend_backfill_pdu(&self, pdu_id: &RawPduId, event_id: &EventId, json: &CanonicalJsonObject) {
pub(super) fn prepend_backfill_pdu(
&self,
pdu_id: &RawPduId,
event_id: &EventId,
json: &CanonicalJsonObject,
) {
self.pduid_pdu.raw_put(pdu_id, Json(json));
self.eventid_pduid.insert(event_id, pdu_id);
self.eventid_outlierpdu.remove(event_id);
@ -187,7 +208,10 @@ impl Data {
/// Removes a pdu and creates a new one with the same id.
pub(super) async fn replace_pdu(
&self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, _pdu: &PduEvent,
&self,
pdu_id: &RawPduId,
pdu_json: &CanonicalJsonObject,
_pdu: &PduEvent,
) -> Result {
if self.pduid_pdu.get(pdu_id).await.is_not_found() {
return Err!(Request(NotFound("PDU does not exist.")));
@ -202,7 +226,10 @@ impl Data {
/// happened before the event with id `until` in reverse-chronological
/// order.
pub(super) async fn pdus_rev<'a>(
&'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, until: PduCount,
&'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId,
until: PduCount,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
let current = self
.count_to_id(room_id, until, Direction::Backward)
@ -219,7 +246,10 @@ impl Data {
}
pub(super) async fn pdus<'a>(
&'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, from: PduCount,
&'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId,
from: PduCount,
) -> Result<impl Stream<Item = PdusIterItem> + Send + Unpin + 'a> {
let current = self.count_to_id(room_id, from, Direction::Forward).await?;
let prefix = current.shortroomid();
@ -236,8 +266,8 @@ impl Data {
fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: Option<&UserId>) -> PdusIterItem {
let pdu_id: RawPduId = pdu_id.into();
let mut pdu =
serde_json::from_slice::<PduEvent>(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON");
let mut pdu = serde_json::from_slice::<PduEvent>(pdu)
.expect("PduEvent in pduid_pdu database column is invalid JSON");
if Some(pdu.sender.borrow()) != user_id {
pdu.remove_transaction_id().log_err().ok();
@ -249,7 +279,10 @@ impl Data {
}
pub(super) fn increment_notification_counts(
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
&self,
room_id: &RoomId,
notifies: Vec<OwnedUserId>,
highlights: Vec<OwnedUserId>,
) {
let _cork = self.db.cork();
@ -268,7 +301,12 @@ impl Data {
}
}
async fn count_to_id(&self, room_id: &RoomId, shorteventid: PduCount, dir: Direction) -> Result<RawPduId> {
async fn count_to_id(
&self,
room_id: &RoomId,
shorteventid: PduCount,
dir: Direction,
) -> Result<RawPduId> {
let shortroomid: ShortRoomId = self
.services
.short

View file

@ -15,7 +15,9 @@ use conduwuit::{
validated, warn, Err, Error, Result, Server,
};
pub use conduwuit::{PduId, RawPduId};
use futures::{future, future::ready, Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use futures::{
future, future::ready, Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
};
use ruma::{
api::federation,
canonical_json::to_canonical_value,
@ -32,8 +34,8 @@ use ruma::{
},
push::{Action, Ruleset, Tweak},
state_res::{self, Event, RoomVersion},
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName,
RoomId, RoomVersionId, ServerName, UserId,
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId,
OwnedServerName, RoomId, RoomVersionId, ServerName, UserId,
};
use serde::Deserialize;
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
@ -116,7 +118,8 @@ impl crate::Service for Service {
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_accessor: args
.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"),
read_receipt: args.depend::<rooms::read_receipt::Service>("rooms::read_receipt"),
sending: args.depend::<sending::Service>("sending"),
@ -127,7 +130,8 @@ impl crate::Service for Service {
threads: args.depend::<rooms::threads::Service>("rooms::threads"),
search: args.depend::<rooms::search::Service>("rooms::search"),
spaces: args.depend::<rooms::spaces::Service>("rooms::spaces"),
event_handler: args.depend::<rooms::event_handler::Service>("rooms::event_handler"),
event_handler: args
.depend::<rooms::event_handler::Service>("rooms::event_handler"),
},
db: Data::new(&args),
mutex_insert: RoomMutexMap::new(),
@ -185,12 +189,18 @@ impl Service {
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn last_timeline_count(&self, sender_user: Option<&UserId>, room_id: &RoomId) -> Result<PduCount> {
pub async fn last_timeline_count(
&self,
sender_user: Option<&UserId>,
room_id: &RoomId,
) -> Result<PduCount> {
self.db.last_timeline_count(sender_user, room_id).await
}
/// Returns the `count` of this pdu's id.
pub async fn get_pdu_count(&self, event_id: &EventId) -> Result<PduCount> { self.db.get_pdu_count(event_id).await }
pub async fn get_pdu_count(&self, event_id: &EventId) -> Result<PduCount> {
self.db.get_pdu_count(event_id).await
}
// TODO Is this the same as the function above?
/*
@ -222,13 +232,18 @@ impl Service {
/// Returns the json of a pdu.
#[inline]
pub async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
pub async fn get_non_outlier_pdu_json(
&self,
event_id: &EventId,
) -> Result<CanonicalJsonObject> {
self.db.get_non_outlier_pdu_json(event_id).await
}
/// Returns the pdu's id.
#[inline]
pub async fn get_pdu_id(&self, event_id: &EventId) -> Result<RawPduId> { self.db.get_pdu_id(event_id).await }
pub async fn get_pdu_id(&self, event_id: &EventId) -> Result<RawPduId> {
self.db.get_pdu_id(event_id).await
}
/// Returns the pdu.
///
@ -241,19 +256,26 @@ impl Service {
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub async fn get_pdu(&self, event_id: &EventId) -> Result<PduEvent> { self.db.get_pdu(event_id).await }
pub async fn get_pdu(&self, event_id: &EventId) -> Result<PduEvent> {
self.db.get_pdu(event_id).await
}
/// Checks if pdu exists
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub fn pdu_exists<'a>(&'a self, event_id: &'a EventId) -> impl Future<Output = bool> + Send + 'a {
pub fn pdu_exists<'a>(
&'a self,
event_id: &'a EventId,
) -> impl Future<Output = bool> + Send + 'a {
self.db.pdu_exists(event_id)
}
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
pub async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result<PduEvent> { self.db.get_pdu_from_id(pdu_id).await }
pub async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result<PduEvent> {
self.db.get_pdu_from_id(pdu_id).await
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
pub async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result<CanonicalJsonObject> {
@ -262,7 +284,12 @@ impl Service {
/// Removes a pdu and creates a new one with the same id.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn replace_pdu(&self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> {
pub async fn replace_pdu(
&self,
pdu_id: &RawPduId,
pdu_json: &CanonicalJsonObject,
pdu: &PduEvent,
) -> Result<()> {
self.db.replace_pdu(pdu_id, pdu_json, pdu).await
}
@ -278,7 +305,8 @@ impl Service {
pdu: &PduEvent,
mut pdu_json: CanonicalJsonObject,
leaves: Vec<OwnedEventId>,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
* mutex */
) -> Result<RawPduId> {
// Coalesce database writes for the remainder of this scope.
let _cork = self.db.db.cork_and_flush();
@ -313,10 +341,16 @@ impl Service {
unsigned.insert(
"prev_content".to_owned(),
CanonicalJsonValue::Object(
utils::to_canonical_object(prev_state.content.clone()).map_err(|e| {
error!("Failed to convert prev_state to canonical JSON: {e}");
Error::bad_database("Failed to convert prev_state to canonical JSON.")
})?,
utils::to_canonical_object(prev_state.content.clone()).map_err(
|e| {
error!(
"Failed to convert prev_state to canonical JSON: {e}"
);
Error::bad_database(
"Failed to convert prev_state to canonical JSON.",
)
},
)?,
),
);
unsigned.insert(
@ -357,11 +391,7 @@ impl Service {
.reset_notification_counts(&pdu.sender, &pdu.room_id);
let count2 = PduCount::Normal(self.services.globals.next_count().unwrap());
let pdu_id: RawPduId = PduId {
shortroomid,
shorteventid: count2,
}
.into();
let pdu_id: RawPduId = PduId { shortroomid, shorteventid: count2 }.into();
// Insert pdu
self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await;
@ -408,7 +438,10 @@ impl Service {
.account_data
.get_global(user, GlobalAccountDataEventType::PushRules)
.await
.map_or_else(|_| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global);
.map_or_else(
|_| Ruleset::server_default(user),
|ev: PushRulesEvent| ev.content.global,
);
let mut highlight = false;
let mut notify = false;
@ -420,11 +453,11 @@ impl Service {
.await
{
match action {
Action::Notify => notify = true,
Action::SetTweak(Tweak::Highlight(true)) => {
| Action::Notify => notify = true,
| Action::SetTweak(Tweak::Highlight(true)) => {
highlight = true;
},
_ => {},
| _ => {},
};
// Break early if both conditions are true
@ -457,12 +490,12 @@ impl Service {
.increment_notification_counts(&pdu.room_id, notifies, highlights);
match pdu.kind {
TimelineEventType::RoomRedaction => {
| TimelineEventType::RoomRedaction => {
use RoomVersionId::*;
let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?;
match room_version_id {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &pdu.redacts {
if self
.services
@ -474,7 +507,7 @@ impl Service {
}
}
},
_ => {
| _ => {
let content: RoomRedactionEventContent = pdu.get_content()?;
if let Some(redact_id) = &content.redacts {
if self
@ -489,7 +522,7 @@ impl Service {
},
};
},
TimelineEventType::SpaceChild => {
| TimelineEventType::SpaceChild =>
if let Some(_state_key) = &pdu.state_key {
self.services
.spaces
@ -497,18 +530,18 @@ impl Service {
.lock()
.await
.remove(&pdu.room_id);
}
},
TimelineEventType::RoomMember => {
},
| TimelineEventType::RoomMember => {
if let Some(state_key) = &pdu.state_key {
// if the state_key fails
let target_user_id =
UserId::parse(state_key.clone()).expect("This state_key was previously validated");
let target_user_id = UserId::parse(state_key.clone())
.expect("This state_key was previously validated");
let content: RoomMemberEventContent = pdu.get_content()?;
let invite_state = match content.membership {
MembershipState::Invite => self.services.state.summary_stripped(pdu).await.into(),
_ => None,
| MembershipState::Invite =>
self.services.state.summary_stripped(pdu).await.into(),
| _ => None,
};
// Update our membership info, we do this here incase a user is invited
@ -527,7 +560,7 @@ impl Service {
.await?;
}
},
TimelineEventType::RoomMessage => {
| TimelineEventType::RoomMessage => {
let content: ExtractBody = pdu.get_content()?;
if let Some(body) = content.body {
self.services.search.index_pdu(shortroomid, &pdu_id, &body);
@ -539,7 +572,7 @@ impl Service {
}
}
},
_ => {},
| _ => {},
}
if let Ok(content) = pdu.get_content::<ExtractRelatesToEventId>() {
@ -552,24 +585,23 @@ impl Service {
if let Ok(content) = pdu.get_content::<ExtractRelatesTo>() {
match content.relates_to {
Relation::Reply {
in_reply_to,
} => {
| Relation::Reply { in_reply_to } => {
// We need to do it again here, because replies don't have
// event_id as a top level field
if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await {
if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await
{
self.services
.pdu_metadata
.add_relation(count2, related_pducount);
}
},
Relation::Thread(thread) => {
| Relation::Thread(thread) => {
self.services
.threads
.add_to_thread(&thread.event_id, pdu)
.await?;
},
_ => {}, // TODO: Aggregate other types
| _ => {}, // TODO: Aggregate other types
}
}
@ -637,7 +669,8 @@ impl Service {
pdu_builder: PduBuilder,
sender: &UserId,
room_id: &RoomId,
_mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
_mutex_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room
* state mutex */
) -> Result<(PduEvent, CanonicalJsonObject)> {
let PduBuilder {
event_type,
@ -707,7 +740,8 @@ impl Service {
unsigned.insert("prev_content".to_owned(), prev_pdu.get_content_as_value());
unsigned.insert(
"prev_sender".to_owned(),
serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"),
serde_json::to_value(&prev_pdu.sender)
.expect("UserId::to_value always works"),
);
unsigned.insert(
"replaces_state".to_owned(),
@ -744,9 +778,7 @@ impl Service {
} else {
Some(to_raw_value(&unsigned).expect("to_raw_value always works"))
},
hashes: EventHash {
sha256: "aaa".to_owned(),
},
hashes: EventHash { sha256: "aaa".to_owned() },
signatures: None,
};
@ -769,13 +801,14 @@ impl Service {
}
// Hash and sign
let mut pdu_json = utils::to_canonical_object(&pdu)
.map_err(|e| err!(Request(BadJson(warn!("Failed to convert PDU to canonical JSON: {e}")))))?;
let mut pdu_json = utils::to_canonical_object(&pdu).map_err(|e| {
err!(Request(BadJson(warn!("Failed to convert PDU to canonical JSON: {e}"))))
})?;
// room v3 and above removed the "event_id" field from remote PDU format
match room_version_id {
RoomVersionId::V1 | RoomVersionId::V2 => {},
_ => {
| RoomVersionId::V1 | RoomVersionId::V2 => {},
| _ => {
pdu_json.remove("event_id");
},
};
@ -783,7 +816,8 @@ impl Service {
// Add origin because synapse likes that (and it's required in the spec)
pdu_json.insert(
"origin".to_owned(),
to_canonical_value(self.services.globals.server_name()).expect("server name is a valid CanonicalJsonValue"),
to_canonical_value(self.services.globals.server_name())
.expect("server name is a valid CanonicalJsonValue"),
);
if let Err(e) = self
@ -792,17 +826,18 @@ impl Service {
.hash_and_sign_event(&mut pdu_json, &room_version_id)
{
return match e {
Error::Signatures(ruma::signatures::Error::PduSize) => {
| Error::Signatures(ruma::signatures::Error::PduSize) => {
Err!(Request(TooLarge("Message/PDU is too long (exceeds 65535 bytes)")))
},
_ => Err!(Request(Unknown(warn!("Signing event failed: {e}")))),
| _ => Err!(Request(Unknown(warn!("Signing event failed: {e}")))),
};
}
// Generate event id
pdu.event_id = EventId::parse_arc(format!(
"${}",
ruma::signatures::reference_hash(&pdu_json, &room_version_id).expect("ruma can calculate reference hashes")
ruma::signatures::reference_hash(&pdu_json, &room_version_id)
.expect("ruma can calculate reference hashes")
))
.expect("ruma's reference hashes are valid event ids");
@ -830,7 +865,8 @@ impl Service {
pdu_builder: PduBuilder,
sender: &UserId,
room_id: &RoomId,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
* mutex */
) -> Result<Arc<EventId>> {
let (pdu, pdu_json) = self
.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)
@ -844,7 +880,7 @@ impl Service {
if pdu.kind == TimelineEventType::RoomRedaction {
use RoomVersionId::*;
match self.services.state.get_room_version(&pdu.room_id).await? {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &pdu.redacts {
if !self
.services
@ -856,7 +892,7 @@ impl Service {
}
};
},
_ => {
| _ => {
let content: RoomRedactionEventContent = pdu.get_content()?;
if let Some(redact_id) = &content.redacts {
if !self
@ -937,7 +973,8 @@ impl Service {
new_room_leaves: Vec<OwnedEventId>,
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
soft_fail: bool,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state
* mutex */
) -> Result<Option<RawPduId>> {
// We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't
@ -971,7 +1008,9 @@ impl Service {
/// items.
#[inline]
pub fn all_pdus<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId,
&'a self,
user_id: &'a UserId,
room_id: &'a RoomId,
) -> impl Stream<Item = PdusIterItem> + Send + Unpin + 'a {
self.pdus(Some(user_id), room_id, None)
.map_ok(|stream| stream.map(Ok))
@ -983,7 +1022,10 @@ impl Service {
/// Reverse iteration starting at from.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn pdus_rev<'a>(
&'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, until: Option<PduCount>,
&'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId,
until: Option<PduCount>,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
self.db
.pdus_rev(user_id, room_id, until.unwrap_or_else(PduCount::max))
@ -993,7 +1035,10 @@ impl Service {
/// Forward iteration starting at from.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn pdus<'a>(
&'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, from: Option<PduCount>,
&'a self,
user_id: Option<&'a UserId>,
room_id: &'a RoomId,
from: Option<PduCount>,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
self.db
.pdus(user_id, room_id, from.unwrap_or_else(PduCount::min))
@ -1002,17 +1047,21 @@ impl Service {
/// Replace a PDU with the redacted form.
#[tracing::instrument(skip(self, reason))]
pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: ShortRoomId) -> Result {
pub async fn redact_pdu(
&self,
event_id: &EventId,
reason: &PduEvent,
shortroomid: ShortRoomId,
) -> Result {
// TODO: Don't reserialize, keep original json
let Ok(pdu_id) = self.get_pdu_id(event_id).await else {
// If event does not exist, just noop
return Ok(());
};
let mut pdu = self
.get_pdu_from_id(&pdu_id)
.await
.map_err(|e| err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU."))))?;
let mut pdu = self.get_pdu_from_id(&pdu_id).await.map_err(|e| {
err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU.")))
})?;
if let Ok(content) = pdu.get_content::<ExtractBody>() {
if let Some(body) = content.body {
@ -1026,8 +1075,9 @@ impl Service {
pdu.redact(&room_version_id, reason)?;
let obj = utils::to_canonical_object(&pdu)
.map_err(|e| err!(Database(error!(?event_id, ?e, "Failed to convert PDU to canonical JSON"))))?;
let obj = utils::to_canonical_object(&pdu).map_err(|e| {
err!(Database(error!(?event_id, ?e, "Failed to convert PDU to canonical JSON")))
})?;
self.replace_pdu(&pdu_id, &obj, &pdu).await
}
@ -1069,7 +1119,9 @@ impl Service {
.unwrap_or_default();
let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| {
if level > &power_levels.users_default && !self.services.globals.user_is_local(user_id) {
if level > &power_levels.users_default
&& !self.services.globals.user_is_local(user_id)
{
Some(user_id.server_name())
} else {
None
@ -1124,7 +1176,7 @@ impl Service {
)
.await;
match response {
Ok(response) => {
| Ok(response) => {
for pdu in response.pdus {
if let Err(e) = self.backfill_pdu(backfill_server, pdu).boxed().await {
debug_warn!("Failed to add backfilled pdu in room {room_id}: {e}");
@ -1132,7 +1184,7 @@ impl Service {
}
return Ok(());
},
Err(e) => {
| Err(e) => {
warn!("{backfill_server} failed to provide backfill for room {room_id}: {e}");
},
}
@ -1144,7 +1196,8 @@ impl Service {
#[tracing::instrument(skip(self, pdu))]
pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) -> Result<()> {
let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu).await?;
let (event_id, value, room_id) =
self.services.event_handler.parse_incoming_pdu(&pdu).await?;
// Lock so we cannot backfill the same pdu twice at the same time
let mutex_lock = self
@ -1210,10 +1263,10 @@ impl Service {
#[tracing::instrument(skip_all, level = "debug")]
async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Result<()> {
match pdu.event_type() {
TimelineEventType::RoomEncryption => {
| TimelineEventType::RoomEncryption => {
return Err!(Request(Forbidden(error!("Encryption not supported in admins room."))));
},
TimelineEventType::RoomMember => {
| TimelineEventType::RoomMember => {
let target = pdu
.state_key()
.filter(|v| v.starts_with('@'))
@ -1223,9 +1276,11 @@ async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Res
let content: RoomMemberEventContent = pdu.get_content()?;
match content.membership {
MembershipState::Leave => {
| MembershipState::Leave => {
if target == server_user {
return Err!(Request(Forbidden(error!("Server user cannot leave the admins room."))));
return Err!(Request(Forbidden(error!(
"Server user cannot leave the admins room."
))));
}
let count = self
@ -1239,13 +1294,17 @@ async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Res
.await;
if count < 2 {
return Err!(Request(Forbidden(error!("Last admin cannot leave the admins room."))));
return Err!(Request(Forbidden(error!(
"Last admin cannot leave the admins room."
))));
}
},
MembershipState::Ban if pdu.state_key().is_some() => {
| MembershipState::Ban if pdu.state_key().is_some() => {
if target == server_user {
return Err!(Request(Forbidden(error!("Server cannot be banned from admins room."))));
return Err!(Request(Forbidden(error!(
"Server cannot be banned from admins room."
))));
}
let count = self
@ -1259,13 +1318,15 @@ async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Res
.await;
if count < 2 {
return Err!(Request(Forbidden(error!("Last admin cannot be banned from admins room."))));
return Err!(Request(Forbidden(error!(
"Last admin cannot be banned from admins room."
))));
}
},
_ => {},
| _ => {},
};
},
_ => {},
| _ => {},
};
Ok(())

View file

@ -52,7 +52,12 @@ impl crate::Service for Service {
impl Service {
/// Sets a user as typing until the timeout timestamp is reached or
/// roomtyping_remove is called.
pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
pub async fn typing_add(
&self,
user_id: &UserId,
room_id: &RoomId,
timeout: u64,
) -> Result<()> {
debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}");
// update clients
self.typing
@ -177,15 +182,15 @@ impl Service {
/// Returns a new typing EDU.
pub async fn typings_all(
&self, room_id: &RoomId, sender_user: &UserId,
&self,
room_id: &RoomId,
sender_user: &UserId,
) -> Result<SyncEphemeralRoomEvent<ruma::events::typing::TypingEventContent>> {
let room_typing_indicators = self.typing.read().await.get(room_id).cloned();
let Some(typing_indicators) = room_typing_indicators else {
return Ok(SyncEphemeralRoomEvent {
content: ruma::events::typing::TypingEventContent {
user_ids: Vec::new(),
},
content: ruma::events::typing::TypingEventContent { user_ids: Vec::new() },
});
};
@ -204,13 +209,16 @@ impl Service {
.await;
Ok(SyncEphemeralRoomEvent {
content: ruma::events::typing::TypingEventContent {
user_ids,
},
content: ruma::events::typing::TypingEventContent { user_ids },
})
}
async fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> {
async fn federation_send(
&self,
room_id: &RoomId,
user_id: &UserId,
typing: bool,
) -> Result<()> {
debug_assert!(
self.services.globals.user_is_local(user_id),
"tried to broadcast typing status of remote user",

View file

@ -92,7 +92,12 @@ pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -
}
#[implement(Service)]
pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: ShortStateHash) {
pub async fn associate_token_shortstatehash(
&self,
room_id: &RoomId,
token: u64,
shortstatehash: ShortStateHash,
) {
let shortroomid = self
.services
.short
@ -108,7 +113,11 @@ pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64,
}
#[implement(Service)]
pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<ShortStateHash> {
pub async fn get_token_shortstatehash(
&self,
room_id: &RoomId,
token: u64,
) -> Result<ShortStateHash> {
let shortroomid = self.services.short.get_shortroomid(room_id).await?;
let key: &[u64] = &[shortroomid, token];