From ad0b0af955cda8b93b6d8c9c665905a2c4dd93d3 Mon Sep 17 00:00:00 2001
From: Jason Volk <jason@zemos.net>
Date: Sat, 25 Jan 2025 23:07:50 +0000
Subject: [PATCH] combine state_accessor data into mod

Signed-off-by: Jason Volk <jason@zemos.net>
---
 src/service/rooms/state_accessor/data.rs | 253 -----------------------
 src/service/rooms/state_accessor/mod.rs  | 183 +++++++++++++---
 2 files changed, 149 insertions(+), 287 deletions(-)
 delete mode 100644 src/service/rooms/state_accessor/data.rs

diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs
deleted file mode 100644
index 29b27a05..00000000
--- a/src/service/rooms/state_accessor/data.rs
+++ /dev/null
@@ -1,253 +0,0 @@
-use std::{borrow::Borrow, collections::HashMap, sync::Arc};
-
-use conduwuit::{
-	at, err,
-	utils::stream::{BroadbandExt, IterStream, ReadyExt},
-	PduEvent, Result,
-};
-use database::{Deserialized, Map};
-use futures::{FutureExt, StreamExt, TryFutureExt};
-use ruma::{events::StateEventType, EventId, OwnedEventId, RoomId};
-use serde::Deserialize;
-
-use crate::{
-	rooms,
-	rooms::{
-		short::{ShortEventId, ShortStateHash, ShortStateKey},
-		state_compressor::parse_compressed_state_event,
-	},
-	Dep,
-};
-
-pub(super) struct Data {
-	shorteventid_shortstatehash: Arc<Map>,
-	services: Services,
-}
-
-struct Services {
-	short: Dep<rooms::short::Service>,
-	state: Dep<rooms::state::Service>,
-	state_compressor: Dep<rooms::state_compressor::Service>,
-	timeline: Dep<rooms::timeline::Service>,
-}
-
-impl Data {
-	pub(super) fn new(args: &crate::Args<'_>) -> Self {
-		let db = &args.db;
-		Self {
-			shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(),
-			services: Services {
-				short: args.depend::<rooms::short::Service>("rooms::short"),
-				state: args.depend::<rooms::state::Service>("rooms::state"),
-				state_compressor: args
-					.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
-				timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
-			},
-		}
-	}
-
-	pub(super) async fn state_full(
-		&self,
-		shortstatehash: ShortStateHash,
-	) -> Result<HashMap<(StateEventType, String), PduEvent>> {
-		let state = self
-			.state_full_pdus(shortstatehash)
-			.await?
-			.into_iter()
-			.filter_map(|pdu| Some(((pdu.kind.to_string().into(), pdu.state_key.clone()?), pdu)))
-			.collect();
-
-		Ok(state)
-	}
-
-	pub(super) async fn state_full_pdus(
-		&self,
-		shortstatehash: ShortStateHash,
-	) -> Result<Vec<PduEvent>> {
-		let short_ids = self.state_full_shortids(shortstatehash).await?;
-
-		let full_pdus = self
-			.services
-			.short
-			.multi_get_eventid_from_short(short_ids.into_iter().map(at!(1)).stream())
-			.ready_filter_map(Result::ok)
-			.broad_filter_map(|event_id: OwnedEventId| async move {
-				self.services.timeline.get_pdu(&event_id).await.ok()
-			})
-			.collect()
-			.await;
-
-		Ok(full_pdus)
-	}
-
-	pub(super) async fn state_full_ids<Id>(
-		&self,
-		shortstatehash: ShortStateHash,
-	) -> Result<HashMap<ShortStateKey, Id>>
-	where
-		Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
-		<Id as ToOwned>::Owned: Borrow<EventId>,
-	{
-		let short_ids = self.state_full_shortids(shortstatehash).await?;
-
-		let full_ids = self
-			.services
-			.short
-			.multi_get_eventid_from_short(short_ids.iter().map(at!(1)).stream())
-			.zip(short_ids.iter().stream().map(at!(0)))
-			.ready_filter_map(|(event_id, shortstatekey)| Some((shortstatekey, event_id.ok()?)))
-			.collect()
-			.boxed()
-			.await;
-
-		Ok(full_ids)
-	}
-
-	pub(super) async fn state_full_shortids(
-		&self,
-		shortstatehash: ShortStateHash,
-	) -> Result<Vec<(ShortStateKey, ShortEventId)>> {
-		let shortids = self
-			.services
-			.state_compressor
-			.load_shortstatehash_info(shortstatehash)
-			.await
-			.map_err(|e| err!(Database("Missing state IDs: {e}")))?
-			.pop()
-			.expect("there is always one layer")
-			.full_state
-			.iter()
-			.copied()
-			.map(parse_compressed_state_event)
-			.collect();
-
-		Ok(shortids)
-	}
-
-	/// Returns a single EventId from `room_id` with key
-	/// (`event_type`,`state_key`).
-	pub(super) async fn state_get_id<Id>(
-		&self,
-		shortstatehash: ShortStateHash,
-		event_type: &StateEventType,
-		state_key: &str,
-	) -> Result<Id>
-	where
-		Id: for<'de> Deserialize<'de> + Sized + ToOwned,
-		<Id as ToOwned>::Owned: Borrow<EventId>,
-	{
-		let shortstatekey = self
-			.services
-			.short
-			.get_shortstatekey(event_type, state_key)
-			.await?;
-
-		let full_state = self
-			.services
-			.state_compressor
-			.load_shortstatehash_info(shortstatehash)
-			.await
-			.map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))?
-			.pop()
-			.expect("there is always one layer")
-			.full_state;
-
-		let compressed = full_state
-			.iter()
-			.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
-			.ok_or(err!(Database("No shortstatekey in compressed state")))?;
-
-		let (_, shorteventid) = parse_compressed_state_event(*compressed);
-
-		self.services
-			.short
-			.get_eventid_from_short(shorteventid)
-			.await
-	}
-
-	/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
-	pub(super) async fn state_get(
-		&self,
-		shortstatehash: ShortStateHash,
-		event_type: &StateEventType,
-		state_key: &str,
-	) -> Result<PduEvent> {
-		self.state_get_id(shortstatehash, event_type, state_key)
-			.and_then(|event_id: OwnedEventId| async move {
-				self.services.timeline.get_pdu(&event_id).await
-			})
-			.await
-	}
-
-	/// Returns the state hash for this pdu.
-	pub(super) async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<ShortStateHash> {
-		const BUFSIZE: usize = size_of::<ShortEventId>();
-
-		self.services
-			.short
-			.get_shorteventid(event_id)
-			.and_then(|shorteventid| {
-				self.shorteventid_shortstatehash
-					.aqry::<BUFSIZE, _>(&shorteventid)
-			})
-			.await
-			.deserialized()
-	}
-
-	/// Returns the full room state.
-	pub(super) async fn room_state_full(
-		&self,
-		room_id: &RoomId,
-	) -> Result<HashMap<(StateEventType, String), PduEvent>> {
-		self.services
-			.state
-			.get_room_shortstatehash(room_id)
-			.and_then(|shortstatehash| self.state_full(shortstatehash))
-			.map_err(|e| err!(Database("Missing state for {room_id:?}: {e:?}")))
-			.await
-	}
-
-	/// Returns the full room state's pdus.
-	#[allow(unused_qualifications)] // async traits
-	pub(super) async fn room_state_full_pdus(&self, room_id: &RoomId) -> Result<Vec<PduEvent>> {
-		self.services
-			.state
-			.get_room_shortstatehash(room_id)
-			.and_then(|shortstatehash| self.state_full_pdus(shortstatehash))
-			.map_err(|e| err!(Database("Missing state pdus for {room_id:?}: {e:?}")))
-			.await
-	}
-
-	/// Returns a single EventId from `room_id` with key
-	/// (`event_type`,`state_key`).
-	pub(super) async fn room_state_get_id<Id>(
-		&self,
-		room_id: &RoomId,
-		event_type: &StateEventType,
-		state_key: &str,
-	) -> Result<Id>
-	where
-		Id: for<'de> Deserialize<'de> + Sized + ToOwned,
-		<Id as ToOwned>::Owned: Borrow<EventId>,
-	{
-		self.services
-			.state
-			.get_room_shortstatehash(room_id)
-			.and_then(|shortstatehash| self.state_get_id(shortstatehash, event_type, state_key))
-			.await
-	}
-
-	/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
-	pub(super) async fn room_state_get(
-		&self,
-		room_id: &RoomId,
-		event_type: &StateEventType,
-		state_key: &str,
-	) -> Result<PduEvent> {
-		self.services
-			.state
-			.get_room_shortstatehash(room_id)
-			.and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key))
-			.await
-	}
-}
diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs
index d89c8835..3d87534b 100644
--- a/src/service/rooms/state_accessor/mod.rs
+++ b/src/service/rooms/state_accessor/mod.rs
@@ -1,5 +1,3 @@
-mod data;
-
 use std::{
 	borrow::Borrow,
 	collections::HashMap,
@@ -8,16 +6,18 @@ use std::{
 };
 
 use conduwuit::{
-	err, error,
+	at, err, error,
 	pdu::PduBuilder,
 	utils,
 	utils::{
 		math::{usize_from_f64, Expected},
-		ReadyExt,
+		stream::BroadbandExt,
+		IterStream, ReadyExt,
 	},
 	Err, Error, PduEvent, Result,
 };
-use futures::StreamExt;
+use database::{Deserialized, Map};
+use futures::{FutureExt, StreamExt, TryFutureExt};
 use lru_cache::LruCache;
 use ruma::{
 	events::{
@@ -38,33 +38,40 @@ use ruma::{
 	},
 	room::RoomType,
 	space::SpaceRoomJoinRule,
-	EventEncryptionAlgorithm, EventId, JsOption, OwnedRoomAliasId, OwnedRoomId, OwnedServerName,
-	OwnedUserId, RoomId, ServerName, UserId,
+	EventEncryptionAlgorithm, EventId, JsOption, OwnedEventId, OwnedRoomAliasId, OwnedRoomId,
+	OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
 };
 use serde::Deserialize;
 
-use self::data::Data;
 use crate::{
 	rooms,
 	rooms::{
 		short::{ShortEventId, ShortStateHash, ShortStateKey},
 		state::RoomMutexGuard,
+		state_compressor::parse_compressed_state_event,
 	},
 	Dep,
 };
 
 pub struct Service {
-	services: Services,
-	db: Data,
 	pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, ShortStateHash), bool>>,
 	pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, ShortStateHash), bool>>,
+	services: Services,
+	db: Data,
 }
 
 struct Services {
+	short: Dep<rooms::short::Service>,
+	state: Dep<rooms::state::Service>,
+	state_compressor: Dep<rooms::state_compressor::Service>,
 	state_cache: Dep<rooms::state_cache::Service>,
 	timeline: Dep<rooms::timeline::Service>,
 }
 
