From 4776fe66c4a9d5cbb0153e8ff23009d21ed5010e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 28 Sep 2024 15:14:48 +0000 Subject: [PATCH] handle serde_json for deserialized() Signed-off-by: Jason Volk --- src/database/de.rs | 94 ++++++++++++++++++++------- src/database/deserialized.rs | 14 ---- src/database/handle.rs | 28 -------- src/service/account_data/mod.rs | 2 +- src/service/appservice/data.rs | 2 +- src/service/globals/data.rs | 5 +- src/service/key_backups/mod.rs | 12 +--- src/service/pusher/mod.rs | 2 +- src/service/rooms/outlier/mod.rs | 4 +- src/service/rooms/state_cache/data.rs | 10 +-- src/service/rooms/timeline/data.rs | 15 ++--- src/service/uiaa/mod.rs | 2 +- src/service/users/mod.rs | 12 ++-- 13 files changed, 95 insertions(+), 107 deletions(-) diff --git a/src/database/de.rs b/src/database/de.rs index 8ce25aa3..a5d2c127 100644 --- a/src/database/de.rs +++ b/src/database/de.rs @@ -58,10 +58,15 @@ impl<'de> Deserializer<'de> { } #[inline] - fn record_trail(&mut self) -> &'de [u8] { - let record = &self.buf[self.pos..]; - self.inc_pos(record.len()); - record + fn record_next_peek_byte(&self) -> Option { + let started = self.pos != 0; + let buf = &self.buf[self.pos..]; + debug_assert!( + !started || buf[0] == Self::SEP, + "Missing expected record separator at current position" + ); + + buf.get::(started.into()).copied() } #[inline] @@ -75,6 +80,13 @@ impl<'de> Deserializer<'de> { self.inc_pos(started.into()); } + #[inline] + fn record_trail(&mut self) -> &'de [u8] { + let record = &self.buf[self.pos..]; + self.inc_pos(record.len()); + record + } + #[inline] fn inc_pos(&mut self, n: usize) { self.pos = self.pos.saturating_add(n); @@ -85,13 +97,6 @@ impl<'de> Deserializer<'de> { impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { type Error = Error; - fn deserialize_map(self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - unimplemented!("deserialize Map not implemented") - } - fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de>, @@ -113,13 +118,23 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { visitor.visit_seq(self) } - fn deserialize_struct( - self, _name: &'static str, _fields: &'static [&'static str], _visitor: V, - ) -> Result + fn deserialize_map(self, visitor: V) -> Result where V: Visitor<'de>, { - unimplemented!("deserialize Struct not implemented") + let input = self.record_next(); + let mut d = serde_json::Deserializer::from_slice(input); + d.deserialize_map(visitor).map_err(Into::into) + } + + fn deserialize_struct(self, name: &'static str, fields: &'static [&'static str], visitor: V) -> Result + where + V: Visitor<'de>, + { + let input = self.record_next(); + let mut d = serde_json::Deserializer::from_slice(input); + d.deserialize_struct(name, fields, visitor) + .map_err(Into::into) } fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result @@ -134,11 +149,14 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { visitor.visit_unit() } - fn deserialize_newtype_struct(self, _name: &'static str, _visitor: V) -> Result + fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result where V: Visitor<'de>, { - unimplemented!("deserialize Newtype Struct not implemented") + match name { + "$serde_json::private::RawValue" => visitor.visit_map(self), + _ => visitor.visit_newtype_struct(self), + } } fn deserialize_enum( @@ -228,19 +246,31 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } fn deserialize_unit>(self, _visitor: V) -> Result { - unimplemented!("deserialize Unit Struct not implemented") + unimplemented!("deserialize Unit not implemented") } - fn deserialize_identifier>(self, _visitor: V) -> Result { - unimplemented!("deserialize Identifier not implemented") + // this only used for $serde_json::private::RawValue at this time; see MapAccess + fn deserialize_identifier>(self, visitor: V) -> Result { + let input = "$serde_json::private::RawValue"; + visitor.visit_borrowed_str(input) } fn deserialize_ignored_any>(self, _visitor: V) -> Result { unimplemented!("deserialize Ignored Any not implemented") } - fn deserialize_any>(self, _visitor: V) -> Result { - unimplemented!("deserialize any not implemented") + fn deserialize_any>(self, visitor: V) -> Result { + debug_assert_eq!( + conduit::debug::type_name::(), + "serde_json::value::de::::deserialize::ValueVisitor", + "deserialize_any: type not expected" + ); + + match self.record_next_peek_byte() { + Some(b'{') => self.deserialize_map(visitor), + _ => self.deserialize_str(visitor), + } } } @@ -259,3 +289,23 @@ impl<'a, 'de: 'a> de::SeqAccess<'de> for &'a mut Deserializer<'de> { seed.deserialize(&mut **self).map(Some) } } + +// this only used for $serde_json::private::RawValue at this time. our db +// schema doesn't have its own map format; we use json for that anyway +impl<'a, 'de: 'a> de::MapAccess<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'de>, + { + seed.deserialize(&mut **self).map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + seed.deserialize(&mut **self) + } +} diff --git a/src/database/deserialized.rs b/src/database/deserialized.rs index 7da112d5..a59b2ce5 100644 --- a/src/database/deserialized.rs +++ b/src/database/deserialized.rs @@ -9,11 +9,6 @@ pub trait Deserialized { F: FnOnce(T) -> U, T: for<'de> Deserialize<'de>; - fn map_json(self, f: F) -> Result - where - F: FnOnce(T) -> U, - T: for<'de> Deserialize<'de>; - #[inline] fn deserialized(self) -> Result where @@ -22,13 +17,4 @@ pub trait Deserialized { { self.map_de(identity::) } - - #[inline] - fn deserialized_json(self) -> Result - where - T: for<'de> Deserialize<'de>, - Self: Sized, - { - self.map_json(identity::) - } } diff --git a/src/database/handle.rs b/src/database/handle.rs index 89d87137..0d4bd02e 100644 --- a/src/database/handle.rs +++ b/src/database/handle.rs @@ -48,15 +48,6 @@ impl AsRef for Handle<'_> { } impl Deserialized for Result> { - #[inline] - fn map_json(self, f: F) -> Result - where - F: FnOnce(T) -> U, - T: for<'de> Deserialize<'de>, - { - self?.map_json(f) - } - #[inline] fn map_de(self, f: F) -> Result where @@ -68,15 +59,6 @@ impl Deserialized for Result> { } impl<'a> Deserialized for Result<&'a Handle<'a>> { - #[inline] - fn map_json(self, f: F) -> Result - where - F: FnOnce(T) -> U, - T: for<'de> Deserialize<'de>, - { - self.and_then(|handle| handle.map_json(f)) - } - #[inline] fn map_de(self, f: F) -> Result where @@ -88,16 +70,6 @@ impl<'a> Deserialized for Result<&'a Handle<'a>> { } impl<'a> Deserialized for &'a Handle<'a> { - fn map_json(self, f: F) -> Result - where - F: FnOnce(T) -> U, - T: for<'de> Deserialize<'de>, - { - serde_json::from_slice::(self.as_ref()) - .map_err(Into::into) - .map(f) - } - fn map_de(self, f: F) -> Result where F: FnOnce(T) -> U, diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index b4eb143d..4f00cff1 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -108,7 +108,7 @@ pub async fn get( .qry(&key) .and_then(|roomuserdataid| self.db.roomuserdataid_accountdata.qry(&roomuserdataid)) .await - .deserialized_json() + .deserialized() } /// Returns all changes to the account data that happened after `since`. diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index d5fa5476..f31c5e63 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -40,7 +40,7 @@ impl Data { self.id_appserviceregistrations .qry(id) .await - .deserialized_json() + .deserialized() .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) } diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 3286e40c..76f97944 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -305,10 +305,7 @@ impl Data { } pub async fn signing_keys_for(&self, origin: &ServerName) -> Result { - self.server_signingkeys - .qry(origin) - .await - .deserialized_json() + self.server_signingkeys.qry(origin).await.deserialized() } pub async fn database_version(&self) -> u64 { self.global.qry("version").await.deserialized().unwrap_or(0) } diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 12712e79..decf32f7 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -166,11 +166,7 @@ pub async fn get_latest_backup(&self, user_id: &UserId) -> Result<(String, Raw Result> { let key = (user_id, version); - self.db - .backupid_algorithm - .qry(&key) - .await - .deserialized_json() + self.db.backupid_algorithm.qry(&key).await.deserialized() } #[implement(Service)] @@ -278,11 +274,7 @@ pub async fn get_session( ) -> Result> { let key = (user_id, version, room_id, session_id); - self.db - .backupkeyid_backup - .qry(&key) - .await - .deserialized_json() + self.db.backupkeyid_backup.qry(&key).await.deserialized() } #[implement(Service)] diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 44ff1945..8d8b553f 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -90,7 +90,7 @@ impl Service { .senderkey_pusher .qry(&senderkey) .await - .deserialized_json() + .deserialized() } pub async fn get_pushers(&self, sender: &UserId) -> Vec { diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 277b5982..4c9225ae 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -33,7 +33,7 @@ pub async fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result Result { .eventid_outlierpdu .qry(event_id) .await - .deserialized_json() + .deserialized() } /// Append the PDU as an outlier. diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 38e504f6..f3ccaf10 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -156,10 +156,7 @@ impl Data { &self, user_id: &UserId, room_id: &RoomId, ) -> Result>> { let key = (user_id, room_id); - self.userroomid_invitestate - .qry(&key) - .await - .deserialized_json() + self.userroomid_invitestate.qry(&key).await.deserialized() } #[tracing::instrument(skip(self), level = "debug")] @@ -167,10 +164,7 @@ impl Data { &self, user_id: &UserId, room_id: &RoomId, ) -> Result>> { let key = (user_id, room_id); - self.userroomid_leftstate - .qry(&key) - .await - .deserialized_json() + self.userroomid_leftstate.qry(&key).await.deserialized() } /// Returns an iterator over all rooms a user left. diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index cd746be4..314dcb9f 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -90,17 +90,14 @@ impl Data { return Ok(pdu); } - self.eventid_outlierpdu - .qry(event_id) - .await - .deserialized_json() + self.eventid_outlierpdu.qry(event_id).await.deserialized() } /// Returns the json of a pdu. pub(super) async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result { let pduid = self.get_pdu_id(event_id).await?; - self.pduid_pdu.qry(&pduid).await.deserialized_json() + self.pduid_pdu.qry(&pduid).await.deserialized() } /// Returns the pdu's id. @@ -113,7 +110,7 @@ impl Data { pub(super) async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result { let pduid = self.get_pdu_id(event_id).await?; - self.pduid_pdu.qry(&pduid).await.deserialized_json() + self.pduid_pdu.qry(&pduid).await.deserialized() } /// Like get_non_outlier_pdu(), but without the expense of fetching and @@ -137,7 +134,7 @@ impl Data { self.eventid_outlierpdu .qry(event_id) .await - .deserialized_json() + .deserialized() .map(Arc::new) } @@ -162,12 +159,12 @@ impl Data { /// /// This does __NOT__ check the outliers `Tree`. pub(super) async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { - self.pduid_pdu.qry(pdu_id).await.deserialized_json() + self.pduid_pdu.qry(pdu_id).await.deserialized() } /// Returns the pdu as a `BTreeMap`. pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { - self.pduid_pdu.qry(pdu_id).await.deserialized_json() + self.pduid_pdu.qry(pdu_id).await.deserialized() } pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) { diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 7e231514..0415bfc2 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -238,6 +238,6 @@ async fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session .userdevicesessionid_uiaainfo .qry(&key) .await - .deserialized_json() + .deserialized() .map_err(|_| err!(Request(Forbidden("UIAA session does not exist.")))) } diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 9a058ba9..ca37ed9d 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -577,7 +577,7 @@ impl Service { .qry(&key) .await .map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key."))))? - .deserialized_json() + .deserialized() .map_err(|e| err!(Database("key in keyid_key is invalid. {e:?}")))?; let signatures = cross_signing_key @@ -652,7 +652,7 @@ impl Service { pub async fn get_device_keys<'a>(&'a self, user_id: &'a UserId, device_id: &DeviceId) -> Result> { let key_id = (user_id, device_id); - self.db.keyid_key.qry(&key_id).await.deserialized_json() + self.db.keyid_key.qry(&key_id).await.deserialized() } pub async fn get_key( @@ -666,7 +666,7 @@ impl Service { .keyid_key .qry(key_id) .await - .deserialized_json::()?; + .deserialized::()?; let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?; let raw_value = serde_json::value::to_raw_value(&cleaned)?; @@ -700,7 +700,7 @@ impl Service { pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result> { let key_id = self.db.userid_usersigningkeyid.qry(user_id).await?; - self.db.keyid_key.qry(&*key_id).await.deserialized_json() + self.db.keyid_key.qry(&*key_id).await.deserialized() } pub async fn add_to_device_event( @@ -791,7 +791,7 @@ impl Service { .userdeviceid_metadata .qry(&(user_id, device_id)) .await - .deserialized_json() + .deserialized() } pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result { @@ -830,7 +830,7 @@ impl Service { .userfilterid_filter .qry(&(user_id, filter_id)) .await - .deserialized_json() + .deserialized() } /// Creates an OpenID token, which can be used to prove that a user has