diff --git a/src/database/mod.rs b/src/database/mod.rs index e66abf68..c39b2b2f 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -10,6 +10,7 @@ pub mod maps; mod opts; mod ser; mod stream; +mod tests; mod util; mod watchers; @@ -28,7 +29,7 @@ pub use self::{ handle::Handle, keyval::{KeyVal, Slice}, map::Map, - ser::{Interfix, Separator}, + ser::{serialize, serialize_to_array, serialize_to_vec, Interfix, Json, Separator}, }; conduit::mod_ctor! {} diff --git a/src/database/ser.rs b/src/database/ser.rs index bd4bbd9a..742f1e34 100644 --- a/src/database/ser.rs +++ b/src/database/ser.rs @@ -1,12 +1,24 @@ use std::io::Write; -use conduit::{err, result::DebugInspect, utils::exchange, Error, Result}; +use arrayvec::ArrayVec; +use conduit::{debug::type_name, err, result::DebugInspect, utils::exchange, Error, Result}; use serde::{ser, Serialize}; #[inline] -pub(crate) fn serialize_to_vec(val: &T) -> Result> +pub fn serialize_to_array(val: T) -> Result> where - T: Serialize + ?Sized, + T: Serialize, +{ + let mut buf = ArrayVec::::new(); + serialize(&mut buf, val)?; + + Ok(buf) +} + +#[inline] +pub fn serialize_to_vec(val: T) -> Result> +where + T: Serialize, { let mut buf = Vec::with_capacity(64); serialize(&mut buf, val)?; @@ -15,10 +27,10 @@ where } #[inline] -pub(crate) fn serialize<'a, W, T>(out: &'a mut W, val: &'a T) -> Result<&'a [u8]> +pub fn serialize<'a, W, T>(out: &'a mut W, val: T) -> Result<&'a [u8]> where - W: Write + AsRef<[u8]>, - T: Serialize + ?Sized, + W: Write + AsRef<[u8]> + 'a, + T: Serialize, { let mut serializer = Serializer { out, @@ -43,6 +55,10 @@ pub(crate) struct Serializer<'a, W: Write> { fin: bool, } +/// Newtype for JSON serialization. +#[derive(Debug, Serialize)] +pub struct Json(pub T); + /// Directive to force separator serialization specifically for prefix keying /// use. This is a quirk of the database schema and prefix iterations. #[derive(Debug, Serialize)] @@ -56,38 +72,43 @@ pub struct Separator; impl Serializer<'_, W> { const SEP: &'static [u8] = b"\xFF"; + fn tuple_start(&mut self) { + debug_assert!(!self.sep, "Tuple start with separator set"); + self.sequence_start(); + } + + fn tuple_end(&mut self) -> Result { + self.sequence_end()?; + Ok(()) + } + fn sequence_start(&mut self) { debug_assert!(!self.is_finalized(), "Sequence start with finalization set"); - debug_assert!(!self.sep, "Sequence start with separator set"); - if cfg!(debug_assertions) { - self.depth = self.depth.saturating_add(1); - } + cfg!(debug_assertions).then(|| self.depth = self.depth.saturating_add(1)); } - fn sequence_end(&mut self) { - self.sep = false; - if cfg!(debug_assertions) { - self.depth = self.depth.saturating_sub(1); - } + fn sequence_end(&mut self) -> Result { + cfg!(debug_assertions).then(|| self.depth = self.depth.saturating_sub(1)); + Ok(()) } - fn record_start(&mut self) -> Result<()> { + fn record_start(&mut self) -> Result { debug_assert!(!self.is_finalized(), "Starting a record after serialization finalized"); exchange(&mut self.sep, true) .then(|| self.separator()) .unwrap_or(Ok(())) } - fn separator(&mut self) -> Result<()> { + fn separator(&mut self) -> Result { debug_assert!(!self.is_finalized(), "Writing a separator after serialization finalized"); self.out.write_all(Self::SEP).map_err(Into::into) } + fn write(&mut self, buf: &[u8]) -> Result { self.out.write_all(buf).map_err(Into::into) } + fn set_finalized(&mut self) { debug_assert!(!self.is_finalized(), "Finalization already set"); - if cfg!(debug_assertions) { - self.fin = true; - } + cfg!(debug_assertions).then(|| self.fin = true); } fn is_finalized(&self) -> bool { self.fin } @@ -104,53 +125,65 @@ impl ser::Serializer for &mut Serializer<'_, W> { type SerializeTupleStruct = Self; type SerializeTupleVariant = Self; - fn serialize_map(self, _len: Option) -> Result { - unimplemented!("serialize Map not implemented") - } - fn serialize_seq(self, _len: Option) -> Result { self.sequence_start(); - self.record_start()?; Ok(self) } fn serialize_tuple(self, _len: usize) -> Result { - self.sequence_start(); + self.tuple_start(); Ok(self) } fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { - self.sequence_start(); + self.tuple_start(); Ok(self) } fn serialize_tuple_variant( self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, ) -> Result { - self.sequence_start(); - Ok(self) + unimplemented!("serialize Tuple Variant not implemented") + } + + fn serialize_map(self, _len: Option) -> Result { + unimplemented!( + "serialize Map not implemented; did you mean to use database::Json() around your serde_json::Value?" + ) } fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { - self.sequence_start(); - Ok(self) + unimplemented!( + "serialize Struct not implemented at this time; did you mean to use database::Json() around your struct?" + ) } fn serialize_struct_variant( self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, ) -> Result { - self.sequence_start(); - Ok(self) + unimplemented!("serialize Struct Variant not implemented") } - fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result { - unimplemented!("serialize New Type Struct not implemented") + #[allow(clippy::needless_borrows_for_generic_args)] // buggy + fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result + where + T: Serialize + ?Sized, + { + debug_assert!( + name != "Json" || type_name::() != "alloc::boxed::Box", + "serializing a Json(RawValue); you can skip serialization instead" + ); + + match name { + "Json" => serde_json::to_writer(&mut self.out, value).map_err(Into::into), + _ => unimplemented!("Unrecognized serialization Newtype {name:?}"), + } } fn serialize_newtype_variant( self, _name: &'static str, _idx: u32, _var: &'static str, _value: &T, ) -> Result { - unimplemented!("serialize New Type Variant not implemented") + unimplemented!("serialize Newtype Variant not implemented") } fn serialize_unit_struct(self, name: &'static str) -> Result { @@ -180,35 +213,94 @@ impl ser::Serializer for &mut Serializer<'_, W> { self.serialize_str(v.encode_utf8(&mut buf)) } - fn serialize_str(self, v: &str) -> Result { self.serialize_bytes(v.as_bytes()) } + fn serialize_str(self, v: &str) -> Result { + debug_assert!( + self.depth > 0, + "serializing string at the top-level; you can skip serialization instead" + ); - fn serialize_bytes(self, v: &[u8]) -> Result { self.out.write_all(v).map_err(Error::Io) } + self.serialize_bytes(v.as_bytes()) + } + + fn serialize_bytes(self, v: &[u8]) -> Result { + debug_assert!( + self.depth > 0, + "serializing byte array at the top-level; you can skip serialization instead" + ); + + self.write(v) + } fn serialize_f64(self, _v: f64) -> Result { unimplemented!("serialize f64 not implemented") } fn serialize_f32(self, _v: f32) -> Result { unimplemented!("serialize f32 not implemented") } - fn serialize_i64(self, v: i64) -> Result { self.out.write_all(&v.to_be_bytes()).map_err(Error::Io) } + fn serialize_i64(self, v: i64) -> Result { self.write(&v.to_be_bytes()) } - fn serialize_i32(self, _v: i32) -> Result { unimplemented!("serialize i32 not implemented") } + fn serialize_i32(self, v: i32) -> Result { self.write(&v.to_be_bytes()) } fn serialize_i16(self, _v: i16) -> Result { unimplemented!("serialize i16 not implemented") } fn serialize_i8(self, _v: i8) -> Result { unimplemented!("serialize i8 not implemented") } - fn serialize_u64(self, v: u64) -> Result { self.out.write_all(&v.to_be_bytes()).map_err(Error::Io) } + fn serialize_u64(self, v: u64) -> Result { self.write(&v.to_be_bytes()) } - fn serialize_u32(self, _v: u32) -> Result { unimplemented!("serialize u32 not implemented") } + fn serialize_u32(self, v: u32) -> Result { self.write(&v.to_be_bytes()) } fn serialize_u16(self, _v: u16) -> Result { unimplemented!("serialize u16 not implemented") } - fn serialize_u8(self, v: u8) -> Result { self.out.write_all(&[v]).map_err(Error::Io) } + fn serialize_u8(self, v: u8) -> Result { self.write(&[v]) } fn serialize_bool(self, _v: bool) -> Result { unimplemented!("serialize bool not implemented") } fn serialize_unit(self) -> Result { unimplemented!("serialize unit not implemented") } } +impl ser::SerializeSeq for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_element(&mut self, val: &T) -> Result { val.serialize(&mut **self) } + + fn end(self) -> Result { self.sequence_end() } +} + +impl ser::SerializeTuple for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_element(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { self.tuple_end() } +} + +impl ser::SerializeTupleStruct for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { self.tuple_end() } +} + +impl ser::SerializeTupleVariant for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field(&mut self, val: &T) -> Result { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result { self.tuple_end() } +} + impl ser::SerializeMap for &mut Serializer<'_, W> { type Error = Error; type Ok = (); @@ -221,95 +313,27 @@ impl ser::SerializeMap for &mut Serializer<'_, W> { unimplemented!("serialize Map Val not implemented") } - fn end(self) -> Result { - self.sequence_end(); - Ok(()) - } -} - -impl ser::SerializeSeq for &mut Serializer<'_, W> { - type Error = Error; - type Ok = (); - - fn serialize_element(&mut self, val: &T) -> Result { val.serialize(&mut **self) } - - fn end(self) -> Result { - self.sequence_end(); - Ok(()) - } + fn end(self) -> Result { unimplemented!("serialize Map End not implemented") } } impl ser::SerializeStruct for &mut Serializer<'_, W> { type Error = Error; type Ok = (); - fn serialize_field(&mut self, _key: &'static str, val: &T) -> Result { - self.record_start()?; - val.serialize(&mut **self) + fn serialize_field(&mut self, _key: &'static str, _val: &T) -> Result { + unimplemented!("serialize Struct Field not implemented") } - fn end(self) -> Result { - self.sequence_end(); - Ok(()) - } + fn end(self) -> Result { unimplemented!("serialize Struct End not implemented") } } impl ser::SerializeStructVariant for &mut Serializer<'_, W> { type Error = Error; type Ok = (); - fn serialize_field(&mut self, _key: &'static str, val: &T) -> Result { - self.record_start()?; - val.serialize(&mut **self) + fn serialize_field(&mut self, _key: &'static str, _val: &T) -> Result { + unimplemented!("serialize Struct Variant Field not implemented") } - fn end(self) -> Result { - self.sequence_end(); - Ok(()) - } -} - -impl ser::SerializeTuple for &mut Serializer<'_, W> { - type Error = Error; - type Ok = (); - - fn serialize_element(&mut self, val: &T) -> Result { - self.record_start()?; - val.serialize(&mut **self) - } - - fn end(self) -> Result { - self.sequence_end(); - Ok(()) - } -} - -impl ser::SerializeTupleStruct for &mut Serializer<'_, W> { - type Error = Error; - type Ok = (); - - fn serialize_field(&mut self, val: &T) -> Result { - self.record_start()?; - val.serialize(&mut **self) - } - - fn end(self) -> Result { - self.sequence_end(); - Ok(()) - } -} - -impl ser::SerializeTupleVariant for &mut Serializer<'_, W> { - type Error = Error; - type Ok = (); - - fn serialize_field(&mut self, val: &T) -> Result { - self.record_start()?; - val.serialize(&mut **self) - } - - fn end(self) -> Result { - self.sequence_end(); - Ok(()) - } + fn end(self) -> Result { unimplemented!("serialize Struct Variant End not implemented") } } diff --git a/src/database/tests.rs b/src/database/tests.rs new file mode 100644 index 00000000..47dfb32c --- /dev/null +++ b/src/database/tests.rs @@ -0,0 +1,232 @@ +#![cfg(test)] +#![allow(clippy::needless_borrows_for_generic_args)] + +use std::fmt::Debug; + +use arrayvec::ArrayVec; +use conduit::ruma::{serde::Raw, RoomId, UserId}; +use serde::Serialize; + +use crate::{ + de, ser, + ser::{serialize_to_vec, Json}, + Interfix, +}; + +#[test] +#[should_panic(expected = "serializing string at the top-level")] +fn ser_str() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let s = serialize_to_vec(&user_id).expect("failed to serialize user_id"); + assert_eq!(&s, user_id.as_bytes()); +} + +#[test] +fn ser_tuple() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + let mut a = user_id.as_bytes().to_vec(); + a.push(0xFF); + a.extend_from_slice(room_id.as_bytes()); + + let b = (user_id, room_id); + let b = serialize_to_vec(&b).expect("failed to serialize tuple"); + + assert_eq!(a, b); +} + +#[test] +#[should_panic(expected = "I/O error: failed to write whole buffer")] +fn ser_overflow() { + const BUFSIZE: usize = 10; + + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + assert!(BUFSIZE < user_id.as_str().len() + room_id.as_str().len()); + let mut buf = ArrayVec::::new(); + + let val = (user_id, room_id); + _ = ser::serialize(&mut buf, val).unwrap(); +} + +#[test] +fn ser_complex() { + use conduit::ruma::Mxc; + + #[derive(Debug, Serialize)] + struct Dim { + width: u32, + height: u32, + } + + let mxc = Mxc { + server_name: "example.com".try_into().unwrap(), + media_id: "AbCdEfGhIjK", + }; + + let dim = Dim { + width: 123, + height: 456, + }; + + let mut a = Vec::new(); + a.extend_from_slice(b"mxc://"); + a.extend_from_slice(mxc.server_name.as_bytes()); + a.extend_from_slice(b"/"); + a.extend_from_slice(mxc.media_id.as_bytes()); + a.push(0xFF); + a.extend_from_slice(&dim.width.to_be_bytes()); + a.extend_from_slice(&dim.height.to_be_bytes()); + a.push(0xFF); + + let d: &[u32] = &[dim.width, dim.height]; + let b = (mxc, d, Interfix); + let b = serialize_to_vec(b).expect("failed to serialize complex"); + + assert_eq!(a, b); +} + +#[test] +fn ser_json() { + use conduit::ruma::api::client::filter::FilterDefinition; + + let filter = FilterDefinition { + event_fields: Some(vec!["content.body".to_owned()]), + ..Default::default() + }; + + let serialized = serialize_to_vec(Json(&filter)).expect("failed to serialize value"); + + let s = String::from_utf8_lossy(&serialized); + assert_eq!(&s, r#"{"event_fields":["content.body"]}"#); +} + +#[test] +fn ser_json_value() { + use conduit::ruma::api::client::filter::FilterDefinition; + + let filter = FilterDefinition { + event_fields: Some(vec!["content.body".to_owned()]), + ..Default::default() + }; + + let value = serde_json::to_value(filter).expect("failed to serialize to serde_json::value"); + let serialized = serialize_to_vec(Json(value)).expect("failed to serialize value"); + + let s = String::from_utf8_lossy(&serialized); + assert_eq!(&s, r#"{"event_fields":["content.body"]}"#); +} + +#[test] +fn ser_json_macro() { + use serde_json::json; + + #[derive(Serialize)] + struct Foo { + foo: String, + } + + let content = Foo { + foo: "bar".to_owned(), + }; + let content = serde_json::to_value(content).expect("failed to serialize content"); + let sender: &UserId = "@foo:example.com".try_into().unwrap(); + let serialized = serialize_to_vec(Json(json!({ + "sender": sender, + "content": content, + }))) + .expect("failed to serialize value"); + + let s = String::from_utf8_lossy(&serialized); + assert_eq!(&s, r#"{"content":{"foo":"bar"},"sender":"@foo:example.com"}"#); +} + +#[test] +#[should_panic(expected = "serializing string at the top-level")] +fn ser_json_raw() { + use conduit::ruma::api::client::filter::FilterDefinition; + + let filter = FilterDefinition { + event_fields: Some(vec!["content.body".to_owned()]), + ..Default::default() + }; + + let value = serde_json::value::to_raw_value(&filter).expect("failed to serialize to raw value"); + let a = serialize_to_vec(value.get()).expect("failed to serialize raw value"); + let s = String::from_utf8_lossy(&a); + assert_eq!(&s, r#"{"event_fields":["content.body"]}"#); +} + +#[test] +#[should_panic(expected = "you can skip serialization instead")] +fn ser_json_raw_json() { + use conduit::ruma::api::client::filter::FilterDefinition; + + let filter = FilterDefinition { + event_fields: Some(vec!["content.body".to_owned()]), + ..Default::default() + }; + + let value = serde_json::value::to_raw_value(&filter).expect("failed to serialize to raw value"); + let a = serialize_to_vec(Json(value)).expect("failed to serialize json value"); + let s = String::from_utf8_lossy(&a); + assert_eq!(&s, r#"{"event_fields":["content.body"]}"#); +} + +#[test] +fn de_tuple() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com\xFF!room:example.com"; + let (a, b): (&UserId, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); + assert_eq!(b, room_id, "deserialized room_id does not match"); +} + +#[test] +fn de_json_array() { + let a = &["foo", "bar", "baz"]; + let s = serde_json::to_vec(a).expect("failed to serialize to JSON array"); + + let b: Raw>> = de::from_slice(&s).expect("failed to deserialize"); + + let d: Vec = serde_json::from_str(b.json().get()).expect("failed to deserialize JSON"); + + for (i, a) in a.iter().enumerate() { + assert_eq!(*a, d[i]); + } +} + +#[test] +fn de_json_raw_array() { + let a = &["foo", "bar", "baz"]; + let s = serde_json::to_vec(a).expect("failed to serialize to JSON array"); + + let b: Raw>> = de::from_slice(&s).expect("failed to deserialize"); + + let c: Vec> = serde_json::from_str(b.json().get()).expect("failed to deserialize JSON"); + + for (i, a) in a.iter().enumerate() { + let c = serde_json::to_value(c[i].json()).expect("failed to deserialize JSON to string"); + assert_eq!(*a, c); + } +} + +#[test] +fn ser_array() { + let a: u64 = 123_456; + let b: u64 = 987_654; + + let arr: &[u64] = &[a, b]; + + let mut v = Vec::new(); + v.extend_from_slice(&a.to_be_bytes()); + v.extend_from_slice(&b.to_be_bytes()); + + let s = serialize_to_vec(arr).expect("failed to serialize"); + assert_eq!(&s, &v, "serialization does not match"); +}