From 4a3cc9fffa62636ba1c3e76494142da226f9aefd Mon Sep 17 00:00:00 2001
From: Jason Volk <jason@zemos.net>
Date: Sat, 30 Nov 2024 08:09:51 +0000
Subject: [PATCH] de-arc state_full_ids

Signed-off-by: Jason Volk <jason@zemos.net>
---
 src/api/client/context.rs                |  6 ++---
 src/api/client/sync/v3.rs                |  8 +++---
 src/api/client/sync/v4.rs                |  6 ++---
 src/api/server/send_join.rs              |  6 ++---
 src/api/server/state.rs                  | 10 ++++---
 src/api/server/state_ids.rs              |  5 ++--
 src/service/rooms/spaces/mod.rs          |  4 +--
 src/service/rooms/state_accessor/data.rs | 34 +++++++++++++++++-------
 src/service/rooms/state_accessor/mod.rs  | 29 ++++++++++++++------
 9 files changed, 69 insertions(+), 39 deletions(-)

diff --git a/src/api/client/context.rs b/src/api/client/context.rs
index bf87f5e1..652e17f4 100644
--- a/src/api/client/context.rs
+++ b/src/api/client/context.rs
@@ -1,4 +1,4 @@
-use std::iter::once;
+use std::{collections::HashMap, iter::once};
 
 use axum::extract::State;
 use conduit::{
@@ -10,7 +10,7 @@ use futures::{future::try_join, StreamExt, TryFutureExt};
 use ruma::{
 	api::client::{context::get_context, filter::LazyLoadOptions},
 	events::StateEventType,
-	UserId,
+	OwnedEventId, UserId,
 };
 
 use crate::{
@@ -124,7 +124,7 @@ pub(crate) async fn get_context_route(
 		.await
 		.map_err(|e| err!(Database("State hash not found: {e}")))?;
 
-	let state_ids = services
+	let state_ids: HashMap<_, OwnedEventId> = services
 		.rooms
 		.state_accessor
 		.state_full_ids(shortstatehash)
diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs
index 9c1cefdb..5578077f 100644
--- a/src/api/client/sync/v3.rs
+++ b/src/api/client/sync/v3.rs
@@ -32,7 +32,7 @@ use ruma::{
 		TimelineEventType::*,
 	},
 	serde::Raw,
-	uint, DeviceId, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
+	uint, DeviceId, EventId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
 };
 use tracing::{Instrument as _, Span};
 
@@ -398,7 +398,7 @@ async fn handle_left_room(
 		Err(_) => HashMap::new(),
 	};
 
-	let Ok(left_event_id) = services
+	let Ok(left_event_id): Result<OwnedEventId> = services
 		.rooms
 		.state_accessor
 		.room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str())
@@ -666,7 +666,7 @@ async fn load_joined_room(
 
 			let (joined_member_count, invited_member_count, heroes) = calculate_counts().await?;
 
-			let current_state_ids = services
+			let current_state_ids: HashMap<_, OwnedEventId> = services
 				.rooms
 				.state_accessor
 				.state_full_ids(current_shortstatehash)
@@ -736,7 +736,7 @@ async fn load_joined_room(
 			let mut delta_state_events = Vec::new();
 
 			if since_shortstatehash != current_shortstatehash {
-				let current_state_ids = services
+				let current_state_ids: HashMap<_, OwnedEventId> = services
 					.rooms
 					.state_accessor
 					.state_full_ids(current_shortstatehash)
diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs
index 62c313e2..14d79c19 100644
--- a/src/api/client/sync/v4.rs
+++ b/src/api/client/sync/v4.rs
@@ -1,6 +1,6 @@
 use std::{
 	cmp::{self, Ordering},
-	collections::{BTreeMap, BTreeSet, HashSet},
+	collections::{BTreeMap, BTreeSet, HashMap, HashSet},
 	time::Duration,
 };
 
@@ -30,7 +30,7 @@ use ruma::{
 		TimelineEventType::{self, *},
 	},
 	state_res::Event,
-	uint, MilliSecondsSinceUnixEpoch, OwnedRoomId, UInt, UserId,
+	uint, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, UInt, UserId,
 };
 use service::{rooms::read_receipt::pack_receipts, Services};
 
@@ -211,7 +211,7 @@ pub(crate) async fn sync_events_v4_route(
 				let new_encrypted_room = encrypted_room && since_encryption.is_err();
 
 				if encrypted_room {
-					let current_state_ids = services
+					let current_state_ids: HashMap<_, OwnedEventId> = services
 						.rooms
 						.state_accessor
 						.state_full_ids(current_shortstatehash)
diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs
index 0ad07b1e..92ab3b50 100644
--- a/src/api/server/send_join.rs
+++ b/src/api/server/send_join.rs
@@ -1,6 +1,6 @@
 #![allow(deprecated)]
 
-use std::borrow::Borrow;
+use std::{borrow::Borrow, collections::HashMap};
 
 use axum::extract::State;
 use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result};
@@ -11,7 +11,7 @@ use ruma::{
 		room::member::{MembershipState, RoomMemberEventContent},
 		StateEventType,
 	},
-	CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName,
+	CanonicalJsonValue, OwnedEventId, OwnedServerName, OwnedUserId, RoomId, ServerName,
 };
 use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
 use service::Services;
@@ -165,7 +165,7 @@ async fn create_join_event(
 
 	drop(mutex_lock);
 
-	let state_ids = services
+	let state_ids: HashMap<_, OwnedEventId> = services
 		.rooms
 		.state_accessor
 		.state_full_ids(shortstatehash)
diff --git a/src/api/server/state.rs b/src/api/server/state.rs
index b21fce68..400b9237 100644
--- a/src/api/server/state.rs
+++ b/src/api/server/state.rs
@@ -3,7 +3,7 @@ use std::{borrow::Borrow, iter::once};
 use axum::extract::State;
 use conduit::{err, result::LogErr, utils::IterStream, Result};
 use futures::{FutureExt, StreamExt, TryStreamExt};
-use ruma::api::federation::event::get_room_state;
+use ruma::{api::federation::event::get_room_state, OwnedEventId};
 
 use super::AccessCheck;
 use crate::Ruma;
@@ -30,14 +30,18 @@ pub(crate) async fn get_room_state_route(
 		.await
 		.map_err(|_| err!(Request(NotFound("PDU state not found."))))?;
 
-	let pdus = services
+	let state_ids: Vec<OwnedEventId> = services
 		.rooms
 		.state_accessor
 		.state_full_ids(shortstatehash)
 		.await
 		.log_err()
 		.map_err(|_| err!(Request(NotFound("PDU state IDs not found."))))?
-		.values()
+		.into_values()
+		.collect();
+
+	let pdus = state_ids
+		.iter()
 		.try_stream()
 		.and_then(|id| services.rooms.timeline.get_pdu_json(id))
 		.and_then(|pdu| {
diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs
index 0c023bf0..55662a40 100644
--- a/src/api/server/state_ids.rs
+++ b/src/api/server/state_ids.rs
@@ -3,7 +3,7 @@ use std::{borrow::Borrow, iter::once};
 use axum::extract::State;
 use conduit::{err, Result};
 use futures::StreamExt;
-use ruma::api::federation::event::get_room_state_ids;
+use ruma::{api::federation::event::get_room_state_ids, OwnedEventId};
 
 use super::AccessCheck;
 use crate::Ruma;
@@ -31,14 +31,13 @@ pub(crate) async fn get_room_state_ids_route(
 		.await
 		.map_err(|_| err!(Request(NotFound("Pdu state not found."))))?;
 
-	let pdu_ids = services
+	let pdu_ids: Vec<OwnedEventId> = services
 		.rooms
 		.state_accessor
 		.state_full_ids(shortstatehash)
 		.await
 		.map_err(|_| err!(Request(NotFound("State ids not found"))))?
 		.into_values()
-		.map(|id| (*id).to_owned())
 		.collect();
 
 	let auth_chain_ids = services
diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs
index 2b80e3dc..3e972ca6 100644
--- a/src/service/rooms/spaces/mod.rs
+++ b/src/service/rooms/spaces/mod.rs
@@ -1,7 +1,7 @@
 mod tests;
 
 use std::{
-	collections::VecDeque,
+	collections::{HashMap, VecDeque},
 	fmt::{Display, Formatter},
 	str::FromStr,
 	sync::Arc,
@@ -572,7 +572,7 @@ impl Service {
 			return Ok(None);
 		};
 
-		let state = self
+		let state: HashMap<_, Arc<_>> = self
 			.services
 			.state_accessor
 			.state_full_ids(current_shortstatehash)
diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs
index 6c67b856..7760d5b6 100644
--- a/src/service/rooms/state_accessor/data.rs
+++ b/src/service/rooms/state_accessor/data.rs
@@ -1,4 +1,4 @@
-use std::{collections::HashMap, sync::Arc};
+use std::{borrow::Borrow, collections::HashMap, sync::Arc};
 
 use conduit::{
 	at, err,
@@ -8,6 +8,7 @@ use conduit::{
 use database::{Deserialized, Map};
 use futures::{StreamExt, TryFutureExt};
 use ruma::{events::StateEventType, EventId, OwnedEventId, RoomId};
+use serde::Deserialize;
 
 use crate::{
 	rooms,
@@ -84,7 +85,11 @@ impl Data {
 		Ok(full_pdus)
 	}
 
-	pub(super) async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result<HashMap<u64, Arc<EventId>>> {
+	pub(super) async fn state_full_ids<Id>(&self, shortstatehash: ShortStateHash) -> Result<HashMap<ShortStateKey, Id>>
+	where
+		Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
+		<Id as ToOwned>::Owned: Borrow<EventId>,
+	{
 		let short_ids = self.state_full_shortids(shortstatehash).await?;
 
 		let event_ids = self
@@ -123,11 +128,15 @@ impl Data {
 		Ok(shortids)
 	}
 
-	/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
-	#[allow(clippy::unused_self)]
-	pub(super) async fn state_get_id(
+	/// Returns a single EventId from `room_id` with key
+	/// (`event_type`,`state_key`).
+	pub(super) async fn state_get_id<Id>(
 		&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
-	) -> Result<Arc<EventId>> {
+	) -> Result<Id>
+	where
+		Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
+		<Id as ToOwned>::Owned: Borrow<EventId>,
+	{
 		let shortstatekey = self
 			.services
 			.short
@@ -162,7 +171,7 @@ impl Data {
 		&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
 	) -> Result<PduEvent> {
 		self.state_get_id(shortstatehash, event_type, state_key)
-			.and_then(|event_id| async move { self.services.timeline.get_pdu(&event_id).await })
+			.and_then(|event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await })
 			.await
 	}
 
@@ -204,10 +213,15 @@ impl Data {
 			.await
 	}
 
-	/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
-	pub(super) async fn room_state_get_id(
+	/// Returns a single EventId from `room_id` with key
+	/// (`event_type`,`state_key`).
+	pub(super) async fn room_state_get_id<Id>(
 		&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
-	) -> Result<Arc<EventId>> {
+	) -> Result<Id>
+	where
+		Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
+		<Id as ToOwned>::Owned: Borrow<EventId>,
+	{
 		self.services
 			.state
 			.get_room_shortstatehash(room_id)
diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs
index 18f999b4..e42d3764 100644
--- a/src/service/rooms/state_accessor/mod.rs
+++ b/src/service/rooms/state_accessor/mod.rs
@@ -1,6 +1,7 @@
 mod data;
 
 use std::{
+	borrow::Borrow,
 	collections::HashMap,
 	fmt::Write,
 	sync::{Arc, Mutex as StdMutex, Mutex},
@@ -101,8 +102,12 @@ impl Service {
 	/// Builds a StateMap by iterating over all keys that start
 	/// with state_hash, this gives the full state for the given state_hash.
 	#[tracing::instrument(skip(self), level = "debug")]
-	pub async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result<HashMap<u64, Arc<EventId>>> {
-		self.db.state_full_ids(shortstatehash).await
+	pub async fn state_full_ids<Id>(&self, shortstatehash: ShortStateHash) -> Result<HashMap<u64, Id>>
+	where
+		Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
+		<Id as ToOwned>::Owned: Borrow<EventId>,
+	{
+		self.db.state_full_ids::<Id>(shortstatehash).await
 	}
 
 	#[inline]
@@ -118,12 +123,16 @@ impl Service {
 		self.db.state_full(shortstatehash).await
 	}
 
-	/// Returns a single PDU from `room_id` with key (`event_type`,
+	/// Returns a single EventId from `room_id` with key (`event_type`,
 	/// `state_key`).
 	#[tracing::instrument(skip(self), level = "debug")]
-	pub async fn state_get_id(
+	pub async fn state_get_id<Id>(
 		&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
-	) -> Result<Arc<EventId>> {
+	) -> Result<Id>
+	where
+		Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
+		<Id as ToOwned>::Owned: Borrow<EventId>,
+	{
 		self.db
 			.state_get_id(shortstatehash, event_type, state_key)
 			.await
@@ -321,12 +330,16 @@ impl Service {
 		self.db.room_state_full_pdus(room_id).await
 	}
 
-	/// Returns a single PDU from `room_id` with key (`event_type`,
+	/// Returns a single EventId from `room_id` with key (`event_type`,
 	/// `state_key`).
 	#[tracing::instrument(skip(self), level = "debug")]
-	pub async fn room_state_get_id(
+	pub async fn room_state_get_id<Id>(
 		&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
-	) -> Result<Arc<EventId>> {
+	) -> Result<Id>
+	where
+		Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
+		<Id as ToOwned>::Owned: Borrow<EventId>,
+	{
 		self.db
 			.room_state_get_id(room_id, event_type, state_key)
 			.await