diff --git a/src/api/client/context.rs b/src/api/client/context.rs index bf87f5e1..652e17f4 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -1,4 +1,4 @@ -use std::iter::once; +use std::{collections::HashMap, iter::once}; use axum::extract::State; use conduit::{ @@ -10,7 +10,7 @@ use futures::{future::try_join, StreamExt, TryFutureExt}; use ruma::{ api::client::{context::get_context, filter::LazyLoadOptions}, events::StateEventType, - UserId, + OwnedEventId, UserId, }; use crate::{ @@ -124,7 +124,7 @@ pub(crate) async fn get_context_route( .await .map_err(|e| err!(Database("State hash not found: {e}")))?; - let state_ids = services + let state_ids: HashMap<_, OwnedEventId> = services .rooms .state_accessor .state_full_ids(shortstatehash) diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index 9c1cefdb..5578077f 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -32,7 +32,7 @@ use ruma::{ TimelineEventType::*, }, serde::Raw, - uint, DeviceId, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId, + uint, DeviceId, EventId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId, }; use tracing::{Instrument as _, Span}; @@ -398,7 +398,7 @@ async fn handle_left_room( Err(_) => HashMap::new(), }; - let Ok(left_event_id) = services + let Ok(left_event_id): Result = services .rooms .state_accessor .room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str()) @@ -666,7 +666,7 @@ async fn load_joined_room( let (joined_member_count, invited_member_count, heroes) = calculate_counts().await?; - let current_state_ids = services + let current_state_ids: HashMap<_, OwnedEventId> = services .rooms .state_accessor .state_full_ids(current_shortstatehash) @@ -736,7 +736,7 @@ async fn load_joined_room( let mut delta_state_events = Vec::new(); if since_shortstatehash != current_shortstatehash { - let current_state_ids = services + let current_state_ids: HashMap<_, OwnedEventId> = services .rooms .state_accessor .state_full_ids(current_shortstatehash) diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs index 62c313e2..14d79c19 100644 --- a/src/api/client/sync/v4.rs +++ b/src/api/client/sync/v4.rs @@ -1,6 +1,6 @@ use std::{ cmp::{self, Ordering}, - collections::{BTreeMap, BTreeSet, HashSet}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, time::Duration, }; @@ -30,7 +30,7 @@ use ruma::{ TimelineEventType::{self, *}, }, state_res::Event, - uint, MilliSecondsSinceUnixEpoch, OwnedRoomId, UInt, UserId, + uint, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, UInt, UserId, }; use service::{rooms::read_receipt::pack_receipts, Services}; @@ -211,7 +211,7 @@ pub(crate) async fn sync_events_v4_route( let new_encrypted_room = encrypted_room && since_encryption.is_err(); if encrypted_room { - let current_state_ids = services + let current_state_ids: HashMap<_, OwnedEventId> = services .rooms .state_accessor .state_full_ids(current_shortstatehash) diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 0ad07b1e..92ab3b50 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -1,6 +1,6 @@ #![allow(deprecated)] -use std::borrow::Borrow; +use std::{borrow::Borrow, collections::HashMap}; use axum::extract::State; use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result}; @@ -11,7 +11,7 @@ use ruma::{ room::member::{MembershipState, RoomMemberEventContent}, StateEventType, }, - CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName, + CanonicalJsonValue, OwnedEventId, OwnedServerName, OwnedUserId, RoomId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use service::Services; @@ -165,7 +165,7 @@ async fn create_join_event( drop(mutex_lock); - let state_ids = services + let state_ids: HashMap<_, OwnedEventId> = services .rooms .state_accessor .state_full_ids(shortstatehash) diff --git a/src/api/server/state.rs b/src/api/server/state.rs index b21fce68..400b9237 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -3,7 +3,7 @@ use std::{borrow::Borrow, iter::once}; use axum::extract::State; use conduit::{err, result::LogErr, utils::IterStream, Result}; use futures::{FutureExt, StreamExt, TryStreamExt}; -use ruma::api::federation::event::get_room_state; +use ruma::{api::federation::event::get_room_state, OwnedEventId}; use super::AccessCheck; use crate::Ruma; @@ -30,14 +30,18 @@ pub(crate) async fn get_room_state_route( .await .map_err(|_| err!(Request(NotFound("PDU state not found."))))?; - let pdus = services + let state_ids: Vec = services .rooms .state_accessor .state_full_ids(shortstatehash) .await .log_err() .map_err(|_| err!(Request(NotFound("PDU state IDs not found."))))? - .values() + .into_values() + .collect(); + + let pdus = state_ids + .iter() .try_stream() .and_then(|id| services.rooms.timeline.get_pdu_json(id)) .and_then(|pdu| { diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index 0c023bf0..55662a40 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -3,7 +3,7 @@ use std::{borrow::Borrow, iter::once}; use axum::extract::State; use conduit::{err, Result}; use futures::StreamExt; -use ruma::api::federation::event::get_room_state_ids; +use ruma::{api::federation::event::get_room_state_ids, OwnedEventId}; use super::AccessCheck; use crate::Ruma; @@ -31,14 +31,13 @@ pub(crate) async fn get_room_state_ids_route( .await .map_err(|_| err!(Request(NotFound("Pdu state not found."))))?; - let pdu_ids = services + let pdu_ids: Vec = services .rooms .state_accessor .state_full_ids(shortstatehash) .await .map_err(|_| err!(Request(NotFound("State ids not found"))))? .into_values() - .map(|id| (*id).to_owned()) .collect(); let auth_chain_ids = services diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 2b80e3dc..3e972ca6 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -1,7 +1,7 @@ mod tests; use std::{ - collections::VecDeque, + collections::{HashMap, VecDeque}, fmt::{Display, Formatter}, str::FromStr, sync::Arc, @@ -572,7 +572,7 @@ impl Service { return Ok(None); }; - let state = self + let state: HashMap<_, Arc<_>> = self .services .state_accessor .state_full_ids(current_shortstatehash) diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 6c67b856..7760d5b6 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{borrow::Borrow, collections::HashMap, sync::Arc}; use conduit::{ at, err, @@ -8,6 +8,7 @@ use conduit::{ use database::{Deserialized, Map}; use futures::{StreamExt, TryFutureExt}; use ruma::{events::StateEventType, EventId, OwnedEventId, RoomId}; +use serde::Deserialize; use crate::{ rooms, @@ -84,7 +85,11 @@ impl Data { Ok(full_pdus) } - pub(super) async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result>> { + pub(super) async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result> + where + Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + ::Owned: Borrow, + { let short_ids = self.state_full_shortids(shortstatehash).await?; let event_ids = self @@ -123,11 +128,15 @@ impl Data { Ok(shortids) } - /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). - #[allow(clippy::unused_self)] - pub(super) async fn state_get_id( + /// Returns a single EventId from `room_id` with key + /// (`event_type`,`state_key`). + pub(super) async fn state_get_id( &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, - ) -> Result> { + ) -> Result + where + Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + ::Owned: Borrow, + { let shortstatekey = self .services .short @@ -162,7 +171,7 @@ impl Data { &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, ) -> Result { self.state_get_id(shortstatehash, event_type, state_key) - .and_then(|event_id| async move { self.services.timeline.get_pdu(&event_id).await }) + .and_then(|event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await }) .await } @@ -204,10 +213,15 @@ impl Data { .await } - /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). - pub(super) async fn room_state_get_id( + /// Returns a single EventId from `room_id` with key + /// (`event_type`,`state_key`). + pub(super) async fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result> { + ) -> Result + where + Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + ::Owned: Borrow, + { self.services .state .get_room_shortstatehash(room_id) diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 18f999b4..e42d3764 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -1,6 +1,7 @@ mod data; use std::{ + borrow::Borrow, collections::HashMap, fmt::Write, sync::{Arc, Mutex as StdMutex, Mutex}, @@ -101,8 +102,12 @@ impl Service { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self), level = "debug")] - pub async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result>> { - self.db.state_full_ids(shortstatehash).await + pub async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result> + where + Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + ::Owned: Borrow, + { + self.db.state_full_ids::(shortstatehash).await } #[inline] @@ -118,12 +123,16 @@ impl Service { self.db.state_full(shortstatehash).await } - /// Returns a single PDU from `room_id` with key (`event_type`, + /// Returns a single EventId from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub async fn state_get_id( + pub async fn state_get_id( &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, - ) -> Result> { + ) -> Result + where + Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + ::Owned: Borrow, + { self.db .state_get_id(shortstatehash, event_type, state_key) .await @@ -321,12 +330,16 @@ impl Service { self.db.room_state_full_pdus(room_id).await } - /// Returns a single PDU from `room_id` with key (`event_type`, + /// Returns a single EventId from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub async fn room_state_get_id( + pub async fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result> { + ) -> Result + where + Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + ::Owned: Borrow, + { self.db .room_state_get_id(room_id, event_type, state_key) .await