de-arc state_full_ids

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-11-30 08:09:51 +00:00
parent b5266ad9f5
commit 4a3cc9fffa
9 changed files with 69 additions and 39 deletions

View file

@ -1,4 +1,4 @@
use std::iter::once; use std::{collections::HashMap, iter::once};
use axum::extract::State; use axum::extract::State;
use conduit::{ use conduit::{
@ -10,7 +10,7 @@ use futures::{future::try_join, StreamExt, TryFutureExt};
use ruma::{ use ruma::{
api::client::{context::get_context, filter::LazyLoadOptions}, api::client::{context::get_context, filter::LazyLoadOptions},
events::StateEventType, events::StateEventType,
UserId, OwnedEventId, UserId,
}; };
use crate::{ use crate::{
@ -124,7 +124,7 @@ pub(crate) async fn get_context_route(
.await .await
.map_err(|e| err!(Database("State hash not found: {e}")))?; .map_err(|e| err!(Database("State hash not found: {e}")))?;
let state_ids = services let state_ids: HashMap<_, OwnedEventId> = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)

View file

@ -32,7 +32,7 @@ use ruma::{
TimelineEventType::*, TimelineEventType::*,
}, },
serde::Raw, serde::Raw,
uint, DeviceId, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId, uint, DeviceId, EventId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
}; };
use tracing::{Instrument as _, Span}; use tracing::{Instrument as _, Span};
@ -398,7 +398,7 @@ async fn handle_left_room(
Err(_) => HashMap::new(), Err(_) => HashMap::new(),
}; };
let Ok(left_event_id) = services let Ok(left_event_id): Result<OwnedEventId> = services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str()) .room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str())
@ -666,7 +666,7 @@ async fn load_joined_room(
let (joined_member_count, invited_member_count, heroes) = calculate_counts().await?; let (joined_member_count, invited_member_count, heroes) = calculate_counts().await?;
let current_state_ids = services let current_state_ids: HashMap<_, OwnedEventId> = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(current_shortstatehash) .state_full_ids(current_shortstatehash)
@ -736,7 +736,7 @@ async fn load_joined_room(
let mut delta_state_events = Vec::new(); let mut delta_state_events = Vec::new();
if since_shortstatehash != current_shortstatehash { if since_shortstatehash != current_shortstatehash {
let current_state_ids = services let current_state_ids: HashMap<_, OwnedEventId> = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(current_shortstatehash) .state_full_ids(current_shortstatehash)

View file

@ -1,6 +1,6 @@
use std::{ use std::{
cmp::{self, Ordering}, cmp::{self, Ordering},
collections::{BTreeMap, BTreeSet, HashSet}, collections::{BTreeMap, BTreeSet, HashMap, HashSet},
time::Duration, time::Duration,
}; };
@ -30,7 +30,7 @@ use ruma::{
TimelineEventType::{self, *}, TimelineEventType::{self, *},
}, },
state_res::Event, state_res::Event,
uint, MilliSecondsSinceUnixEpoch, OwnedRoomId, UInt, UserId, uint, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, UInt, UserId,
}; };
use service::{rooms::read_receipt::pack_receipts, Services}; use service::{rooms::read_receipt::pack_receipts, Services};
@ -211,7 +211,7 @@ pub(crate) async fn sync_events_v4_route(
let new_encrypted_room = encrypted_room && since_encryption.is_err(); let new_encrypted_room = encrypted_room && since_encryption.is_err();
if encrypted_room { if encrypted_room {
let current_state_ids = services let current_state_ids: HashMap<_, OwnedEventId> = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(current_shortstatehash) .state_full_ids(current_shortstatehash)

View file

@ -1,6 +1,6 @@
#![allow(deprecated)] #![allow(deprecated)]
use std::borrow::Borrow; use std::{borrow::Borrow, collections::HashMap};
use axum::extract::State; use axum::extract::State;
use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result}; use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result};
@ -11,7 +11,7 @@ use ruma::{
room::member::{MembershipState, RoomMemberEventContent}, room::member::{MembershipState, RoomMemberEventContent},
StateEventType, StateEventType,
}, },
CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName, CanonicalJsonValue, OwnedEventId, OwnedServerName, OwnedUserId, RoomId, ServerName,
}; };
use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use service::Services; use service::Services;
@ -165,7 +165,7 @@ async fn create_join_event(
drop(mutex_lock); drop(mutex_lock);
let state_ids = services let state_ids: HashMap<_, OwnedEventId> = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)