+struct Data {
+	shorteventid_shortstatehash: Arc<Map>,
+}
+
 impl crate::Service for Service {
 	fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
 		let config = &args.server.config;
@@ -74,17 +81,23 @@ impl crate::Service for Service {
 			f64::from(config.user_visibility_cache_capacity) * config.cache_capacity_modifier;
 
 		Ok(Arc::new(Self {
-			services: Services {
-				state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
-				timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
-			},
-			db: Data::new(&args),
 			server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(
 				server_visibility_cache_capacity,
 			)?)),
 			user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(
 				user_visibility_cache_capacity,
 			)?)),
+			services: Services {
+				state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
+				timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
+				short: args.depend::<rooms::short::Service>("rooms::short"),
+				state: args.depend::<rooms::state::Service>("rooms::state"),
+				state_compressor: args
+					.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
+			},
+			db: Data {
+				shorteventid_shortstatehash: args.db["shorteventid_shortstatehash"].clone(),
+			},
 		}))
 	}
 
@@ -130,6 +143,37 @@ impl crate::Service for Service {
 }
 
 impl Service {
+	pub async fn state_full(
+		&self,
+		shortstatehash: ShortStateHash,
+	) -> Result<HashMap<(StateEventType, String), PduEvent>> {
+		let state = self
+			.state_full_pdus(shortstatehash)
+			.await?
+			.into_iter()
+			.filter_map(|pdu| Some(((pdu.kind.to_string().into(), pdu.state_key.clone()?), pdu)))
+			.collect();
+
+		Ok(state)
+	}
+
+	pub async fn state_full_pdus(&self, shortstatehash: ShortStateHash) -> Result<Vec<PduEvent>> {
+		let short_ids = self.state_full_shortids(shortstatehash).await?;
+
+		let full_pdus = self
+			.services
+			.short
+			.multi_get_eventid_from_short(short_ids.into_iter().map(at!(1)).stream())
+			.ready_filter_map(Result::ok)
+			.broad_filter_map(|event_id: OwnedEventId| async move {
+				self.services.timeline.get_pdu(&event_id).await.ok()
+			})
+			.collect()
+			.await;
+
+		Ok(full_pdus)
+	}
+
 	/// Builds a StateMap by iterating over all keys that start
 	/// with state_hash, this gives the full state for the given state_hash.
 	#[tracing::instrument(skip(self), level = "debug")]
@@ -141,7 +185,19 @@ impl Service {
 		Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
 		<Id as ToOwned>::Owned: Borrow<EventId>,
 	{
-		self.db.state_full_ids::<Id>(shortstatehash).await
+		let short_ids = self.state_full_shortids(shortstatehash).await?;
+
+		let full_ids = self
+			.services
+			.short
+			.multi_get_eventid_from_short(short_ids.iter().map(at!(1)).stream())
+			.zip(short_ids.iter().stream().map(at!(0)))
+			.ready_filter_map(|(event_id, shortstatekey)| Some((shortstatekey, event_id.ok()?)))
+			.collect()
+			.boxed()
+			.await;
+
+		Ok(full_ids)
 	}
 
 	#[inline]
@@ -149,14 +205,21 @@ impl Service {
 		&self,
 		shortstatehash: ShortStateHash,
 	) -> Result<Vec<(ShortStateKey, ShortEventId)>> {
-		self.db.state_full_shortids(shortstatehash).await
-	}
+		let shortids = self
+			.services
+			.state_compressor
+			.load_shortstatehash_info(shortstatehash)
+			.await
+			.map_err(|e| err!(Database("Missing state IDs: {e}")))?
+			.pop()
+			.expect("there is always one layer")
+			.full_state
+			.iter()
+			.copied()
+			.map(parse_compressed_state_event)
+			.collect();
 
-	pub async fn state_full(
-		&self,
-		shortstatehash: ShortStateHash,
-	) -> Result<HashMap<(StateEventType, String), PduEvent>> {
-		self.db.state_full(shortstatehash).await
+		Ok(shortids)
 	}
 
 	/// Returns a single EventId from `room_id` with key (`event_type`,
@@ -172,22 +235,47 @@ impl Service {
 		Id: for<'de> Deserialize<'de> + Sized + ToOwned,
 		<Id as ToOwned>::Owned: Borrow<EventId>,
 	{
-		self.db
-			.state_get_id(shortstatehash, event_type, state_key)
+		let shortstatekey = self
+			.services
+			.short
+			.get_shortstatekey(event_type, state_key)
+			.await?;
+
+		let full_state = self
+			.services
+			.state_compressor
+			.load_shortstatehash_info(shortstatehash)
+			.await
+			.map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))?
+			.pop()
+			.expect("there is always one layer")
+			.full_state;
+
+		let compressed = full_state
+			.iter()
+			.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
+			.ok_or(err!(Database("No shortstatekey in compressed state")))?;
+
+		let (_, shorteventid) = parse_compressed_state_event(*compressed);
+
+		self.services
+			.short
+			.get_eventid_from_short(shorteventid)
 			.await
 	}
 
 	/// Returns a single PDU from `room_id` with key (`event_type`,
 	/// `state_key`).
-	#[inline]
 	pub async fn state_get(
 		&self,
 		shortstatehash: ShortStateHash,
 		event_type: &StateEventType,
 		state_key: &str,
 	) -> Result<PduEvent> {
-		self.db
-			.state_get(shortstatehash, event_type, state_key)
+		self.state_get_id(shortstatehash, event_type, state_key)
+			.and_then(|event_id: OwnedEventId| async move {
+				self.services.timeline.get_pdu(&event_id).await
+			})
 			.await
 	}
 
@@ -375,7 +463,18 @@ impl Service {
 
 	/// Returns the state hash for this pdu.
 	pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<ShortStateHash> {
-		self.db.pdu_shortstatehash(event_id).await
+		const BUFSIZE: usize = size_of::<ShortEventId>();
+
+		self.services
+			.short
+			.get_shorteventid(event_id)
+			.and_then(|shorteventid| {
+				self.db
+					.shorteventid_shortstatehash
+					.aqry::<BUFSIZE, _>(&shorteventid)
+			})
+			.await
+			.deserialized()
 	}
 
 	/// Returns the full room state.
@@ -384,13 +483,23 @@ impl Service {
 		&self,
 		room_id: &RoomId,
 	) -> Result<HashMap<(StateEventType, String), PduEvent>> {
-		self.db.room_state_full(room_id).await
+		self.services
+			.state
+			.get_room_shortstatehash(room_id)
+			.and_then(|shortstatehash| self.state_full(shortstatehash))
+			.map_err(|e| err!(Database("Missing state for {room_id:?}: {e:?}")))
+			.await
 	}
 
 	/// Returns the full room state pdus
 	#[tracing::instrument(skip(self), level = "debug")]
 	pub async fn room_state_full_pdus(&self, room_id: &RoomId) -> Result<Vec<PduEvent>> {
-		self.db.room_state_full_pdus(room_id).await
+		self.services
+			.state
+			.get_room_shortstatehash(room_id)
+			.and_then(|shortstatehash| self.state_full_pdus(shortstatehash))
+			.map_err(|e| err!(Database("Missing state pdus for {room_id:?}: {e:?}")))
+			.await
 	}
 
 	/// Returns a single EventId from `room_id` with key (`event_type`,
@@ -406,8 +515,10 @@ impl Service {
 		Id: for<'de> Deserialize<'de> + Sized + ToOwned,
 		<Id as ToOwned>::Owned: Borrow<EventId>,
 	{
-		self.db
-			.room_state_get_id(room_id, event_type, state_key)
+		self.services
+			.state
+			.get_room_shortstatehash(room_id)
+			.and_then(|shortstatehash| self.state_get_id(shortstatehash, event_type, state_key))
 			.await
 	}
 
@@ -420,7 +531,11 @@ impl Service {
 		event_type: &StateEventType,
 		state_key: &str,
 	) -> Result<PduEvent> {
-		self.db.room_state_get(room_id, event_type, state_key).await
+		self.services
+			.state
+			.get_room_shortstatehash(room_id)
+			.and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key))
+			.await
 	}
 
 	/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).