diff --git a/src/backends/plonky2/primitives/ec/curve.rs b/src/backends/plonky2/primitives/ec/curve.rs index 22911b28..aea57114 100644 --- a/src/backends/plonky2/primitives/ec/curve.rs +++ b/src/backends/plonky2/primitives/ec/curve.rs @@ -161,7 +161,7 @@ impl Serialize for Point { S: Serializer, { let point_b58 = format!("{}", self); - serializer.serialize_str(&point_b58) + serializer.serialize_newtype_struct("pod2::Point", &point_b58) } } diff --git a/src/backends/plonky2/primitives/ec/schnorr.rs b/src/backends/plonky2/primitives/ec/schnorr.rs index f6b51a3b..90aa6c9d 100644 --- a/src/backends/plonky2/primitives/ec/schnorr.rs +++ b/src/backends/plonky2/primitives/ec/schnorr.rs @@ -244,7 +244,7 @@ impl Serialize for SecretKey { S: Serializer, { let sk_b64 = serialize_bytes(&self.as_bytes()); - serializer.serialize_str(&sk_b64) + serializer.serialize_newtype_struct("pod2::SecretKey", &sk_b64) } } diff --git a/src/lib.rs b/src/lib.rs index 4b7de653..76c6047e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ #![allow(clippy::uninlined_format_args)] // TODO: Remove this in another PR #![allow(clippy::manual_repeat_n)] // TODO: Remove this in another PR #![allow(clippy::large_enum_variant)] // TODO: Remove this in another PR +#![feature(macro_metavar_expr_concat)] #![feature(mapped_lock_guards)] pub mod backends; diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index 97ed113f..b4d363e5 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -59,6 +59,9 @@ pub const SELF_ID_HASH: Hash = Hash([F(0x5), F(0xe), F(0x1), F(0xf)]); pub const EMPTY_HASH: Hash = Hash([F::ZERO, F::ZERO, F::ZERO, F::ZERO]); #[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +// use pod2:: prefix to help ValueSerializer recognize RawValue +// most serializers will ignore the name +#[serde(rename = "pod2::RawValue")] pub struct RawValue( #[serde( serialize_with = "serialize_value_tuple", diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 8ae37608..003195b9 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -19,6 +19,7 @@ mod error; mod operation; mod pod_deserialization; pub mod serialization; +pub mod serializer; mod statement; use std::{any::Any, fmt}; diff --git a/src/middleware/serialization.rs b/src/middleware/serialization.rs index 68e6efbd..3f5ce4e5 100644 --- a/src/middleware/serialization.rs +++ b/src/middleware/serialization.rs @@ -4,28 +4,31 @@ use std::{ }; use plonky2::field::types::Field; -use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; +use serde::{Deserialize, Serialize, Serializer}; use super::{Key, Value}; use crate::middleware::{F, HASH_SIZE, VALUE_SIZE}; -fn serialize_field_tuple( - value: &[F; N], - serializer: S, -) -> Result -where - S: serde::Serializer, -{ +pub(super) fn field_array_to_string(value: &[F; N]) -> String { // `value` is little-endian in memory. We serialize it as a big-endian hex string // for human readability. - let s = value + value .iter() .rev() .fold(String::with_capacity(N * 16), |mut s, limb| { write!(s, "{:016x}", limb.0).unwrap(); s - }); - serializer.serialize_str(&s) + }) +} + +fn serialize_field_tuple( + value: &[F; N], + serializer: S, +) -> Result +where + S: serde::Serializer, +{ + serializer.serialize_str(&field_array_to_string(value)) } fn deserialize_field_tuple<'de, D, const N: usize>(deserializer: D) -> Result<[F; N], D::Error> @@ -131,17 +134,12 @@ where // Sets are serialized as sequences of elements, which are not ordered by // default. We want to serialize them in a deterministic way, and we can -// achieve this by sorting the elements. This takes advantage of the fact that -// Value implements Ord. +// achieve this by sorting the elements. pub fn ordered_set(value: &HashSet, serializer: S) -> Result where S: Serializer, { - let mut set = serializer.serialize_seq(Some(value.len()))?; let mut sorted_values: Vec<&Value> = value.iter().collect(); sorted_values.sort_by_key(|v| v.raw()); - for v in sorted_values { - set.serialize_element(v)?; - } - set.end() + serializer.serialize_newtype_struct("pod2::Set", &sorted_values) } diff --git a/src/middleware/serializer.rs b/src/middleware/serializer.rs new file mode 100644 index 00000000..e5764390 --- /dev/null +++ b/src/middleware/serializer.rs @@ -0,0 +1,1628 @@ +use std::{ + collections::{HashMap, HashSet}, + str::FromStr, +}; + +use base64::{prelude::BASE64_STANDARD, Engine}; +use serde::{ + de::{ + value::{ + MapAccessDeserializer, MapDeserializer, SeqDeserializer, StrDeserializer, + U32Deserializer, + }, + IntoDeserializer, Unexpected, + }, + forward_to_deserialize_any, + ser::{ + Impossible, SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, + SerializeTuple, SerializeTupleStruct, SerializeTupleVariant, + }, + Serialize, Serializer, +}; + +use super::{Key, Value}; +use crate::{ + backends::plonky2::{ + deserialize_bytes, + primitives::ec::{curve::Point, schnorr::SecretKey}, + serialize_bytes, + }, + frontend::SignedDict, + middleware::{ + containers::{Array, Dictionary, Set}, + field_array_to_string, + serialization::deserialize_value_tuple, + RawValue, TypedValue, + }, +}; + +/// Indicates that `value` should be serialized to a `Set` rather than +/// an `Array`. +/// +/// Serde regards both arrays and sets as "sequences", so the serializer +/// cannot distinguish between the two. By default, [`ValueSerializer`] will +/// serialize sequences to an `Array`. This function hints to the serializer +/// that it should serialize `value` as a `Set` instead. If `value` is not a sequence, +/// then the hint has no effect. +/// ``` +/// use pod2::middleware::{Key, TypedValue, +/// serializer::{DictionarySerializer, serialize_seq_as_set}}; +/// use std::collections::HashSet; +/// use serde::Serialize; +/// +/// #[derive(Serialize)] +/// struct ExampleStruct { +/// arr: HashSet, +/// #[serde(serialize_with = "serialize_seq_as_set")] +/// set: HashSet +/// } +/// +/// let set: HashSet = [2].into_iter().collect(); +/// let ex = ExampleStruct { +/// arr: set.clone(), +/// set +/// }; +/// let d = ex.serialize(DictionarySerializer::new(6)).unwrap(); +/// assert!(matches!(d.get(&Key::from("arr")).unwrap().typed(), TypedValue::Array(_))); +/// assert!(matches!(d.get(&Key::from("set")).unwrap().typed(), TypedValue::Set(_))); +/// ``` +pub fn serialize_seq_as_set( + value: &T, + serializer: S, +) -> Result { + serializer.serialize_newtype_struct("pod2::AsSet", value) +} + +#[derive(Clone, Copy)] +enum ValueSerializerState { + Default, + RawValue, + Point, + SecretKey, + AsSet, +} + +#[derive(Clone, Copy)] +pub struct ValueSerializer { + container_depth: usize, + state: ValueSerializerState, +} + +#[derive(Clone, Copy)] +struct OptionValueSerializer(ValueSerializer); + +#[derive(Clone, Copy)] +pub struct DictionarySerializer { + container_depth: usize, +} + +enum SeqData { + Array(Vec), + Set(HashSet), +} + +pub struct ValueSerializeSeq { + data: SeqData, + container_depth: usize, +} + +pub struct ValueSerializeTupleVariant(DictionarySerializeTupleVariant); + +pub struct ValueSerializeMap(DictionarySerializeMap); + +pub struct ValueSerializeStructVariant(DictionarySerializeStructVariant); + +pub struct DictionarySerializeTupleVariant { + name: &'static str, + inner: ValueSerializeSeq, +} + +pub struct DictionarySerializeMap { + kvs: HashMap, + next_key: Option, + container_depth: usize, +} + +pub struct DictionarySerializeStructVariant { + name: &'static str, + inner: ValueSerializeMap, +} + +struct OptionValueSerializeSeq(ValueSerializeSeq); +struct OptionValueSerializeMap(ValueSerializeMap); +struct OptionValueSerializeStructVariant(ValueSerializeStructVariant); +struct OptionValueSerializeTupleVariant(ValueSerializeTupleVariant); + +impl ValueSerializer { + pub fn new(container_depth: usize) -> Self { + Self { + container_depth, + state: ValueSerializerState::Default, + } + } + + fn dictionary_serializer(self) -> DictionarySerializer { + DictionarySerializer { + container_depth: self.container_depth, + } + } + + fn update_state_newtype(&mut self, name: &'static str) { + match name { + "pod2::RawValue" => self.state = ValueSerializerState::RawValue, + "pod2::Point" => self.state = ValueSerializerState::Point, + "pod2::SecretKey" => self.state = ValueSerializerState::SecretKey, + "pod2::AsSet" => self.state = ValueSerializerState::AsSet, + _ => (), + } + } +} + +impl DictionarySerializer { + pub fn new(container_depth: usize) -> Self { + Self { container_depth } + } + + fn value_serializer(self) -> ValueSerializer { + ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + } + } +} + +impl Serializer for ValueSerializer { + type Ok = Value; + type Error = serde::de::value::Error; + type SerializeSeq = ValueSerializeSeq; + type SerializeTuple = ValueSerializeSeq; + type SerializeTupleStruct = ValueSerializeSeq; + type SerializeMap = ValueSerializeMap; + type SerializeStruct = ValueSerializeMap; + type SerializeTupleVariant = ValueSerializeTupleVariant; + type SerializeStructVariant = ValueSerializeStructVariant; + + fn serialize_bool(self, v: bool) -> Result { + Ok(Value::from(v)) + } + + fn serialize_i8(self, v: i8) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_i16(self, v: i16) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_i32(self, v: i32) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_i64(self, v: i64) -> Result { + Ok(Value::from(v)) + } + + fn serialize_u8(self, v: u8) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_u16(self, v: u16) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_u32(self, v: u32) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_u64(self, v: u64) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_f32(self, v: f32) -> Result { + let s = format!("{v:.8e}"); + self.serialize_str(&s) + } + + fn serialize_f64(self, v: f64) -> Result { + let s = format!("{v:.16e}"); + self.serialize_str(&s) + } + + fn serialize_str(self, v: &str) -> Result { + Ok(match self.state { + ValueSerializerState::RawValue => { + let arr = deserialize_value_tuple(StrDeserializer::new(v))?; + Value::from(RawValue(arr)) + } + ValueSerializerState::Point => { + Value::from(Point::from_str(v).map_err(serde::ser::Error::custom)?) + } + ValueSerializerState::SecretKey => { + let bytes = deserialize_bytes(v).map_err(serde::ser::Error::custom)?; + let sk = SecretKey::from_bytes(&bytes).map_err(serde::ser::Error::custom)?; + Value::from(sk) + } + _ => Value::from(v), + }) + } + + fn serialize_char(self, v: char) -> Result { + self.serialize_str(&String::from(v)) + } + + fn serialize_bytes(self, v: &[u8]) -> Result { + let s = BASE64_STANDARD.encode(v); + self.serialize_str(&s) + } + + fn serialize_seq(self, _len: Option) -> Result { + Ok(ValueSerializeSeq { + data: match self.state { + ValueSerializerState::AsSet => SeqData::Set(HashSet::new()), + _ => SeqData::Array(vec![]), + }, + container_depth: self.container_depth, + }) + } + + fn serialize_tuple(self, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_unit(self) -> Result { + Ok(Value::from(false)) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + self.serialize_unit() + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + self.serialize_str(variant) + } + + fn serialize_newtype_struct( + mut self, + name: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + self.update_state_newtype(name); + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + self.dictionary_serializer() + .serialize_newtype_variant(name, variant_index, variant, value) + .map(Value::from) + } + + fn serialize_some(self, value: &T) -> Result + where + T: ?Sized + Serialize, + { + let value_serialized = value.serialize(self)?; + let mut hash_set = HashSet::new(); + hash_set.insert(value_serialized); + let set = Set::new(self.container_depth, hash_set).map_err(serde::ser::Error::custom)?; + Ok(Value::from(set)) + } + + fn serialize_none(self) -> Result { + let set = + Set::new(self.container_depth, HashSet::new()).map_err(serde::ser::Error::custom)?; + Ok(Value::from(set)) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + self.dictionary_serializer() + .serialize_tuple_variant(name, variant_index, variant, len) + .map(ValueSerializeTupleVariant) + } + + fn serialize_map(self, len: Option) -> Result { + self.dictionary_serializer() + .serialize_map(len) + .map(ValueSerializeMap) + } + + fn serialize_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_map(Some(len)) + } + + fn serialize_struct_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + self.dictionary_serializer() + .serialize_struct_variant(name, variant_index, variant, len) + .map(ValueSerializeStructVariant) + } +} + +impl SerializeSeq for ValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let val = value.serialize(ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + })?; + match &mut self.data { + SeqData::Set(s) => { + s.insert(val); + } + SeqData::Array(a) => a.push(val), + } + Ok(()) + } + + fn end(self) -> Result { + match self.data { + SeqData::Set(s) => { + let set = Set::new(self.container_depth, s).map_err(serde::de::Error::custom)?; + Ok(Value::from(set)) + } + SeqData::Array(a) => { + let arr = Array::new(self.container_depth, a).map_err(serde::de::Error::custom)?; + Ok(Value::from(arr)) + } + } + } +} + +impl SerializeTuple for ValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result { + SerializeSeq::end(self) + } +} + +impl SerializeTupleStruct for ValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result { + SerializeSeq::end(self) + } +} + +impl SerializeTupleVariant for ValueSerializeTupleVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(value) + } + + fn end(self) -> Result { + SerializeTupleVariant::end(self.0).map(Value::from) + } +} + +impl SerializeMap for ValueSerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_key(key) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_value(value) + } + + fn end(self) -> Result { + SerializeMap::end(self.0).map(Value::from) + } +} + +impl SerializeStruct for ValueSerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeStruct::serialize_field(&mut self.0, key, value) + } + + fn end(self) -> Result { + SerializeStruct::end(self.0).map(Value::from) + } +} + +impl SerializeStructVariant for ValueSerializeStructVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(key, value) + } + + fn end(self) -> Result { + SerializeStructVariant::end(self.0).map(Value::from) + } +} + +macro_rules! serialize_type { + ( bytes ) => { + &[u8] + }; + ( str ) => { + &str + }; + ( unit_struct ) => { + &'static str + }; + ( $name: ident ) => { + $name + }; +} + +macro_rules! map_expected { + ( $($item: ident )* ) => { + $( + fn ${concat(serialize_, $item)}(self, _: serialize_type!($item)) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + )* + } +} + +impl Serializer for DictionarySerializer { + type Ok = Dictionary; + type Error = serde::de::value::Error; + type SerializeSeq = Impossible; + type SerializeTuple = Impossible; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = DictionarySerializeTupleVariant; + type SerializeMap = DictionarySerializeMap; + type SerializeStruct = DictionarySerializeMap; + type SerializeStructVariant = DictionarySerializeStructVariant; + + map_expected!(bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 bytes str char unit_struct); + + fn serialize_none(self) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_some(self, _value: &T) -> Result + where + T: ?Sized + Serialize, + { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_unit(self) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + let ser_value = value.serialize(self.value_serializer())?; + let mut map = HashMap::new(); + map.insert(Key::from(variant), ser_value); + Dictionary::new(self.container_depth, map).map_err(serde::ser::Error::custom) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + Ok(DictionarySerializeTupleVariant { + name: variant, + inner: ValueSerializeSeq { + data: SeqData::Array(Vec::new()), + container_depth: self.container_depth, + }, + }) + } + + fn serialize_map(self, _len: Option) -> Result { + Ok(DictionarySerializeMap { + kvs: HashMap::new(), + container_depth: self.container_depth, + next_key: None, + }) + } + + fn serialize_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_map(Some(len)) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + Ok(DictionarySerializeStructVariant { + name: variant, + inner: ValueSerializeMap(self.serialize_map(Some(len))?), + }) + } +} + +impl SerializeTupleVariant for DictionarySerializeTupleVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeSeq::serialize_element(&mut self.inner, value) + } + + fn end(self) -> Result { + let max_depth = self.inner.container_depth; + let arr = SerializeSeq::end(self.inner)?; + let mut map = HashMap::new(); + map.insert(Key::new(self.name.to_string()), arr); + let dict = Dictionary::new(max_depth, map).map_err(serde::de::Error::custom)?; + Ok(dict) + } +} + +impl DictionarySerializeMap { + fn convert_key(&self, key: &T) -> Result::Error> + where + T: ?Sized + Serialize, + { + let key_ser = key.serialize(ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + })?; + if let TypedValue::String(s) = key_ser.typed() { + Ok(Key::new(s.clone())) + } else { + Err(serde::de::Error::invalid_value( + Unexpected::Other("non-string key in map"), + &"string", + )) + } + } +} + +impl SerializeMap for DictionarySerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.next_key = Some(self.convert_key(key)?); + Ok(()) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let val_ser = value.serialize(ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + })?; + self.kvs.insert( + self.next_key + .take() + .expect("serialize_key should be called before serialize_value"), + val_ser, + ); + Ok(()) + } + + fn end(self) -> Result { + let dict = + Dictionary::new(self.container_depth, self.kvs).map_err(serde::ser::Error::custom)?; + Ok(dict) + } +} + +impl SerializeStruct for DictionarySerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + //SerializeMap::serialize_entry(self, key, value) + let opt_ser = value.serialize(OptionValueSerializer(ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + }))?; + if let Some(val_ser) = opt_ser { + let key = self.convert_key(key)?; + self.kvs.insert(key, val_ser); + } + Ok(()) + } + + fn end(self) -> Result { + SerializeMap::end(self) + } +} + +impl SerializeStructVariant for DictionarySerializeStructVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeMap::serialize_entry(&mut self.inner, key, value) + } + + fn end(self) -> Result { + let depth = self.inner.0.container_depth; + let value = SerializeMap::end(self.inner)?; + let mut kvs = HashMap::new(); + kvs.insert(Key::new(self.name.to_string()), value); + let dict = Dictionary::new(depth, kvs).map_err(serde::ser::Error::custom)?; + Ok(dict) + } +} + +macro_rules! option_serializer_forward { + ( $($item: ident )* ) => { + $( + fn ${concat(serialize_, $item)}(self, v: serialize_type!($item)) -> Result { + self.0.${concat(serialize_, $item)}(v).map(Some) + } + )* + } +} + +impl Serializer for OptionValueSerializer { + type Ok = Option; + type Error = ::Error; + type SerializeSeq = OptionValueSerializeSeq; + type SerializeTuple = OptionValueSerializeSeq; + type SerializeTupleStruct = OptionValueSerializeSeq; + type SerializeMap = OptionValueSerializeMap; + type SerializeStruct = OptionValueSerializeMap; + type SerializeTupleVariant = OptionValueSerializeTupleVariant; + type SerializeStructVariant = OptionValueSerializeStructVariant; + + option_serializer_forward!(bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char bytes str); + + fn serialize_none(self) -> Result { + Ok(None) + } + + fn serialize_some(self, value: &T) -> Result + where + T: ?Sized + Serialize, + { + value.serialize(self.0).map(Some) + } + + fn serialize_unit(self) -> Result { + self.0.serialize_unit().map(Some) + } + + fn serialize_unit_struct(self, name: &'static str) -> Result { + self.0.serialize_unit_struct(name).map(Some) + } + + fn serialize_unit_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + ) -> Result { + self.0 + .serialize_unit_variant(name, variant_index, variant) + .map(Some) + } + + fn serialize_newtype_struct( + mut self, + name: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + self.0.update_state_newtype(name); + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + self.0 + .serialize_newtype_variant(name, variant_index, variant, value) + .map(Some) + } + + fn serialize_seq(self, len: Option) -> Result { + self.0.serialize_seq(len).map(OptionValueSerializeSeq) + } + + fn serialize_tuple(self, len: usize) -> Result { + self.0.serialize_tuple(len).map(OptionValueSerializeSeq) + } + + fn serialize_tuple_struct( + self, + name: &'static str, + len: usize, + ) -> Result { + self.0 + .serialize_tuple_struct(name, len) + .map(OptionValueSerializeSeq) + } + + fn serialize_tuple_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + self.0 + .serialize_tuple_variant(name, variant_index, variant, len) + .map(OptionValueSerializeTupleVariant) + } + + fn serialize_struct( + self, + name: &'static str, + len: usize, + ) -> Result { + self.0 + .serialize_struct(name, len) + .map(OptionValueSerializeMap) + } + + fn serialize_map(self, len: Option) -> Result { + self.0.serialize_map(len).map(OptionValueSerializeMap) + } + + fn serialize_struct_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + self.0 + .serialize_struct_variant(name, variant_index, variant, len) + .map(OptionValueSerializeStructVariant) + } +} + +impl SerializeSeq for OptionValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeSeq::serialize_element(&mut self.0, value) + } + + fn end(self) -> Result { + SerializeSeq::end(self.0).map(Some) + } +} + +impl SerializeTuple for OptionValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeTuple::serialize_element(&mut self.0, value) + } + + fn end(self) -> Result { + SerializeTuple::end(self.0).map(Some) + } +} + +impl SerializeTupleStruct for OptionValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(value) + } + + fn end(self) -> Result { + SerializeTupleStruct::end(self.0).map(Some) + } +} + +impl SerializeMap for OptionValueSerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_key(key) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_value(value) + } + + fn end(self) -> Result { + SerializeMap::end(self.0).map(Some) + } +} + +impl SerializeStruct for OptionValueSerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(key, value) + } + + fn end(self) -> Result { + SerializeStruct::end(self.0).map(Some) + } +} + +impl SerializeStructVariant for OptionValueSerializeStructVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(key, value) + } + + fn end(self) -> Result { + self.0.end().map(Some) + } +} + +impl SerializeTupleVariant for OptionValueSerializeTupleVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(value) + } + + fn end(self) -> Result { + self.0.end().map(Some) + } +} + +impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a TypedValue { + type Deserializer = Self; + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a Value { + type Deserializer = &'a TypedValue; + fn into_deserializer(self) -> Self::Deserializer { + self.typed() + } +} + +impl<'a, 'de, E: serde::de::Error> IntoDeserializer<'de, E> for &'a Key { + type Deserializer = StrDeserializer<'a, E>; + fn into_deserializer(self) -> Self::Deserializer { + StrDeserializer::new(&self.name) + } +} + +impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a Dictionary { + type Deserializer = Self; + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a SignedDict { + type Deserializer = &'a Dictionary; + fn into_deserializer(self) -> Self::Deserializer { + &self.dict + } +} + +impl<'de> serde::Deserializer<'de> for &Dictionary { + type Error = serde::de::value::Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_map(MapDeserializer::new(self.kvs().iter())) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_enum(MapAccessDeserializer::new(MapDeserializer::new( + self.kvs().iter(), + ))) + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_struct( + self, + _name: &'static str, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_map(MapDeserializer::new( + self.kvs().iter().map(|(k, v)| (k, StructField(v.typed()))), + )) + } + + forward_to_deserialize_any! { bool i8 i16 i32 i64 f32 f64 u8 u16 u32 u64 char str bytes byte_buf string option unit unit_struct seq tuple_struct tuple map newtype_struct identifier } +} + +impl<'de> serde::Deserializer<'de> for &TypedValue { + type Error = serde::de::value::Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::Int(i) => visitor.visit_i64(*i), + TypedValue::Raw(v) => visitor.visit_string(field_array_to_string(&v.0)), + TypedValue::PublicKey(k) => visitor.visit_string(k.to_string()), + TypedValue::SecretKey(k) => visitor.visit_string(serialize_bytes(&k.as_bytes())), + TypedValue::Bool(b) => visitor.visit_bool(*b), + TypedValue::Array(a) => visitor.visit_seq(SeqDeserializer::new(a.array().iter())), + TypedValue::Set(s) => visitor.visit_seq(SeqDeserializer::new(s.set().iter())), + TypedValue::String(s) => visitor.visit_str(s), + TypedValue::Dictionary(d) => d.deserialize_any(visitor), + } + } + + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => visitor.visit_enum(StrDeserializer::new(s)), + TypedValue::Int(i) => { + if let Ok(u) = u32::try_from(*i) { + visitor.visit_enum(U32Deserializer::new(u)) + } else { + self.deserialize_any(visitor) + } + } + TypedValue::Dictionary(d) => d.deserialize_enum(name, variants, visitor), + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::Set(s) if s.set().len() <= 1 => match s.set().iter().next() { + Some(x) => visitor.visit_some(x.typed()), + None => visitor.visit_none(), + }, + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + let b = BASE64_STANDARD + .decode(s) + .map_err(|e| serde::de::Error::custom(e.to_string()))?; + visitor.visit_bytes(&b) + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + let b = BASE64_STANDARD + .decode(s) + .map_err(|e| serde::de::Error::custom(e.to_string()))?; + visitor.visit_byte_buf(b) + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_f32(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + if let Ok(f) = f32::from_str(s) { + visitor.visit_f32(f) + } else { + self.deserialize_any(visitor) + } + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_f64(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + if let Ok(f) = f64::from_str(s) { + visitor.visit_f64(f) + } else { + self.deserialize_any(visitor) + } + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::Dictionary(d) => d.deserialize_struct(name, fields, visitor), + _ => self.deserialize_any(visitor), + } + } + + forward_to_deserialize_any! { bool i8 i16 i32 i64 u8 u16 u32 u64 char str string unit_struct seq tuple_struct tuple map identifier ignored_any } +} + +struct StructField<'a>(&'a TypedValue); + +impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for StructField<'a> { + type Deserializer = Self; + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +macro_rules! deserialize_forward { + ( $($item: ident )* ) => { + $( + fn ${concat(deserialize_, $item)}(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.${concat(deserialize_, $item)}(visitor) + } + )* + } +} + +impl<'de> serde::Deserializer<'de> for StructField<'_> { + type Error = serde::de::value::Error; + + deserialize_forward!(any bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 str char string bytes byte_buf unit seq ignored_any map identifier); + + fn deserialize_unit_struct( + self, + name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.deserialize_unit_struct(name, visitor) + } + + fn deserialize_newtype_struct( + self, + name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.deserialize_newtype_struct(name, visitor) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.deserialize_tuple(len, visitor) + } + + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.deserialize_struct(name, fields, visitor) + } + + fn deserialize_tuple_struct( + self, + name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.deserialize_tuple_struct(name, len, visitor) + } + + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.deserialize_enum(name, variants, visitor) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_some(self.0) + } +} + +#[cfg(test)] +mod test { + + use std::collections::{HashMap, HashSet}; + + use serde::{ + de::{DeserializeOwned, Visitor}, + Deserialize, Serialize, + }; + + use crate::{ + backends::plonky2::primitives::ec::{curve::Point, schnorr::SecretKey}, + middleware::{ + serializer::{serialize_seq_as_set, DictionarySerializer, ValueSerializer}, + Params, RawValue, TypedValue, Value, + }, + }; + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] + enum Method { + Search, + Mine, + } + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] + struct Inner { + ch: char, + b: bool, + } + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] + struct Tuple(u8, u32); + + #[derive(PartialEq, Eq, Debug, Clone)] + struct Bytes(Vec); + + struct BytesVisitor; + + impl<'de> Visitor<'de> for BytesVisitor { + type Value = Bytes; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a byte buffer") + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: serde::de::Error, + { + Ok(Bytes(v)) + } + } + + impl Serialize for Bytes { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(&self.0) + } + } + + impl<'de> Deserialize<'de> for Bytes { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_byte_buf(BytesVisitor) + } + } + + #[derive(Serialize, Deserialize, Debug, Clone)] + struct Float(f32); + + impl PartialEq for Float { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 || (self.0.is_nan() && other.0.is_nan()) + } + } + + impl Eq for Float {} + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] + enum Fancy { + B(i64, i64), + C { x: i64, y: Vec }, + } + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] + struct ComplicatedStruct { + id: i64, + name: String, + seed_range: Vec<(Method, RawValue)>, + option1: Option, + option2: Option, + vec_opt: Vec>, + optopt1: Option>, + optopt2: Option>, + optopt3: Option>, + unit: (), + sk: SecretKey, + pk: Point, + fancy1: Fancy, + fancy2: Fancy, + inner: Inner, + tuple: Tuple, + bytes: Bytes, + float: Float, + inf: Float, + nan: Float, + map: HashMap, + #[serde(serialize_with = "serialize_seq_as_set")] + set: HashSet, + } + + fn test_roundtrip(t: T) { + let depth = Params::default().max_depth_mt_containers; + let val = t.serialize(ValueSerializer::new(depth)).unwrap(); + let out: T = Deserialize::deserialize(val.typed()).unwrap(); + assert_eq!(t, out); + } + + fn test_roundtrip_dict(t: T) { + let depth = Params::default().max_depth_mt_containers; + let dict = t.serialize(DictionarySerializer::new(depth)).unwrap(); + let out: T = Deserialize::deserialize(&dict).unwrap(); + assert_eq!(t, out); + } + + fn test_dict_consistency(t: T) { + let depth = Params::default().max_depth_mt_containers; + let val = t.serialize(ValueSerializer::new(depth)).unwrap(); + let dict = t.serialize(DictionarySerializer::new(depth)).unwrap(); + assert_eq!(val, Value::from(dict)); + } + + fn test_dict_consistency_and_roundtrip< + T: Serialize + DeserializeOwned + Eq + Clone + std::fmt::Debug, + >( + t: T, + ) { + test_dict_consistency(t.clone()); + test_roundtrip_dict(t) + } + + fn test_preserved_ser(t: T) + where + TypedValue: From, + { + let depth = Params::default().max_depth_mt_containers; + let ser = t.serialize(ValueSerializer::new(depth)).unwrap(); + let val = Value::from(TypedValue::from(t)); + assert_eq!(ser, val); + } + + fn test_preserved_de(t: T) + where + TypedValue: From, + { + let de = T::deserialize(&TypedValue::from(t.clone())).unwrap(); + assert_eq!(de, t); + } + + fn test_preserved(t: T) + where + TypedValue: From, + { + test_preserved_ser(t.clone()); + test_preserved_de(t); + } + + fn a_hash_map() -> HashMap { + let mut map = HashMap::new(); + map.insert("a".to_string(), 1); + map.insert("b".to_string(), 2); + map + } + + fn a_hash_set() -> HashSet { + [2, 3].into_iter().collect() + } + + #[test] + fn test_complicated_struct() { + let seed_range = vec![ + (Method::Search, RawValue::default()), + (Method::Mine, RawValue::default()), + ]; + let sk = SecretKey::new_rand(); + let pk = sk.public_key(); + let desc = ComplicatedStruct { + id: 1, + name: "a frog".to_string(), + seed_range, + option1: Some(2), + option2: None, + optopt1: Some(Some(4)), + optopt2: Some(None), + optopt3: None, + vec_opt: vec![Some(3), None], + unit: (), + sk, + pk, + fancy1: Fancy::B(0, 1), + fancy2: Fancy::C { + x: 1, + y: vec![2, 3], + }, + inner: Inner { + ch: '\u{200b}', + b: true, + }, + tuple: Tuple(5, 6), + bytes: Bytes(b"abc".to_vec()), + float: Float(3.0), + inf: Float(f32::NEG_INFINITY), + nan: Float(f32::NAN), + map: a_hash_map(), + set: a_hash_set(), + }; + test_roundtrip(desc.clone()); + test_dict_consistency_and_roundtrip(desc); + } + + #[test] + fn test_dict_serialization() { + test_dict_consistency_and_roundtrip(Fancy::B(0, 1)); + test_dict_consistency_and_roundtrip(Fancy::C { + x: 1, + y: vec![2, 3], + }); + test_dict_consistency_and_roundtrip(a_hash_map()); + } + + #[test] + fn test_pod_types() { + let raw = RawValue::default(); + let sk = SecretKey::new_rand(); + let pt = sk.public_key(); + test_preserved(raw); + test_preserved(sk); + test_preserved(pt); + } +}