View file

@ -3,7 +3,7 @@ use std::{borrow::Borrow, iter::once};
use axum::extract::State; use axum::extract::State;
use conduit::{err, result::LogErr, utils::IterStream, Result}; use conduit::{err, result::LogErr, utils::IterStream, Result};
use futures::{FutureExt, StreamExt, TryStreamExt}; use futures::{FutureExt, StreamExt, TryStreamExt};
use ruma::api::federation::event::get_room_state; use ruma::{api::federation::event::get_room_state, OwnedEventId};
use super::AccessCheck; use super::AccessCheck;
use crate::Ruma; use crate::Ruma;
@ -30,14 +30,18 @@ pub(crate) async fn get_room_state_route(
.await .await
.map_err(|_| err!(Request(NotFound("PDU state not found."))))?; .map_err(|_| err!(Request(NotFound("PDU state not found."))))?;
let pdus = services let state_ids: Vec<OwnedEventId> = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
.await .await
.log_err() .log_err()
.map_err(|_| err!(Request(NotFound("PDU state IDs not found."))))? .map_err(|_| err!(Request(NotFound("PDU state IDs not found."))))?
.values() .into_values()
.collect();
let pdus = state_ids
.iter()
.try_stream() .try_stream()
.and_then(|id| services.rooms.timeline.get_pdu_json(id)) .and_then(|id| services.rooms.timeline.get_pdu_json(id))
.and_then(|pdu| { .and_then(|pdu| {

View file

@ -3,7 +3,7 @@ use std::{borrow::Borrow, iter::once};
use axum::extract::State; use axum::extract::State;
use conduit::{err, Result}; use conduit::{err, Result};
use futures::StreamExt; use futures::StreamExt;
use ruma::api::federation::event::get_room_state_ids; use ruma::{api::federation::event::get_room_state_ids, OwnedEventId};
use super::AccessCheck; use super::AccessCheck;
use crate::Ruma; use crate::Ruma;
@ -31,14 +31,13 @@ pub(crate) async fn get_room_state_ids_route(
.await .await
.map_err(|_| err!(Request(NotFound("Pdu state not found."))))?; .map_err(|_| err!(Request(NotFound("Pdu state not found."))))?;
let pdu_ids = services let pdu_ids: Vec<OwnedEventId> = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
.await .await
.map_err(|_| err!(Request(NotFound("State ids not found"))))? .map_err(|_| err!(Request(NotFound("State ids not found"))))?
.into_values() .into_values()
.map(|id| (*id).to_owned())
.collect(); .collect();
let auth_chain_ids = services let auth_chain_ids = services

View file

@ -1,7 +1,7 @@
mod tests; mod tests;
use std::{ use std::{
collections::VecDeque, collections::{HashMap, VecDeque},
fmt::{Display, Formatter}, fmt::{Display, Formatter},
str::FromStr, str::FromStr,
sync::Arc, sync::Arc,
@ -572,7 +572,7 @@ impl Service {
return Ok(None); return Ok(None);
}; };
let state = self let state: HashMap<_, Arc<_>> = self
.services .services
.state_accessor .state_accessor
.state_full_ids(current_shortstatehash) .state_full_ids(current_shortstatehash)

View file

@ -1,4 +1,4 @@
use std::{collections::HashMap, sync::Arc}; use std::{borrow::Borrow, collections::HashMap, sync::Arc};
use conduit::{ use conduit::{
at, err, at, err,
@ -8,6 +8,7 @@ use conduit::{
use database::{Deserialized, Map}; use database::{Deserialized, Map};
use futures::{StreamExt, TryFutureExt}; use futures::{StreamExt, TryFutureExt};
use ruma::{events::StateEventType, EventId, OwnedEventId, RoomId}; use ruma::{events::StateEventType, EventId, OwnedEventId, RoomId};
use serde::Deserialize;
use crate::{ use crate::{
rooms, rooms,
@ -84,7 +85,11 @@ impl Data {
Ok(full_pdus) Ok(full_pdus)
} }
pub(super) async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result<HashMap<u64, Arc<EventId>>> { pub(super) async fn state_full_ids<Id>(&self, shortstatehash: ShortStateHash) -> Result<HashMap<ShortStateKey, Id>>
where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
let short_ids = self.state_full_shortids(shortstatehash).await?; let short_ids = self.state_full_shortids(shortstatehash).await?;
let event_ids = self let event_ids = self
@ -123,11 +128,15 @@ impl Data {
Ok(shortids) Ok(shortids)
} }
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). /// Returns a single EventId from `room_id` with key
#[allow(clippy::unused_self)] /// (`event_type`,`state_key`).
pub(super) async fn state_get_id( pub(super) async fn state_get_id<Id>(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
) -> Result<Arc<EventId>> { ) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
let shortstatekey = self let shortstatekey = self
.services .services
.short .short
@ -162,7 +171,7 @@ impl Data {
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
) -> Result<PduEvent> { ) -> Result<PduEvent> {
self.state_get_id(shortstatehash, event_type, state_key) self.state_get_id(shortstatehash, event_type, state_key)
.and_then(|event_id| async move { self.services.timeline.get_pdu(&event_id).await }) .and_then(|event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await })
.await .await
} }
@ -204,10 +213,15 @@ impl Data {
.await .await
} }
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). /// Returns a single EventId from `room_id` with key
pub(super) async fn room_state_get_id( /// (`event_type`,`state_key`).
pub(super) async fn room_state_get_id<Id>(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Arc<EventId>> { ) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
self.services self.services
.state .state
.get_room_shortstatehash(room_id) .get_room_shortstatehash(room_id)

View file

@ -1,6 +1,7 @@
mod data; mod data;
use std::{ use std::{
borrow::Borrow,
collections::HashMap, collections::HashMap,
fmt::Write, fmt::Write,
sync::{Arc, Mutex as StdMutex, Mutex}, sync::{Arc, Mutex as StdMutex, Mutex},
@ -101,8 +102,12 @@ impl Service {
/// Builds a StateMap by iterating over all keys that start /// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash. /// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result<HashMap<u64, Arc<EventId>>> { pub async fn state_full_ids<Id>(&self, shortstatehash: ShortStateHash) -> Result<HashMap<u64, Id>>
self.db.state_full_ids(shortstatehash).await where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
self.db.state_full_ids::<Id>(shortstatehash).await
} }
#[inline] #[inline]
@ -118,12 +123,16 @@ impl Service {
self.db.state_full(shortstatehash).await self.db.state_full(shortstatehash).await
} }
/// Returns a single PDU from `room_id` with key (`event_type`, /// Returns a single EventId from `room_id` with key (`event_type`,
/// `state_key`). /// `state_key`).
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub async fn state_get_id( pub async fn state_get_id<Id>(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
) -> Result<Arc<EventId>> { ) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
self.db self.db
.state_get_id(shortstatehash, event_type, state_key) .state_get_id(shortstatehash, event_type, state_key)
.await .await
@ -321,12 +330,16 @@ impl Service {
self.db.room_state_full_pdus(room_id).await self.db.room_state_full_pdus(room_id).await
} }
/// Returns a single PDU from `room_id` with key (`event_type`, /// Returns a single EventId from `room_id` with key (`event_type`,
/// `state_key`). /// `state_key`).
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_get_id( pub async fn room_state_get_id<Id>(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Arc<EventId>> { ) -> Result<Id>
where
Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>,
{
self.db self.db
.room_state_get_id(room_id, event_type, state_key) .room_state_get_id(room_id, event_type, state_key)
.await .await