From 44f58d6f7d22d6ebed955d2ccfe9a85e2ccd3ad4 Mon Sep 17 00:00:00 2001 From: Daniel Gulotta Date: Tue, 19 Aug 2025 12:52:43 -0700 Subject: [PATCH 1/9] serialization --- src/backends/plonky2/primitives/ec/curve.rs | 2 +- src/backends/plonky2/primitives/ec/schnorr.rs | 2 +- src/middleware/serialization.rs | 646 +++++++++++++++++- 3 files changed, 630 insertions(+), 20 deletions(-) diff --git a/src/backends/plonky2/primitives/ec/curve.rs b/src/backends/plonky2/primitives/ec/curve.rs index d9e7ab5b..d71191d3 100644 --- a/src/backends/plonky2/primitives/ec/curve.rs +++ b/src/backends/plonky2/primitives/ec/curve.rs @@ -150,7 +150,7 @@ impl Serialize for Point { S: Serializer, { let point_b58 = format!("{}", self); - serializer.serialize_str(&point_b58) + serializer.serialize_newtype_struct("Point", &point_b58) } } diff --git a/src/backends/plonky2/primitives/ec/schnorr.rs b/src/backends/plonky2/primitives/ec/schnorr.rs index 6de3bffd..146c7b85 100644 --- a/src/backends/plonky2/primitives/ec/schnorr.rs +++ b/src/backends/plonky2/primitives/ec/schnorr.rs @@ -233,7 +233,7 @@ impl Serialize for SecretKey { S: Serializer, { let sk_b64 = serialize_bytes(&self.as_bytes()); - serializer.serialize_str(&sk_b64) + serializer.serialize_newtype_struct("SecretKey", &sk_b64) } } diff --git a/src/middleware/serialization.rs b/src/middleware/serialization.rs index 68e6efbd..76399f87 100644 --- a/src/middleware/serialization.rs +++ b/src/middleware/serialization.rs @@ -1,31 +1,56 @@ use std::{ collections::{HashMap, HashSet}, fmt::Write, + str::FromStr, }; use plonky2::field::types::Field; -use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; +use serde::{ + de::{ + value::{MapDeserializer, SeqDeserializer, StrDeserializer, U32Deserializer}, + IntoDeserializer, Unexpected, + }, + forward_to_deserialize_any, + ser::{ + SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, + SerializeTupleStruct, SerializeTupleVariant, + }, + Deserialize, Serialize, Serializer, +}; use super::{Key, Value}; -use crate::middleware::{F, HASH_SIZE, VALUE_SIZE}; +use crate::{ + backends::plonky2::{ + deserialize_bytes, + primitives::ec::{curve::Point, schnorr::SecretKey}, + serialize_bytes, + }, + middleware::{ + containers::{Array, Dictionary}, + RawValue, TypedValue, F, HASH_SIZE, VALUE_SIZE, + }, +}; -fn serialize_field_tuple( - value: &[F; N], - serializer: S, -) -> Result -where - S: serde::Serializer, -{ +fn field_array_to_string(value: &[F; N]) -> String { // `value` is little-endian in memory. We serialize it as a big-endian hex string // for human readability. - let s = value + value .iter() .rev() .fold(String::with_capacity(N * 16), |mut s, limb| { write!(s, "{:016x}", limb.0).unwrap(); s - }); - serializer.serialize_str(&s) + }) +} + +fn serialize_field_tuple( + value: &[F; N], + serializer: S, +) -> Result +where + S: serde::Serializer, +{ + serializer.serialize_str(&field_array_to_string(value)) } fn deserialize_field_tuple<'de, D, const N: usize>(deserializer: D) -> Result<[F; N], D::Error> @@ -131,17 +156,602 @@ where // Sets are serialized as sequences of elements, which are not ordered by // default. We want to serialize them in a deterministic way, and we can -// achieve this by sorting the elements. This takes advantage of the fact that -// Value implements Ord. +// achieve this by sorting the elements. pub fn ordered_set(value: &HashSet, serializer: S) -> Result where S: Serializer, { - let mut set = serializer.serialize_seq(Some(value.len()))?; let mut sorted_values: Vec<&Value> = value.iter().collect(); sorted_values.sort_by_key(|v| v.raw()); - for v in sorted_values { - set.serialize_element(v)?; + serializer.serialize_newtype_struct("Set", &sorted_values) +} + +#[derive(Clone, Copy)] +enum ValueSerializerState { + Default, + RawValue, + Point, + SecretKey, +} + +#[derive(Clone, Copy)] +pub struct ValueSerializer { + container_depth: usize, + state: ValueSerializerState, +} + +pub struct ValueSerializeSeq { + data: Vec, + container_depth: usize, +} + +pub struct ValueSerializeTupleVariant { + name: &'static str, + inner: ValueSerializeSeq, +} + +pub struct ValueSerializeMap { + kvs: HashMap, + next_key: Option, + container_depth: usize, +} + +pub struct ValueSerializeStructVariant { + name: &'static str, + inner: ValueSerializeMap, +} + +impl ValueSerializer { + pub fn new(container_depth: usize) -> Self { + Self { + container_depth, + state: ValueSerializerState::Default, + } + } +} + +impl Serializer for ValueSerializer { + type Ok = Value; + type Error = serde::de::value::Error; + type SerializeSeq = ValueSerializeSeq; + type SerializeTuple = ValueSerializeSeq; + type SerializeTupleStruct = ValueSerializeSeq; + type SerializeMap = ValueSerializeMap; + type SerializeStruct = ValueSerializeMap; + type SerializeTupleVariant = ValueSerializeTupleVariant; + type SerializeStructVariant = ValueSerializeStructVariant; + + fn serialize_bool(self, v: bool) -> Result { + Ok(Value::from(v)) + } + + fn serialize_i8(self, v: i8) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_i16(self, v: i16) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_i32(self, v: i32) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_i64(self, v: i64) -> Result { + Ok(Value::from(v)) + } + + fn serialize_u8(self, v: u8) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_u16(self, v: u16) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_u32(self, v: u32) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_u64(self, v: u64) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_f32(self, v: f32) -> Result { + self.serialize_f64(v as f64) + } + + fn serialize_f64(self, v: f64) -> Result { + // serialize as string? + Err(serde::de::Error::invalid_value( + Unexpected::Float(v), + &"pod-compatible type", + )) + } + + fn serialize_str(self, v: &str) -> Result { + Ok(match self.state { + ValueSerializerState::RawValue => { + let arr = deserialize_value_tuple(StrDeserializer::new(v))?; + Value::from(RawValue(arr)) + } + ValueSerializerState::Point => { + Value::from(Point::from_str(v).map_err(serde::ser::Error::custom)?) + } + ValueSerializerState::SecretKey => { + let bytes = deserialize_bytes(v).map_err(serde::ser::Error::custom)?; + let sk = SecretKey::from_bytes(&bytes).map_err(serde::ser::Error::custom)?; + Value::from(sk) + } + _ => Value::from(v), + }) + } + + fn serialize_char(self, v: char) -> Result { + self.serialize_str(&String::from(v)) + } + + fn serialize_bytes(self, v: &[u8]) -> Result { + // TODO: serialize as Array, or base64 string? + Err(serde::de::Error::invalid_value( + Unexpected::Bytes(v), + &"pod-compatible type", + )) + } + + fn serialize_seq(self, _len: Option) -> Result { + Ok(ValueSerializeSeq { + data: Vec::new(), + container_depth: self.container_depth, + }) + } + + fn serialize_tuple(self, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_unit(self) -> Result { + SerializeTuple::end(self.serialize_tuple(0)?) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + self.serialize_unit() + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + self.serialize_str(variant) + } + + fn serialize_newtype_struct( + mut self, + name: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + println!("nt struct {name}"); + match name { + "RawValue" => self.state = ValueSerializerState::RawValue, + "Point" => self.state = ValueSerializerState::Point, + "SecretKey" => self.state = ValueSerializerState::SecretKey, + _ => (), + } + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + let ser_value = value.serialize(self)?; + let mut map = HashMap::new(); + map.insert(Key::from(variant), ser_value); + Ok(Value::from( + Dictionary::new(self.container_depth, map).map_err(serde::de::Error::custom)?, + )) + } + + fn serialize_some(self, value: &T) -> Result + where + T: ?Sized + Serialize, + { + self.serialize_newtype_variant("Option", 0, "Some", value) + } + + fn serialize_none(self) -> Result { + self.serialize_unit_variant("Option", 1, "None") + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_variant( + self, + name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Ok(ValueSerializeTupleVariant { + name, + inner: ValueSerializeSeq { + data: Vec::new(), + container_depth: self.container_depth, + }, + }) + } + + fn serialize_map(self, _len: Option) -> Result { + Ok(ValueSerializeMap { + kvs: HashMap::new(), + container_depth: self.container_depth, + next_key: None, + }) + } + + fn serialize_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_map(Some(len)) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + Ok(ValueSerializeStructVariant { + name: variant, + inner: self.serialize_map(Some(len))?, + }) + } +} + +impl SerializeSeq for ValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.data.push(value.serialize(ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + })?); + Ok(()) + } + + fn end(self) -> Result { + let arr = Array::new(self.container_depth, self.data).map_err(serde::de::Error::custom)?; + Ok(Value::from(arr)) + } +} + +impl SerializeTuple for ValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result { + SerializeSeq::end(self) + } +} + +impl SerializeTupleStruct for ValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result { + SerializeSeq::end(self) + } +} + +impl SerializeTupleVariant for ValueSerializeTupleVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeSeq::serialize_element(&mut self.inner, value) + } + + fn end(self) -> Result { + let max_depth = self.inner.container_depth; + let arr = SerializeSeq::end(self.inner)?; + let mut map = HashMap::new(); + map.insert(Key::new(self.name.to_string()), arr); + let dict = Dictionary::new(max_depth, map).map_err(serde::de::Error::custom)?; + Ok(Value::from(dict)) + } +} + +impl SerializeMap for ValueSerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let key_ser = key.serialize(ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + })?; + if let TypedValue::String(s) = key_ser.typed() { + self.next_key = Some(Key::new(s.clone())); + Ok(()) + } else { + Err(serde::de::Error::invalid_value( + Unexpected::Other("non-string key in map"), + &"string", + )) + } + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let val_ser = value.serialize(ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + })?; + self.kvs.insert( + self.next_key + .take() + .expect("serialize_key should be called before serialize_value"), + val_ser, + ); + Ok(()) + } + + fn end(self) -> Result { + let dict = + Dictionary::new(self.container_depth, self.kvs).map_err(serde::ser::Error::custom)?; + Ok(Value::from(dict)) + } +} + +impl SerializeStruct for ValueSerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeMap::serialize_entry(self, key, value) + } + + fn end(self) -> Result { + SerializeMap::end(self) + } +} + +impl SerializeStructVariant for ValueSerializeStructVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeMap::serialize_entry(&mut self.inner, key, value) + } + + fn end(self) -> Result { + let depth = self.inner.container_depth; + let value = SerializeMap::end(self.inner)?; + let mut kvs = HashMap::new(); + kvs.insert(Key::new(self.name.to_string()), value); + let dict = Dictionary::new(depth, kvs).map_err(serde::ser::Error::custom)?; + Ok(Value::from(dict)) + } +} + +impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a TypedValue { + type Deserializer = Self; + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a Value { + type Deserializer = &'a TypedValue; + fn into_deserializer(self) -> Self::Deserializer { + self.typed() + } +} + +impl<'a, 'de, E: serde::de::Error> IntoDeserializer<'de, E> for &'a Key { + type Deserializer = StrDeserializer<'a, E>; + fn into_deserializer(self) -> Self::Deserializer { + StrDeserializer::new(&self.name) + } +} + +impl<'de> serde::Deserializer<'de> for &TypedValue { + type Error = serde::de::value::Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::Int(i) => visitor.visit_i64(*i), + TypedValue::Raw(v) => visitor.visit_string(field_array_to_string(&v.0)), + TypedValue::PublicKey(k) => visitor.visit_string(k.to_string()), + TypedValue::SecretKey(k) => visitor.visit_string(serialize_bytes(&k.as_bytes())), + TypedValue::Bool(b) => visitor.visit_bool(*b), + TypedValue::Array(a) => visitor.visit_seq(SeqDeserializer::new(a.array().iter())), + TypedValue::Set(s) => visitor.visit_seq(SeqDeserializer::new(s.set().iter())), + TypedValue::String(s) => visitor.visit_str(s), + TypedValue::PodId(i) => { + visitor.visit_seq(SeqDeserializer::new(i.0 .0.iter().map(|x| x.0))) + } + TypedValue::Dictionary(d) => visitor.visit_map(MapDeserializer::new(d.kvs().iter())), + } + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => visitor.visit_enum(StrDeserializer::new(s)), + TypedValue::Int(i) => { + if let Ok(u) = u32::try_from(*i) { + visitor.visit_enum(U32Deserializer::new(u)) + } else { + self.deserialize_any(visitor) + } + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + if s == "None" { + visitor.visit_none() + } else { + self.deserialize_any(visitor) + } + } + TypedValue::Dictionary(d) => { + if let Ok(v) = d.get(&Key::from("Some")) { + visitor.visit_some(v.typed()) + } else { + self.deserialize_any(visitor) + } + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::Array(a) if a.array().is_empty() => visitor.visit_unit(), + _ => self.deserialize_any(visitor), + } + } + + forward_to_deserialize_any! { bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes byte_buf unit_struct seq tuple_struct tuple map struct identifier ignored_any } +} + +#[cfg(test)] +mod test { + use serde::{de::DeserializeOwned, Deserialize, Serialize}; + + use crate::{ + backends::plonky2::primitives::ec::{curve::Point, schnorr::SecretKey}, + middleware::{serialization::ValueSerializer, Params, RawValue}, + }; + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + enum Method { + Search, + Mine, + } + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + struct FrogDesc { + frog_id: i64, + name: String, + seed_range: Vec<(Method, RawValue)>, + option1: Option, + option2: Option, + unit: (), + sk: SecretKey, + pk: Point, + } + + fn test_roundtrip(t: T) { + let depth = Params::default().max_depth_mt_containers; + let val = t.serialize(ValueSerializer::new(depth)).unwrap(); + println!("{val:?}"); + let out: T = Deserialize::deserialize(val.typed()).unwrap(); + assert_eq!(t, out); + } + + #[test] + fn test_frog_desc() { + let seed_range = vec![ + (Method::Search, RawValue::default()), + (Method::Mine, RawValue::default()), + ]; + let sk = SecretKey::new_rand(); + let pk = sk.public_key(); + let desc = FrogDesc { + frog_id: 1, + name: "a frog".to_string(), + seed_range, + option1: Some(2), + option2: None, + unit: (), + sk, + pk, + }; + test_roundtrip(desc); } - set.end() } From ee8b87f2bbb8bfb40630186535e0307faac0b759 Mon Sep 17 00:00:00 2001 From: Daniel Gulotta Date: Mon, 25 Aug 2025 10:49:28 -0700 Subject: [PATCH 2/9] support for more types --- src/middleware/serialization.rs | 214 +++++++++++++++++++++++++++++--- 1 file changed, 195 insertions(+), 19 deletions(-) diff --git a/src/middleware/serialization.rs b/src/middleware/serialization.rs index 76399f87..8ede5f23 100644 --- a/src/middleware/serialization.rs +++ b/src/middleware/serialization.rs @@ -4,10 +4,14 @@ use std::{ str::FromStr, }; +use base64::{prelude::BASE64_STANDARD, Engine}; use plonky2::field::types::Field; use serde::{ de::{ - value::{MapDeserializer, SeqDeserializer, StrDeserializer, U32Deserializer}, + value::{ + MapAccessDeserializer, MapDeserializer, SeqDeserializer, StrDeserializer, + U32Deserializer, + }, IntoDeserializer, Unexpected, }, forward_to_deserialize_any, @@ -258,15 +262,13 @@ impl Serializer for ValueSerializer { } fn serialize_f32(self, v: f32) -> Result { - self.serialize_f64(v as f64) + let s = format!("{v:.8e}"); + self.serialize_str(&s) } fn serialize_f64(self, v: f64) -> Result { - // serialize as string? - Err(serde::de::Error::invalid_value( - Unexpected::Float(v), - &"pod-compatible type", - )) + let s = format!("{v:.16e}"); + self.serialize_str(&s) } fn serialize_str(self, v: &str) -> Result { @@ -292,11 +294,8 @@ impl Serializer for ValueSerializer { } fn serialize_bytes(self, v: &[u8]) -> Result { - // TODO: serialize as Array, or base64 string? - Err(serde::de::Error::invalid_value( - Unexpected::Bytes(v), - &"pod-compatible type", - )) + let s = BASE64_STANDARD.encode(v); + self.serialize_str(&s) } fn serialize_seq(self, _len: Option) -> Result { @@ -359,7 +358,7 @@ impl Serializer for ValueSerializer { let mut map = HashMap::new(); map.insert(Key::from(variant), ser_value); Ok(Value::from( - Dictionary::new(self.container_depth, map).map_err(serde::de::Error::custom)?, + Dictionary::new(self.container_depth, map).map_err(serde::ser::Error::custom)?, )) } @@ -384,13 +383,13 @@ impl Serializer for ValueSerializer { fn serialize_tuple_variant( self, - name: &'static str, + _name: &'static str, _variant_index: u32, - _variant: &'static str, + variant: &'static str, _len: usize, ) -> Result { Ok(ValueSerializeTupleVariant { - name, + name: variant, inner: ValueSerializeSeq { data: Vec::new(), container_depth: self.container_depth, @@ -648,6 +647,9 @@ impl<'de> serde::Deserializer<'de> for &TypedValue { self.deserialize_any(visitor) } } + TypedValue::Dictionary(d) => visitor.visit_enum(MapAccessDeserializer::new( + MapDeserializer::new(d.kvs().iter()), + )), _ => self.deserialize_any(visitor), } } @@ -696,16 +698,82 @@ impl<'de> serde::Deserializer<'de> for &TypedValue { } } - forward_to_deserialize_any! { bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes byte_buf unit_struct seq tuple_struct tuple map struct identifier ignored_any } + fn deserialize_bytes(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + let b = BASE64_STANDARD + .decode(s) + .map_err(|e| serde::de::Error::custom(e.to_string()))?; + visitor.visit_bytes(&b) + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + let b = BASE64_STANDARD + .decode(s) + .map_err(|e| serde::de::Error::custom(e.to_string()))?; + visitor.visit_byte_buf(b) + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_f32(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + if let Ok(f) = f32::from_str(s) { + visitor.visit_f32(f) + } else { + self.deserialize_any(visitor) + } + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_f64(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + if let Ok(f) = f64::from_str(s) { + visitor.visit_f64(f) + } else { + self.deserialize_any(visitor) + } + } + _ => self.deserialize_any(visitor), + } + } + + forward_to_deserialize_any! { bool i8 i16 i32 i64 u8 u16 u32 u64 char str string unit_struct seq tuple_struct tuple map struct identifier ignored_any } } #[cfg(test)] mod test { - use serde::{de::DeserializeOwned, Deserialize, Serialize}; + + use serde::{ + de::{DeserializeOwned, Visitor}, + Deserialize, Serialize, + }; use crate::{ backends::plonky2::primitives::ec::{curve::Point, schnorr::SecretKey}, - middleware::{serialization::ValueSerializer, Params, RawValue}, + middleware::{serialization::ValueSerializer, Params, RawValue, TypedValue, Value}, }; #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] @@ -714,6 +782,64 @@ mod test { Mine, } + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + struct Inner { + ch: char, + b: bool, + } + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + struct Tuple(u8, u32); + + #[derive(PartialEq, Eq, Debug)] + struct Bytes(Vec); + + struct BytesVisitor; + + impl<'de> Visitor<'de> for BytesVisitor { + type Value = Bytes; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a byte buffer") + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: serde::de::Error, + { + Ok(Bytes(v)) + } + } + + impl Serialize for Bytes { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(&self.0) + } + } + + impl<'de> Deserialize<'de> for Bytes { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_byte_buf(BytesVisitor) + } + } + + #[derive(Serialize, Deserialize, Debug)] + struct Float(f32); + + impl PartialEq for Float { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 || (self.0.is_nan() && other.0.is_nan()) + } + } + + impl Eq for Float {} + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] struct FrogDesc { frog_id: i64, @@ -724,6 +850,14 @@ mod test { unit: (), sk: SecretKey, pk: Point, + fancy1: Fancy, + fancy2: Fancy, + inner: Inner, + tuple: Tuple, + bytes: Bytes, + float: Float, + inf: Float, + nan: Float, } fn test_roundtrip(t: T) { @@ -734,6 +868,24 @@ mod test { assert_eq!(t, out); } + fn test_preserved(t: T) + where + TypedValue: From, + { + let depth = Params::default().max_depth_mt_containers; + let ser = t.serialize(ValueSerializer::new(depth)).unwrap(); + let val = Value::from(TypedValue::from(t)); + println!("{}", serde_json::to_string(&ser).unwrap()); + println!("{}", serde_json::to_string(&val).unwrap()); + assert_eq!(ser, val); + } + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + enum Fancy { + B(i64, i64), + C { x: i64, y: Vec }, + } + #[test] fn test_frog_desc() { let seed_range = vec![ @@ -751,7 +903,31 @@ mod test { unit: (), sk, pk, + fancy1: Fancy::B(0, 1), + fancy2: Fancy::C { + x: 1, + y: vec![2, 3], + }, + inner: Inner { + ch: '\u{200b}', + b: true, + }, + tuple: Tuple(5, 6), + bytes: Bytes(b"abc".to_vec()), + float: Float(3.0), + inf: Float(f32::NEG_INFINITY), + nan: Float(f32::NAN), }; test_roundtrip(desc); } + + #[test] + fn test_pod_types() { + let raw = RawValue::default(); + let sk = SecretKey::new_rand(); + let pt = sk.public_key(); + test_preserved(raw); + test_preserved(sk); + test_preserved(pt); + } } From c3b557222c19c2caa82e4bce103cbb391615bd76 Mon Sep 17 00:00:00 2001 From: Daniel Gulotta Date: Mon, 25 Aug 2025 14:30:05 -0700 Subject: [PATCH 3/9] move serializer to separate file --- src/middleware/mod.rs | 1 + src/middleware/serialization.rs | 794 +------------------------------ src/middleware/serializer.rs | 797 ++++++++++++++++++++++++++++++++ 3 files changed, 801 insertions(+), 791 deletions(-) create mode 100644 src/middleware/serializer.rs diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index aadc7786..5fd26764 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -18,6 +18,7 @@ mod error; mod operation; mod pod_deserialization; pub mod serialization; +pub mod serializer; mod statement; use std::{any::Any, collections::HashMap, fmt}; diff --git a/src/middleware/serialization.rs b/src/middleware/serialization.rs index 8ede5f23..b309c3c8 100644 --- a/src/middleware/serialization.rs +++ b/src/middleware/serialization.rs @@ -1,41 +1,15 @@ use std::{ collections::{HashMap, HashSet}, fmt::Write, - str::FromStr, }; -use base64::{prelude::BASE64_STANDARD, Engine}; use plonky2::field::types::Field; -use serde::{ - de::{ - value::{ - MapAccessDeserializer, MapDeserializer, SeqDeserializer, StrDeserializer, - U32Deserializer, - }, - IntoDeserializer, Unexpected, - }, - forward_to_deserialize_any, - ser::{ - SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, - SerializeTupleStruct, SerializeTupleVariant, - }, - Deserialize, Serialize, Serializer, -}; +use serde::{Deserialize, Serialize, Serializer}; use super::{Key, Value}; -use crate::{ - backends::plonky2::{ - deserialize_bytes, - primitives::ec::{curve::Point, schnorr::SecretKey}, - serialize_bytes, - }, - middleware::{ - containers::{Array, Dictionary}, - RawValue, TypedValue, F, HASH_SIZE, VALUE_SIZE, - }, -}; +use crate::middleware::{F, HASH_SIZE, VALUE_SIZE}; -fn field_array_to_string(value: &[F; N]) -> String { +pub(super) fn field_array_to_string(value: &[F; N]) -> String { // `value` is little-endian in memory. We serialize it as a big-endian hex string // for human readability. value @@ -169,765 +143,3 @@ where sorted_values.sort_by_key(|v| v.raw()); serializer.serialize_newtype_struct("Set", &sorted_values) } - -#[derive(Clone, Copy)] -enum ValueSerializerState { - Default, - RawValue, - Point, - SecretKey, -} - -#[derive(Clone, Copy)] -pub struct ValueSerializer { - container_depth: usize, - state: ValueSerializerState, -} - -pub struct ValueSerializeSeq { - data: Vec, - container_depth: usize, -} - -pub struct ValueSerializeTupleVariant { - name: &'static str, - inner: ValueSerializeSeq, -} - -pub struct ValueSerializeMap { - kvs: HashMap, - next_key: Option, - container_depth: usize, -} - -pub struct ValueSerializeStructVariant { - name: &'static str, - inner: ValueSerializeMap, -} - -impl ValueSerializer { - pub fn new(container_depth: usize) -> Self { - Self { - container_depth, - state: ValueSerializerState::Default, - } - } -} - -impl Serializer for ValueSerializer { - type Ok = Value; - type Error = serde::de::value::Error; - type SerializeSeq = ValueSerializeSeq; - type SerializeTuple = ValueSerializeSeq; - type SerializeTupleStruct = ValueSerializeSeq; - type SerializeMap = ValueSerializeMap; - type SerializeStruct = ValueSerializeMap; - type SerializeTupleVariant = ValueSerializeTupleVariant; - type SerializeStructVariant = ValueSerializeStructVariant; - - fn serialize_bool(self, v: bool) -> Result { - Ok(Value::from(v)) - } - - fn serialize_i8(self, v: i8) -> Result { - self.serialize_i64(v as i64) - } - - fn serialize_i16(self, v: i16) -> Result { - self.serialize_i64(v as i64) - } - - fn serialize_i32(self, v: i32) -> Result { - self.serialize_i64(v as i64) - } - - fn serialize_i64(self, v: i64) -> Result { - Ok(Value::from(v)) - } - - fn serialize_u8(self, v: u8) -> Result { - self.serialize_i64(v as i64) - } - - fn serialize_u16(self, v: u16) -> Result { - self.serialize_i64(v as i64) - } - - fn serialize_u32(self, v: u32) -> Result { - self.serialize_i64(v as i64) - } - - fn serialize_u64(self, v: u64) -> Result { - self.serialize_i64(v as i64) - } - - fn serialize_f32(self, v: f32) -> Result { - let s = format!("{v:.8e}"); - self.serialize_str(&s) - } - - fn serialize_f64(self, v: f64) -> Result { - let s = format!("{v:.16e}"); - self.serialize_str(&s) - } - - fn serialize_str(self, v: &str) -> Result { - Ok(match self.state { - ValueSerializerState::RawValue => { - let arr = deserialize_value_tuple(StrDeserializer::new(v))?; - Value::from(RawValue(arr)) - } - ValueSerializerState::Point => { - Value::from(Point::from_str(v).map_err(serde::ser::Error::custom)?) - } - ValueSerializerState::SecretKey => { - let bytes = deserialize_bytes(v).map_err(serde::ser::Error::custom)?; - let sk = SecretKey::from_bytes(&bytes).map_err(serde::ser::Error::custom)?; - Value::from(sk) - } - _ => Value::from(v), - }) - } - - fn serialize_char(self, v: char) -> Result { - self.serialize_str(&String::from(v)) - } - - fn serialize_bytes(self, v: &[u8]) -> Result { - let s = BASE64_STANDARD.encode(v); - self.serialize_str(&s) - } - - fn serialize_seq(self, _len: Option) -> Result { - Ok(ValueSerializeSeq { - data: Vec::new(), - container_depth: self.container_depth, - }) - } - - fn serialize_tuple(self, len: usize) -> Result { - self.serialize_seq(Some(len)) - } - - fn serialize_unit(self) -> Result { - SerializeTuple::end(self.serialize_tuple(0)?) - } - - fn serialize_unit_struct(self, _name: &'static str) -> Result { - self.serialize_unit() - } - - fn serialize_unit_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - ) -> Result { - self.serialize_str(variant) - } - - fn serialize_newtype_struct( - mut self, - name: &'static str, - value: &T, - ) -> Result - where - T: ?Sized + Serialize, - { - println!("nt struct {name}"); - match name { - "RawValue" => self.state = ValueSerializerState::RawValue, - "Point" => self.state = ValueSerializerState::Point, - "SecretKey" => self.state = ValueSerializerState::SecretKey, - _ => (), - } - value.serialize(self) - } - - fn serialize_newtype_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - value: &T, - ) -> Result - where - T: ?Sized + Serialize, - { - let ser_value = value.serialize(self)?; - let mut map = HashMap::new(); - map.insert(Key::from(variant), ser_value); - Ok(Value::from( - Dictionary::new(self.container_depth, map).map_err(serde::ser::Error::custom)?, - )) - } - - fn serialize_some(self, value: &T) -> Result - where - T: ?Sized + Serialize, - { - self.serialize_newtype_variant("Option", 0, "Some", value) - } - - fn serialize_none(self) -> Result { - self.serialize_unit_variant("Option", 1, "None") - } - - fn serialize_tuple_struct( - self, - _name: &'static str, - len: usize, - ) -> Result { - self.serialize_seq(Some(len)) - } - - fn serialize_tuple_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - _len: usize, - ) -> Result { - Ok(ValueSerializeTupleVariant { - name: variant, - inner: ValueSerializeSeq { - data: Vec::new(), - container_depth: self.container_depth, - }, - }) - } - - fn serialize_map(self, _len: Option) -> Result { - Ok(ValueSerializeMap { - kvs: HashMap::new(), - container_depth: self.container_depth, - next_key: None, - }) - } - - fn serialize_struct( - self, - _name: &'static str, - len: usize, - ) -> Result { - self.serialize_map(Some(len)) - } - - fn serialize_struct_variant( - self, - _name: &'static str, - _variant_index: u32, - variant: &'static str, - len: usize, - ) -> Result { - Ok(ValueSerializeStructVariant { - name: variant, - inner: self.serialize_map(Some(len))?, - }) - } -} - -impl SerializeSeq for ValueSerializeSeq { - type Ok = ::Ok; - type Error = ::Error; - - fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> - where - T: ?Sized + Serialize, - { - self.data.push(value.serialize(ValueSerializer { - container_depth: self.container_depth, - state: ValueSerializerState::Default, - })?); - Ok(()) - } - - fn end(self) -> Result { - let arr = Array::new(self.container_depth, self.data).map_err(serde::de::Error::custom)?; - Ok(Value::from(arr)) - } -} - -impl SerializeTuple for ValueSerializeSeq { - type Ok = ::Ok; - type Error = ::Error; - - fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> - where - T: ?Sized + Serialize, - { - SerializeSeq::serialize_element(self, value) - } - - fn end(self) -> Result { - SerializeSeq::end(self) - } -} - -impl SerializeTupleStruct for ValueSerializeSeq { - type Ok = ::Ok; - type Error = ::Error; - - fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> - where - T: ?Sized + Serialize, - { - SerializeSeq::serialize_element(self, value) - } - - fn end(self) -> Result { - SerializeSeq::end(self) - } -} - -impl SerializeTupleVariant for ValueSerializeTupleVariant { - type Ok = ::Ok; - type Error = ::Error; - - fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> - where - T: ?Sized + Serialize, - { - SerializeSeq::serialize_element(&mut self.inner, value) - } - - fn end(self) -> Result { - let max_depth = self.inner.container_depth; - let arr = SerializeSeq::end(self.inner)?; - let mut map = HashMap::new(); - map.insert(Key::new(self.name.to_string()), arr); - let dict = Dictionary::new(max_depth, map).map_err(serde::de::Error::custom)?; - Ok(Value::from(dict)) - } -} - -impl SerializeMap for ValueSerializeMap { - type Ok = ::Ok; - type Error = ::Error; - - fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> - where - T: ?Sized + Serialize, - { - let key_ser = key.serialize(ValueSerializer { - container_depth: self.container_depth, - state: ValueSerializerState::Default, - })?; - if let TypedValue::String(s) = key_ser.typed() { - self.next_key = Some(Key::new(s.clone())); - Ok(()) - } else { - Err(serde::de::Error::invalid_value( - Unexpected::Other("non-string key in map"), - &"string", - )) - } - } - - fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> - where - T: ?Sized + Serialize, - { - let val_ser = value.serialize(ValueSerializer { - container_depth: self.container_depth, - state: ValueSerializerState::Default, - })?; - self.kvs.insert( - self.next_key - .take() - .expect("serialize_key should be called before serialize_value"), - val_ser, - ); - Ok(()) - } - - fn end(self) -> Result { - let dict = - Dictionary::new(self.container_depth, self.kvs).map_err(serde::ser::Error::custom)?; - Ok(Value::from(dict)) - } -} - -impl SerializeStruct for ValueSerializeMap { - type Ok = ::Ok; - type Error = ::Error; - - fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> - where - T: ?Sized + Serialize, - { - SerializeMap::serialize_entry(self, key, value) - } - - fn end(self) -> Result { - SerializeMap::end(self) - } -} - -impl SerializeStructVariant for ValueSerializeStructVariant { - type Ok = ::Ok; - type Error = ::Error; - - fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> - where - T: ?Sized + Serialize, - { - SerializeMap::serialize_entry(&mut self.inner, key, value) - } - - fn end(self) -> Result { - let depth = self.inner.container_depth; - let value = SerializeMap::end(self.inner)?; - let mut kvs = HashMap::new(); - kvs.insert(Key::new(self.name.to_string()), value); - let dict = Dictionary::new(depth, kvs).map_err(serde::ser::Error::custom)?; - Ok(Value::from(dict)) - } -} - -impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a TypedValue { - type Deserializer = Self; - fn into_deserializer(self) -> Self::Deserializer { - self - } -} - -impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a Value { - type Deserializer = &'a TypedValue; - fn into_deserializer(self) -> Self::Deserializer { - self.typed() - } -} - -impl<'a, 'de, E: serde::de::Error> IntoDeserializer<'de, E> for &'a Key { - type Deserializer = StrDeserializer<'a, E>; - fn into_deserializer(self) -> Self::Deserializer { - StrDeserializer::new(&self.name) - } -} - -impl<'de> serde::Deserializer<'de> for &TypedValue { - type Error = serde::de::value::Error; - - fn deserialize_any(self, visitor: V) -> Result - where - V: serde::de::Visitor<'de>, - { - match self { - TypedValue::Int(i) => visitor.visit_i64(*i), - TypedValue::Raw(v) => visitor.visit_string(field_array_to_string(&v.0)), - TypedValue::PublicKey(k) => visitor.visit_string(k.to_string()), - TypedValue::SecretKey(k) => visitor.visit_string(serialize_bytes(&k.as_bytes())), - TypedValue::Bool(b) => visitor.visit_bool(*b), - TypedValue::Array(a) => visitor.visit_seq(SeqDeserializer::new(a.array().iter())), - TypedValue::Set(s) => visitor.visit_seq(SeqDeserializer::new(s.set().iter())), - TypedValue::String(s) => visitor.visit_str(s), - TypedValue::PodId(i) => { - visitor.visit_seq(SeqDeserializer::new(i.0 .0.iter().map(|x| x.0))) - } - TypedValue::Dictionary(d) => visitor.visit_map(MapDeserializer::new(d.kvs().iter())), - } - } - - fn deserialize_enum( - self, - _name: &'static str, - _variants: &'static [&'static str], - visitor: V, - ) -> Result - where - V: serde::de::Visitor<'de>, - { - match self { - TypedValue::String(s) => visitor.visit_enum(StrDeserializer::new(s)), - TypedValue::Int(i) => { - if let Ok(u) = u32::try_from(*i) { - visitor.visit_enum(U32Deserializer::new(u)) - } else { - self.deserialize_any(visitor) - } - } - TypedValue::Dictionary(d) => visitor.visit_enum(MapAccessDeserializer::new( - MapDeserializer::new(d.kvs().iter()), - )), - _ => self.deserialize_any(visitor), - } - } - - fn deserialize_newtype_struct( - self, - _name: &'static str, - visitor: V, - ) -> Result - where - V: serde::de::Visitor<'de>, - { - visitor.visit_newtype_struct(self) - } - - fn deserialize_option(self, visitor: V) -> Result - where - V: serde::de::Visitor<'de>, - { - match self { - TypedValue::String(s) => { - if s == "None" { - visitor.visit_none() - } else { - self.deserialize_any(visitor) - } - } - TypedValue::Dictionary(d) => { - if let Ok(v) = d.get(&Key::from("Some")) { - visitor.visit_some(v.typed()) - } else { - self.deserialize_any(visitor) - } - } - _ => self.deserialize_any(visitor), - } - } - - fn deserialize_unit(self, visitor: V) -> Result - where - V: serde::de::Visitor<'de>, - { - match self { - TypedValue::Array(a) if a.array().is_empty() => visitor.visit_unit(), - _ => self.deserialize_any(visitor), - } - } - - fn deserialize_bytes(self, visitor: V) -> Result - where - V: serde::de::Visitor<'de>, - { - match self { - TypedValue::String(s) => { - let b = BASE64_STANDARD - .decode(s) - .map_err(|e| serde::de::Error::custom(e.to_string()))?; - visitor.visit_bytes(&b) - } - _ => self.deserialize_any(visitor), - } - } - - fn deserialize_byte_buf(self, visitor: V) -> Result - where - V: serde::de::Visitor<'de>, - { - match self { - TypedValue::String(s) => { - let b = BASE64_STANDARD - .decode(s) - .map_err(|e| serde::de::Error::custom(e.to_string()))?; - visitor.visit_byte_buf(b) - } - _ => self.deserialize_any(visitor), - } - } - - fn deserialize_f32(self, visitor: V) -> Result - where - V: serde::de::Visitor<'de>, - { - match self { - TypedValue::String(s) => { - if let Ok(f) = f32::from_str(s) { - visitor.visit_f32(f) - } else { - self.deserialize_any(visitor) - } - } - _ => self.deserialize_any(visitor), - } - } - - fn deserialize_f64(self, visitor: V) -> Result - where - V: serde::de::Visitor<'de>, - { - match self { - TypedValue::String(s) => { - if let Ok(f) = f64::from_str(s) { - visitor.visit_f64(f) - } else { - self.deserialize_any(visitor) - } - } - _ => self.deserialize_any(visitor), - } - } - - forward_to_deserialize_any! { bool i8 i16 i32 i64 u8 u16 u32 u64 char str string unit_struct seq tuple_struct tuple map struct identifier ignored_any } -} - -#[cfg(test)] -mod test { - - use serde::{ - de::{DeserializeOwned, Visitor}, - Deserialize, Serialize, - }; - - use crate::{ - backends::plonky2::primitives::ec::{curve::Point, schnorr::SecretKey}, - middleware::{serialization::ValueSerializer, Params, RawValue, TypedValue, Value}, - }; - - #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] - enum Method { - Search, - Mine, - } - - #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] - struct Inner { - ch: char, - b: bool, - } - - #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] - struct Tuple(u8, u32); - - #[derive(PartialEq, Eq, Debug)] - struct Bytes(Vec); - - struct BytesVisitor; - - impl<'de> Visitor<'de> for BytesVisitor { - type Value = Bytes; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(formatter, "a byte buffer") - } - - fn visit_byte_buf(self, v: Vec) -> Result - where - E: serde::de::Error, - { - Ok(Bytes(v)) - } - } - - impl Serialize for Bytes { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_bytes(&self.0) - } - } - - impl<'de> Deserialize<'de> for Bytes { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_byte_buf(BytesVisitor) - } - } - - #[derive(Serialize, Deserialize, Debug)] - struct Float(f32); - - impl PartialEq for Float { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 || (self.0.is_nan() && other.0.is_nan()) - } - } - - impl Eq for Float {} - - #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] - struct FrogDesc { - frog_id: i64, - name: String, - seed_range: Vec<(Method, RawValue)>, - option1: Option, - option2: Option, - unit: (), - sk: SecretKey, - pk: Point, - fancy1: Fancy, - fancy2: Fancy, - inner: Inner, - tuple: Tuple, - bytes: Bytes, - float: Float, - inf: Float, - nan: Float, - } - - fn test_roundtrip(t: T) { - let depth = Params::default().max_depth_mt_containers; - let val = t.serialize(ValueSerializer::new(depth)).unwrap(); - println!("{val:?}"); - let out: T = Deserialize::deserialize(val.typed()).unwrap(); - assert_eq!(t, out); - } - - fn test_preserved(t: T) - where - TypedValue: From, - { - let depth = Params::default().max_depth_mt_containers; - let ser = t.serialize(ValueSerializer::new(depth)).unwrap(); - let val = Value::from(TypedValue::from(t)); - println!("{}", serde_json::to_string(&ser).unwrap()); - println!("{}", serde_json::to_string(&val).unwrap()); - assert_eq!(ser, val); - } - - #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] - enum Fancy { - B(i64, i64), - C { x: i64, y: Vec }, - } - - #[test] - fn test_frog_desc() { - let seed_range = vec![ - (Method::Search, RawValue::default()), - (Method::Mine, RawValue::default()), - ]; - let sk = SecretKey::new_rand(); - let pk = sk.public_key(); - let desc = FrogDesc { - frog_id: 1, - name: "a frog".to_string(), - seed_range, - option1: Some(2), - option2: None, - unit: (), - sk, - pk, - fancy1: Fancy::B(0, 1), - fancy2: Fancy::C { - x: 1, - y: vec![2, 3], - }, - inner: Inner { - ch: '\u{200b}', - b: true, - }, - tuple: Tuple(5, 6), - bytes: Bytes(b"abc".to_vec()), - float: Float(3.0), - inf: Float(f32::NEG_INFINITY), - nan: Float(f32::NAN), - }; - test_roundtrip(desc); - } - - #[test] - fn test_pod_types() { - let raw = RawValue::default(); - let sk = SecretKey::new_rand(); - let pt = sk.public_key(); - test_preserved(raw); - test_preserved(sk); - test_preserved(pt); - } -} diff --git a/src/middleware/serializer.rs b/src/middleware/serializer.rs new file mode 100644 index 00000000..10745130 --- /dev/null +++ b/src/middleware/serializer.rs @@ -0,0 +1,797 @@ +use std::{ + collections::HashMap, + str::FromStr, +}; + +use base64::{prelude::BASE64_STANDARD, Engine}; +use serde::{ + de::{ + value::{ + MapAccessDeserializer, MapDeserializer, SeqDeserializer, StrDeserializer, + U32Deserializer, + }, + IntoDeserializer, Unexpected, + }, + forward_to_deserialize_any, + ser::{ + SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, + SerializeTupleStruct, SerializeTupleVariant, + }, Serialize, Serializer, +}; + +use super::{Key, Value}; +use crate::{ + backends::plonky2::{ + deserialize_bytes, + primitives::ec::{curve::Point, schnorr::SecretKey}, + serialize_bytes, + }, + middleware::{ + containers::{Array, Dictionary}, + field_array_to_string, + serialization::deserialize_value_tuple, + RawValue, TypedValue, + }, +}; + +#[derive(Clone, Copy)] +enum ValueSerializerState { + Default, + RawValue, + Point, + SecretKey, +} + +#[derive(Clone, Copy)] +pub struct ValueSerializer { + container_depth: usize, + state: ValueSerializerState, +} + +pub struct ValueSerializeSeq { + data: Vec, + container_depth: usize, +} + +pub struct ValueSerializeTupleVariant { + name: &'static str, + inner: ValueSerializeSeq, +} + +pub struct ValueSerializeMap { + kvs: HashMap, + next_key: Option, + container_depth: usize, +} + +pub struct ValueSerializeStructVariant { + name: &'static str, + inner: ValueSerializeMap, +} + +impl ValueSerializer { + pub fn new(container_depth: usize) -> Self { + Self { + container_depth, + state: ValueSerializerState::Default, + } + } +} + +impl Serializer for ValueSerializer { + type Ok = Value; + type Error = serde::de::value::Error; + type SerializeSeq = ValueSerializeSeq; + type SerializeTuple = ValueSerializeSeq; + type SerializeTupleStruct = ValueSerializeSeq; + type SerializeMap = ValueSerializeMap; + type SerializeStruct = ValueSerializeMap; + type SerializeTupleVariant = ValueSerializeTupleVariant; + type SerializeStructVariant = ValueSerializeStructVariant; + + fn serialize_bool(self, v: bool) -> Result { + Ok(Value::from(v)) + } + + fn serialize_i8(self, v: i8) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_i16(self, v: i16) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_i32(self, v: i32) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_i64(self, v: i64) -> Result { + Ok(Value::from(v)) + } + + fn serialize_u8(self, v: u8) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_u16(self, v: u16) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_u32(self, v: u32) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_u64(self, v: u64) -> Result { + self.serialize_i64(v as i64) + } + + fn serialize_f32(self, v: f32) -> Result { + let s = format!("{v:.8e}"); + self.serialize_str(&s) + } + + fn serialize_f64(self, v: f64) -> Result { + let s = format!("{v:.16e}"); + self.serialize_str(&s) + } + + fn serialize_str(self, v: &str) -> Result { + Ok(match self.state { + ValueSerializerState::RawValue => { + let arr = deserialize_value_tuple(StrDeserializer::new(v))?; + Value::from(RawValue(arr)) + } + ValueSerializerState::Point => { + Value::from(Point::from_str(v).map_err(serde::ser::Error::custom)?) + } + ValueSerializerState::SecretKey => { + let bytes = deserialize_bytes(v).map_err(serde::ser::Error::custom)?; + let sk = SecretKey::from_bytes(&bytes).map_err(serde::ser::Error::custom)?; + Value::from(sk) + } + _ => Value::from(v), + }) + } + + fn serialize_char(self, v: char) -> Result { + self.serialize_str(&String::from(v)) + } + + fn serialize_bytes(self, v: &[u8]) -> Result { + let s = BASE64_STANDARD.encode(v); + self.serialize_str(&s) + } + + fn serialize_seq(self, _len: Option) -> Result { + Ok(ValueSerializeSeq { + data: Vec::new(), + container_depth: self.container_depth, + }) + } + + fn serialize_tuple(self, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_unit(self) -> Result { + SerializeTuple::end(self.serialize_tuple(0)?) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + self.serialize_unit() + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + self.serialize_str(variant) + } + + fn serialize_newtype_struct( + mut self, + name: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + println!("nt struct {name}"); + match name { + "RawValue" => self.state = ValueSerializerState::RawValue, + "Point" => self.state = ValueSerializerState::Point, + "SecretKey" => self.state = ValueSerializerState::SecretKey, + _ => (), + } + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + let ser_value = value.serialize(self)?; + let mut map = HashMap::new(); + map.insert(Key::from(variant), ser_value); + Ok(Value::from( + Dictionary::new(self.container_depth, map).map_err(serde::ser::Error::custom)?, + )) + } + + fn serialize_some(self, value: &T) -> Result + where + T: ?Sized + Serialize, + { + self.serialize_newtype_variant("Option", 0, "Some", value) + } + + fn serialize_none(self) -> Result { + self.serialize_unit_variant("Option", 1, "None") + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + Ok(ValueSerializeTupleVariant { + name: variant, + inner: ValueSerializeSeq { + data: Vec::new(), + container_depth: self.container_depth, + }, + }) + } + + fn serialize_map(self, _len: Option) -> Result { + Ok(ValueSerializeMap { + kvs: HashMap::new(), + container_depth: self.container_depth, + next_key: None, + }) + } + + fn serialize_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_map(Some(len)) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + Ok(ValueSerializeStructVariant { + name: variant, + inner: self.serialize_map(Some(len))?, + }) + } +} + +impl SerializeSeq for ValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.data.push(value.serialize(ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + })?); + Ok(()) + } + + fn end(self) -> Result { + let arr = Array::new(self.container_depth, self.data).map_err(serde::de::Error::custom)?; + Ok(Value::from(arr)) + } +} + +impl SerializeTuple for ValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result { + SerializeSeq::end(self) + } +} + +impl SerializeTupleStruct for ValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result { + SerializeSeq::end(self) + } +} + +impl SerializeTupleVariant for ValueSerializeTupleVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeSeq::serialize_element(&mut self.inner, value) + } + + fn end(self) -> Result { + let max_depth = self.inner.container_depth; + let arr = SerializeSeq::end(self.inner)?; + let mut map = HashMap::new(); + map.insert(Key::new(self.name.to_string()), arr); + let dict = Dictionary::new(max_depth, map).map_err(serde::de::Error::custom)?; + Ok(Value::from(dict)) + } +} + +impl SerializeMap for ValueSerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let key_ser = key.serialize(ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + })?; + if let TypedValue::String(s) = key_ser.typed() { + self.next_key = Some(Key::new(s.clone())); + Ok(()) + } else { + Err(serde::de::Error::invalid_value( + Unexpected::Other("non-string key in map"), + &"string", + )) + } + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let val_ser = value.serialize(ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + })?; + self.kvs.insert( + self.next_key + .take() + .expect("serialize_key should be called before serialize_value"), + val_ser, + ); + Ok(()) + } + + fn end(self) -> Result { + let dict = + Dictionary::new(self.container_depth, self.kvs).map_err(serde::ser::Error::custom)?; + Ok(Value::from(dict)) + } +} + +impl SerializeStruct for ValueSerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeMap::serialize_entry(self, key, value) + } + + fn end(self) -> Result { + SerializeMap::end(self) + } +} + +impl SerializeStructVariant for ValueSerializeStructVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeMap::serialize_entry(&mut self.inner, key, value) + } + + fn end(self) -> Result { + let depth = self.inner.container_depth; + let value = SerializeMap::end(self.inner)?; + let mut kvs = HashMap::new(); + kvs.insert(Key::new(self.name.to_string()), value); + let dict = Dictionary::new(depth, kvs).map_err(serde::ser::Error::custom)?; + Ok(Value::from(dict)) + } +} + +impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a TypedValue { + type Deserializer = Self; + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a Value { + type Deserializer = &'a TypedValue; + fn into_deserializer(self) -> Self::Deserializer { + self.typed() + } +} + +impl<'a, 'de, E: serde::de::Error> IntoDeserializer<'de, E> for &'a Key { + type Deserializer = StrDeserializer<'a, E>; + fn into_deserializer(self) -> Self::Deserializer { + StrDeserializer::new(&self.name) + } +} + +impl<'de> serde::Deserializer<'de> for &TypedValue { + type Error = serde::de::value::Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::Int(i) => visitor.visit_i64(*i), + TypedValue::Raw(v) => visitor.visit_string(field_array_to_string(&v.0)), + TypedValue::PublicKey(k) => visitor.visit_string(k.to_string()), + TypedValue::SecretKey(k) => visitor.visit_string(serialize_bytes(&k.as_bytes())), + TypedValue::Bool(b) => visitor.visit_bool(*b), + TypedValue::Array(a) => visitor.visit_seq(SeqDeserializer::new(a.array().iter())), + TypedValue::Set(s) => visitor.visit_seq(SeqDeserializer::new(s.set().iter())), + TypedValue::String(s) => visitor.visit_str(s), + TypedValue::PodId(i) => { + visitor.visit_seq(SeqDeserializer::new(i.0 .0.iter().map(|x| x.0))) + } + TypedValue::Dictionary(d) => visitor.visit_map(MapDeserializer::new(d.kvs().iter())), + } + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => visitor.visit_enum(StrDeserializer::new(s)), + TypedValue::Int(i) => { + if let Ok(u) = u32::try_from(*i) { + visitor.visit_enum(U32Deserializer::new(u)) + } else { + self.deserialize_any(visitor) + } + } + TypedValue::Dictionary(d) => visitor.visit_enum(MapAccessDeserializer::new( + MapDeserializer::new(d.kvs().iter()), + )), + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + if s == "None" { + visitor.visit_none() + } else { + self.deserialize_any(visitor) + } + } + TypedValue::Dictionary(d) => { + if let Ok(v) = d.get(&Key::from("Some")) { + visitor.visit_some(v.typed()) + } else { + self.deserialize_any(visitor) + } + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::Array(a) if a.array().is_empty() => visitor.visit_unit(), + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + let b = BASE64_STANDARD + .decode(s) + .map_err(|e| serde::de::Error::custom(e.to_string()))?; + visitor.visit_bytes(&b) + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + let b = BASE64_STANDARD + .decode(s) + .map_err(|e| serde::de::Error::custom(e.to_string()))?; + visitor.visit_byte_buf(b) + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_f32(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + if let Ok(f) = f32::from_str(s) { + visitor.visit_f32(f) + } else { + self.deserialize_any(visitor) + } + } + _ => self.deserialize_any(visitor), + } + } + + fn deserialize_f64(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::String(s) => { + if let Ok(f) = f64::from_str(s) { + visitor.visit_f64(f) + } else { + self.deserialize_any(visitor) + } + } + _ => self.deserialize_any(visitor), + } + } + + forward_to_deserialize_any! { bool i8 i16 i32 i64 u8 u16 u32 u64 char str string unit_struct seq tuple_struct tuple map struct identifier ignored_any } +} + +#[cfg(test)] +mod test { + + use serde::{ + de::{DeserializeOwned, Visitor}, + Deserialize, Serialize, + }; + + use crate::{ + backends::plonky2::primitives::ec::{curve::Point, schnorr::SecretKey}, + middleware::{serializer::ValueSerializer, Params, RawValue, TypedValue, Value}, + }; + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + enum Method { + Search, + Mine, + } + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + struct Inner { + ch: char, + b: bool, + } + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + struct Tuple(u8, u32); + + #[derive(PartialEq, Eq, Debug)] + struct Bytes(Vec); + + struct BytesVisitor; + + impl<'de> Visitor<'de> for BytesVisitor { + type Value = Bytes; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a byte buffer") + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: serde::de::Error, + { + Ok(Bytes(v)) + } + } + + impl Serialize for Bytes { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(&self.0) + } + } + + impl<'de> Deserialize<'de> for Bytes { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_byte_buf(BytesVisitor) + } + } + + #[derive(Serialize, Deserialize, Debug)] + struct Float(f32); + + impl PartialEq for Float { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 || (self.0.is_nan() && other.0.is_nan()) + } + } + + impl Eq for Float {} + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + struct FrogDesc { + frog_id: i64, + name: String, + seed_range: Vec<(Method, RawValue)>, + option1: Option, + option2: Option, + unit: (), + sk: SecretKey, + pk: Point, + fancy1: Fancy, + fancy2: Fancy, + inner: Inner, + tuple: Tuple, + bytes: Bytes, + float: Float, + inf: Float, + nan: Float, + } + + fn test_roundtrip(t: T) { + let depth = Params::default().max_depth_mt_containers; + let val = t.serialize(ValueSerializer::new(depth)).unwrap(); + println!("{val:?}"); + let out: T = Deserialize::deserialize(val.typed()).unwrap(); + assert_eq!(t, out); + } + + fn test_preserved(t: T) + where + TypedValue: From, + { + let depth = Params::default().max_depth_mt_containers; + let ser = t.serialize(ValueSerializer::new(depth)).unwrap(); + let val = Value::from(TypedValue::from(t)); + println!("{}", serde_json::to_string(&ser).unwrap()); + println!("{}", serde_json::to_string(&val).unwrap()); + assert_eq!(ser, val); + } + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + enum Fancy { + B(i64, i64), + C { x: i64, y: Vec }, + } + + #[test] + fn test_frog_desc() { + let seed_range = vec![ + (Method::Search, RawValue::default()), + (Method::Mine, RawValue::default()), + ]; + let sk = SecretKey::new_rand(); + let pk = sk.public_key(); + let desc = FrogDesc { + frog_id: 1, + name: "a frog".to_string(), + seed_range, + option1: Some(2), + option2: None, + unit: (), + sk, + pk, + fancy1: Fancy::B(0, 1), + fancy2: Fancy::C { + x: 1, + y: vec![2, 3], + }, + inner: Inner { + ch: '\u{200b}', + b: true, + }, + tuple: Tuple(5, 6), + bytes: Bytes(b"abc".to_vec()), + float: Float(3.0), + inf: Float(f32::NEG_INFINITY), + nan: Float(f32::NAN), + }; + test_roundtrip(desc); + } + + #[test] + fn test_pod_types() { + let raw = RawValue::default(); + let sk = SecretKey::new_rand(); + let pt = sk.public_key(); + test_preserved(raw); + test_preserved(sk); + test_preserved(pt); + } +} From 51eb101c2aa8ec7aca5cf1ea0491fbfdd023361e Mon Sep 17 00:00:00 2001 From: Daniel Gulotta Date: Thu, 4 Sep 2025 14:59:53 -0700 Subject: [PATCH 4/9] Dictionary deserializer --- src/backends/plonky2/primitives/ec/curve.rs | 2 +- src/backends/plonky2/primitives/ec/schnorr.rs | 2 +- src/middleware/serializer.rs | 78 ++++++++++++++----- 3 files changed, 60 insertions(+), 22 deletions(-) diff --git a/src/backends/plonky2/primitives/ec/curve.rs b/src/backends/plonky2/primitives/ec/curve.rs index fa0f239b..ed94908c 100644 --- a/src/backends/plonky2/primitives/ec/curve.rs +++ b/src/backends/plonky2/primitives/ec/curve.rs @@ -161,7 +161,7 @@ impl Serialize for Point { S: Serializer, { let point_b58 = format!("{}", self); - serializer.serialize_newtype_struct("Point", &point_b58) + serializer.serialize_newtype_struct("pod2::Point", &point_b58) } } diff --git a/src/backends/plonky2/primitives/ec/schnorr.rs b/src/backends/plonky2/primitives/ec/schnorr.rs index e8e81ce0..90aa6c9d 100644 --- a/src/backends/plonky2/primitives/ec/schnorr.rs +++ b/src/backends/plonky2/primitives/ec/schnorr.rs @@ -244,7 +244,7 @@ impl Serialize for SecretKey { S: Serializer, { let sk_b64 = serialize_bytes(&self.as_bytes()); - serializer.serialize_newtype_struct("SecretKey", &sk_b64) + serializer.serialize_newtype_struct("pod2::SecretKey", &sk_b64) } } diff --git a/src/middleware/serializer.rs b/src/middleware/serializer.rs index 10745130..b018d3c3 100644 --- a/src/middleware/serializer.rs +++ b/src/middleware/serializer.rs @@ -1,7 +1,4 @@ -use std::{ - collections::HashMap, - str::FromStr, -}; +use std::{collections::HashMap, str::FromStr}; use base64::{prelude::BASE64_STANDARD, Engine}; use serde::{ @@ -16,7 +13,8 @@ use serde::{ ser::{ SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, SerializeTupleStruct, SerializeTupleVariant, - }, Serialize, Serializer, + }, + Serialize, Serializer, }; use super::{Key, Value}; @@ -26,6 +24,7 @@ use crate::{ primitives::ec::{curve::Point, schnorr::SecretKey}, serialize_bytes, }, + frontend::SignedDict, middleware::{ containers::{Array, Dictionary}, field_array_to_string, @@ -198,11 +197,10 @@ impl Serializer for ValueSerializer { where T: ?Sized + Serialize, { - println!("nt struct {name}"); match name { "RawValue" => self.state = ValueSerializerState::RawValue, - "Point" => self.state = ValueSerializerState::Point, - "SecretKey" => self.state = ValueSerializerState::SecretKey, + "pod2::Point" => self.state = ValueSerializerState::Point, + "pod2::SecretKey" => self.state = ValueSerializerState::SecretKey, _ => (), } value.serialize(self) @@ -470,6 +468,54 @@ impl<'a, 'de, E: serde::de::Error> IntoDeserializer<'de, E> for &'a Key { } } +impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a Dictionary { + type Deserializer = Self; + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a SignedDict { + type Deserializer = &'a Dictionary; + fn into_deserializer(self) -> Self::Deserializer { + &self.dict + } +} + +impl<'de> serde::Deserializer<'de> for &Dictionary { + type Error = serde::de::value::Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_map(MapDeserializer::new(self.kvs().iter())) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_enum(MapAccessDeserializer::new(MapDeserializer::new( + self.kvs().iter(), + ))) + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_unit() + } + + forward_to_deserialize_any! { bool i8 i16 i32 i64 f32 f64 u8 u16 u32 u64 char str bytes byte_buf string option unit unit_struct seq tuple_struct tuple map newtype_struct struct identifier } +} + impl<'de> serde::Deserializer<'de> for &TypedValue { type Error = serde::de::value::Error; @@ -486,17 +532,14 @@ impl<'de> serde::Deserializer<'de> for &TypedValue { TypedValue::Array(a) => visitor.visit_seq(SeqDeserializer::new(a.array().iter())), TypedValue::Set(s) => visitor.visit_seq(SeqDeserializer::new(s.set().iter())), TypedValue::String(s) => visitor.visit_str(s), - TypedValue::PodId(i) => { - visitor.visit_seq(SeqDeserializer::new(i.0 .0.iter().map(|x| x.0))) - } - TypedValue::Dictionary(d) => visitor.visit_map(MapDeserializer::new(d.kvs().iter())), + TypedValue::Dictionary(d) => d.deserialize_any(visitor), } } fn deserialize_enum( self, - _name: &'static str, - _variants: &'static [&'static str], + name: &'static str, + variants: &'static [&'static str], visitor: V, ) -> Result where @@ -511,9 +554,7 @@ impl<'de> serde::Deserializer<'de> for &TypedValue { self.deserialize_any(visitor) } } - TypedValue::Dictionary(d) => visitor.visit_enum(MapAccessDeserializer::new( - MapDeserializer::new(d.kvs().iter()), - )), + TypedValue::Dictionary(d) => d.deserialize_enum(name, variants, visitor), _ => self.deserialize_any(visitor), } } @@ -727,7 +768,6 @@ mod test { fn test_roundtrip(t: T) { let depth = Params::default().max_depth_mt_containers; let val = t.serialize(ValueSerializer::new(depth)).unwrap(); - println!("{val:?}"); let out: T = Deserialize::deserialize(val.typed()).unwrap(); assert_eq!(t, out); } @@ -739,8 +779,6 @@ mod test { let depth = Params::default().max_depth_mt_containers; let ser = t.serialize(ValueSerializer::new(depth)).unwrap(); let val = Value::from(TypedValue::from(t)); - println!("{}", serde_json::to_string(&ser).unwrap()); - println!("{}", serde_json::to_string(&val).unwrap()); assert_eq!(ser, val); } From 12db2f64c986be0ad2695c7c0e05430a75b7ab76 Mon Sep 17 00:00:00 2001 From: Daniel Gulotta Date: Tue, 9 Sep 2025 10:33:31 -0700 Subject: [PATCH 5/9] more tests, prefix RawValue, type changes --- src/middleware/basetypes.rs | 3 + src/middleware/serializer.rs | 128 ++++++++++++++++++++++++----------- 2 files changed, 90 insertions(+), 41 deletions(-) diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index 97ed113f..b4d363e5 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -59,6 +59,9 @@ pub const SELF_ID_HASH: Hash = Hash([F(0x5), F(0xe), F(0x1), F(0xf)]); pub const EMPTY_HASH: Hash = Hash([F::ZERO, F::ZERO, F::ZERO, F::ZERO]); #[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +// use pod2:: prefix to help ValueSerializer recognize RawValue +// most serializers will ignore the name +#[serde(rename = "pod2::RawValue")] pub struct RawValue( #[serde( serialize_with = "serialize_value_tuple", diff --git a/src/middleware/serializer.rs b/src/middleware/serializer.rs index b018d3c3..e6c2760a 100644 --- a/src/middleware/serializer.rs +++ b/src/middleware/serializer.rs @@ -1,4 +1,7 @@ -use std::{collections::HashMap, str::FromStr}; +use std::{ + collections::{HashMap, HashSet}, + str::FromStr, +}; use base64::{prelude::BASE64_STANDARD, Engine}; use serde::{ @@ -26,7 +29,7 @@ use crate::{ }, frontend::SignedDict, middleware::{ - containers::{Array, Dictionary}, + containers::{Array, Dictionary, Set}, field_array_to_string, serialization::deserialize_value_tuple, RawValue, TypedValue, @@ -173,7 +176,7 @@ impl Serializer for ValueSerializer { } fn serialize_unit(self) -> Result { - SerializeTuple::end(self.serialize_tuple(0)?) + Ok(Value::from(false)) } fn serialize_unit_struct(self, _name: &'static str) -> Result { @@ -198,7 +201,7 @@ impl Serializer for ValueSerializer { T: ?Sized + Serialize, { match name { - "RawValue" => self.state = ValueSerializerState::RawValue, + "pod2::RawValue" => self.state = ValueSerializerState::RawValue, "pod2::Point" => self.state = ValueSerializerState::Point, "pod2::SecretKey" => self.state = ValueSerializerState::SecretKey, _ => (), @@ -228,11 +231,17 @@ impl Serializer for ValueSerializer { where T: ?Sized + Serialize, { - self.serialize_newtype_variant("Option", 0, "Some", value) + let value_serialized = value.serialize(self)?; + let mut hash_set = HashSet::new(); + hash_set.insert(value_serialized); + let set = Set::new(self.container_depth, hash_set).map_err(serde::ser::Error::custom)?; + Ok(Value::from(set)) } fn serialize_none(self) -> Result { - self.serialize_unit_variant("Option", 1, "None") + let set = + Set::new(self.container_depth, HashSet::new()).map_err(serde::ser::Error::custom)?; + Ok(Value::from(set)) } fn serialize_tuple_struct( @@ -575,20 +584,10 @@ impl<'de> serde::Deserializer<'de> for &TypedValue { V: serde::de::Visitor<'de>, { match self { - TypedValue::String(s) => { - if s == "None" { - visitor.visit_none() - } else { - self.deserialize_any(visitor) - } - } - TypedValue::Dictionary(d) => { - if let Ok(v) = d.get(&Key::from("Some")) { - visitor.visit_some(v.typed()) - } else { - self.deserialize_any(visitor) - } - } + TypedValue::Set(s) if s.set().len() <= 1 => match s.set().iter().next() { + Some(x) => visitor.visit_some(x.typed()), + None => visitor.visit_none(), + }, _ => self.deserialize_any(visitor), } } @@ -597,10 +596,7 @@ impl<'de> serde::Deserializer<'de> for &TypedValue { where V: serde::de::Visitor<'de>, { - match self { - TypedValue::Array(a) if a.array().is_empty() => visitor.visit_unit(), - _ => self.deserialize_any(visitor), - } + visitor.visit_unit() } fn deserialize_bytes(self, visitor: V) -> Result @@ -671,6 +667,8 @@ impl<'de> serde::Deserializer<'de> for &TypedValue { #[cfg(test)] mod test { + use std::collections::HashMap; + use serde::{ de::{DeserializeOwned, Visitor}, Deserialize, Serialize, @@ -681,22 +679,22 @@ mod test { middleware::{serializer::ValueSerializer, Params, RawValue, TypedValue, Value}, }; - #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] enum Method { Search, Mine, } - #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] struct Inner { ch: char, b: bool, } - #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] struct Tuple(u8, u32); - #[derive(PartialEq, Eq, Debug)] + #[derive(PartialEq, Eq, Debug, Clone)] struct Bytes(Vec); struct BytesVisitor; @@ -734,7 +732,7 @@ mod test { } } - #[derive(Serialize, Deserialize, Debug)] + #[derive(Serialize, Deserialize, Debug, Clone)] struct Float(f32); impl PartialEq for Float { @@ -745,9 +743,15 @@ mod test { impl Eq for Float {} - #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] - struct FrogDesc { - frog_id: i64, + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] + enum Fancy { + B(i64, i64), + C { x: i64, y: Vec }, + } + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] + struct ComplicatedStruct { + id: i64, name: String, seed_range: Vec<(Method, RawValue)>, option1: Option, @@ -763,6 +767,7 @@ mod test { float: Float, inf: Float, nan: Float, + map: HashMap, } fn test_roundtrip(t: T) { @@ -772,7 +777,19 @@ mod test { assert_eq!(t, out); } - fn test_preserved(t: T) + fn test_roundtrip_dict(t: T) { + let depth = Params::default().max_depth_mt_containers; + let val = t.serialize(ValueSerializer::new(depth)).unwrap(); + match val.typed() { + TypedValue::Dictionary(d) => { + let out: T = Deserialize::deserialize(d).unwrap(); + assert_eq!(t, out); + } + _ => panic!("Expected value to be serialized to a dict"), + } + } + + fn test_preserved_ser(t: T) where TypedValue: From, { @@ -782,22 +799,39 @@ mod test { assert_eq!(ser, val); } - #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] - enum Fancy { - B(i64, i64), - C { x: i64, y: Vec }, + fn test_preserved_de(t: T) + where + TypedValue: From, + { + let de = T::deserialize(&TypedValue::from(t.clone())).unwrap(); + assert_eq!(de, t); + } + + fn test_preserved(t: T) + where + TypedValue: From, + { + test_preserved_ser(t.clone()); + test_preserved_de(t); + } + + fn a_hash_map() -> HashMap { + let mut map = HashMap::new(); + map.insert("a".to_string(), 1); + map.insert("b".to_string(), 2); + map } #[test] - fn test_frog_desc() { + fn test_complicated_struct() { let seed_range = vec![ (Method::Search, RawValue::default()), (Method::Mine, RawValue::default()), ]; let sk = SecretKey::new_rand(); let pk = sk.public_key(); - let desc = FrogDesc { - frog_id: 1, + let desc = ComplicatedStruct { + id: 1, name: "a frog".to_string(), seed_range, option1: Some(2), @@ -819,8 +853,20 @@ mod test { float: Float(3.0), inf: Float(f32::NEG_INFINITY), nan: Float(f32::NAN), + map: a_hash_map(), }; - test_roundtrip(desc); + test_roundtrip(desc.clone()); + test_roundtrip_dict(desc); + } + + #[test] + fn test_dict_deserialization() { + test_roundtrip_dict(Fancy::B(0, 1)); + test_roundtrip_dict(Fancy::C { + x: 1, + y: vec![2, 3], + }); + test_roundtrip_dict(a_hash_map()); } #[test] From 39d601aea2f2e94d59a43520b033615a1980ef68 Mon Sep 17 00:00:00 2001 From: Daniel Gulotta Date: Tue, 9 Sep 2025 12:05:28 -0700 Subject: [PATCH 6/9] Dictionary serializer --- src/middleware/serializer.rs | 411 +++++++++++++++++++++++++++++------ 1 file changed, 350 insertions(+), 61 deletions(-) diff --git a/src/middleware/serializer.rs b/src/middleware/serializer.rs index e6c2760a..cb603257 100644 --- a/src/middleware/serializer.rs +++ b/src/middleware/serializer.rs @@ -14,8 +14,8 @@ use serde::{ }, forward_to_deserialize_any, ser::{ - SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, - SerializeTupleStruct, SerializeTupleVariant, + Impossible, SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, + SerializeTuple, SerializeTupleStruct, SerializeTupleVariant, }, Serialize, Serializer, }; @@ -50,23 +50,34 @@ pub struct ValueSerializer { state: ValueSerializerState, } +#[derive(Clone, Copy)] +pub struct DictionarySerializer { + container_depth: usize, +} + pub struct ValueSerializeSeq { data: Vec, container_depth: usize, } -pub struct ValueSerializeTupleVariant { +pub struct ValueSerializeTupleVariant(DictionarySerializeTupleVariant); + +pub struct ValueSerializeMap(DictionarySerializeMap); + +pub struct ValueSerializeStructVariant(DictionarySerializeStructVariant); + +pub struct DictionarySerializeTupleVariant { name: &'static str, inner: ValueSerializeSeq, } -pub struct ValueSerializeMap { +pub struct DictionarySerializeMap { kvs: HashMap, next_key: Option, container_depth: usize, } -pub struct ValueSerializeStructVariant { +pub struct DictionarySerializeStructVariant { name: &'static str, inner: ValueSerializeMap, } @@ -78,6 +89,25 @@ impl ValueSerializer { state: ValueSerializerState::Default, } } + + fn dictionary_serializer(self) -> DictionarySerializer { + DictionarySerializer { + container_depth: self.container_depth, + } + } +} + +impl DictionarySerializer { + pub fn new(container_depth: usize) -> Self { + Self { container_depth } + } + + fn value_serializer(self) -> ValueSerializer { + ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + } + } } impl Serializer for ValueSerializer { @@ -211,20 +241,17 @@ impl Serializer for ValueSerializer { fn serialize_newtype_variant( self, - _name: &'static str, - _variant_index: u32, + name: &'static str, + variant_index: u32, variant: &'static str, value: &T, ) -> Result where T: ?Sized + Serialize, { - let ser_value = value.serialize(self)?; - let mut map = HashMap::new(); - map.insert(Key::from(variant), ser_value); - Ok(Value::from( - Dictionary::new(self.container_depth, map).map_err(serde::ser::Error::custom)?, - )) + self.dictionary_serializer() + .serialize_newtype_variant(name, variant_index, variant, value) + .map(Value::from) } fn serialize_some(self, value: &T) -> Result @@ -254,26 +281,20 @@ impl Serializer for ValueSerializer { fn serialize_tuple_variant( self, - _name: &'static str, - _variant_index: u32, + name: &'static str, + variant_index: u32, variant: &'static str, - _len: usize, + len: usize, ) -> Result { - Ok(ValueSerializeTupleVariant { - name: variant, - inner: ValueSerializeSeq { - data: Vec::new(), - container_depth: self.container_depth, - }, - }) + self.dictionary_serializer() + .serialize_tuple_variant(name, variant_index, variant, len) + .map(ValueSerializeTupleVariant) } - fn serialize_map(self, _len: Option) -> Result { - Ok(ValueSerializeMap { - kvs: HashMap::new(), - container_depth: self.container_depth, - next_key: None, - }) + fn serialize_map(self, len: Option) -> Result { + self.dictionary_serializer() + .serialize_map(len) + .map(ValueSerializeMap) } fn serialize_struct( @@ -286,15 +307,14 @@ impl Serializer for ValueSerializer { fn serialize_struct_variant( self, - _name: &'static str, - _variant_index: u32, + name: &'static str, + variant_index: u32, variant: &'static str, len: usize, ) -> Result { - Ok(ValueSerializeStructVariant { - name: variant, - inner: self.serialize_map(Some(len))?, - }) + self.dictionary_serializer() + .serialize_struct_variant(name, variant_index, variant, len) + .map(ValueSerializeStructVariant) } } @@ -355,6 +375,261 @@ impl SerializeTupleVariant for ValueSerializeTupleVariant { type Ok = ::Ok; type Error = ::Error; + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(value) + } + + fn end(self) -> Result { + SerializeTupleVariant::end(self.0).map(Value::from) + } +} + +impl SerializeMap for ValueSerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_key(key) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_value(value) + } + + fn end(self) -> Result { + SerializeMap::end(self.0).map(Value::from) + } +} + +impl SerializeStruct for ValueSerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeMap::serialize_entry(self, key, value) + } + + fn end(self) -> Result { + SerializeMap::end(self) + } +} + +impl SerializeStructVariant for ValueSerializeStructVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(key, value) + } + + fn end(self) -> Result { + SerializeStructVariant::end(self.0).map(Value::from) + } +} + +impl Serializer for DictionarySerializer { + type Ok = Dictionary; + type Error = serde::de::value::Error; + type SerializeSeq = Impossible; + type SerializeTuple = Impossible; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = DictionarySerializeTupleVariant; + type SerializeMap = DictionarySerializeMap; + type SerializeStruct = DictionarySerializeMap; + type SerializeStructVariant = DictionarySerializeStructVariant; + + fn serialize_bool(self, _v: bool) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_i8(self, _v: i8) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_i16(self, _v: i16) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_i32(self, _v: i32) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_i64(self, _v: i64) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_u8(self, _v: u8) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_u16(self, _v: u16) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_u32(self, _v: u32) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_u64(self, _v: u64) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_f32(self, _v: f32) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_f64(self, _v: f64) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_char(self, _v: char) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_str(self, _v: &str) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_none(self) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_some(self, _value: &T) -> Result + where + T: ?Sized + Serialize, + { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_unit(self) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + let ser_value = value.serialize(self.value_serializer())?; + let mut map = HashMap::new(); + map.insert(Key::from(variant), ser_value); + Dictionary::new(self.container_depth, map).map_err(serde::ser::Error::custom) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + Ok(DictionarySerializeTupleVariant { + name: variant, + inner: ValueSerializeSeq { + data: Vec::new(), + container_depth: self.container_depth, + }, + }) + } + + fn serialize_map(self, _len: Option) -> Result { + Ok(DictionarySerializeMap { + kvs: HashMap::new(), + container_depth: self.container_depth, + next_key: None, + }) + } + + fn serialize_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_map(Some(len)) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + Ok(DictionarySerializeStructVariant { + name: variant, + inner: ValueSerializeMap(self.serialize_map(Some(len))?), + }) + } +} + +impl SerializeTupleVariant for DictionarySerializeTupleVariant { + type Ok = ::Ok; + type Error = ::Error; + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> where T: ?Sized + Serialize, @@ -368,13 +643,13 @@ impl SerializeTupleVariant for ValueSerializeTupleVariant { let mut map = HashMap::new(); map.insert(Key::new(self.name.to_string()), arr); let dict = Dictionary::new(max_depth, map).map_err(serde::de::Error::custom)?; - Ok(Value::from(dict)) + Ok(dict) } } -impl SerializeMap for ValueSerializeMap { - type Ok = ::Ok; - type Error = ::Error; +impl SerializeMap for DictionarySerializeMap { + type Ok = ::Ok; + type Error = ::Error; fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> where @@ -415,13 +690,13 @@ impl SerializeMap for ValueSerializeMap { fn end(self) -> Result { let dict = Dictionary::new(self.container_depth, self.kvs).map_err(serde::ser::Error::custom)?; - Ok(Value::from(dict)) + Ok(dict) } } -impl SerializeStruct for ValueSerializeMap { - type Ok = ::Ok; - type Error = ::Error; +impl SerializeStruct for DictionarySerializeMap { + type Ok = ::Ok; + type Error = ::Error; fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> where @@ -435,9 +710,9 @@ impl SerializeStruct for ValueSerializeMap { } } -impl SerializeStructVariant for ValueSerializeStructVariant { - type Ok = ::Ok; - type Error = ::Error; +impl SerializeStructVariant for DictionarySerializeStructVariant { + type Ok = ::Ok; + type Error = ::Error; fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> where @@ -447,12 +722,12 @@ impl SerializeStructVariant for ValueSerializeStructVariant { } fn end(self) -> Result { - let depth = self.inner.container_depth; + let depth = self.inner.0.container_depth; let value = SerializeMap::end(self.inner)?; let mut kvs = HashMap::new(); kvs.insert(Key::new(self.name.to_string()), value); let dict = Dictionary::new(depth, kvs).map_err(serde::ser::Error::custom)?; - Ok(Value::from(dict)) + Ok(dict) } } @@ -676,7 +951,10 @@ mod test { use crate::{ backends::plonky2::primitives::ec::{curve::Point, schnorr::SecretKey}, - middleware::{serializer::ValueSerializer, Params, RawValue, TypedValue, Value}, + middleware::{ + serializer::{DictionarySerializer, ValueSerializer}, + Params, RawValue, TypedValue, Value, + }, }; #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] @@ -778,15 +1056,26 @@ mod test { } fn test_roundtrip_dict(t: T) { + let depth = Params::default().max_depth_mt_containers; + let dict = t.serialize(DictionarySerializer::new(depth)).unwrap(); + let out: T = Deserialize::deserialize(&dict).unwrap(); + assert_eq!(t, out); + } + + fn test_dict_consistency(t: T) { let depth = Params::default().max_depth_mt_containers; let val = t.serialize(ValueSerializer::new(depth)).unwrap(); - match val.typed() { - TypedValue::Dictionary(d) => { - let out: T = Deserialize::deserialize(d).unwrap(); - assert_eq!(t, out); - } - _ => panic!("Expected value to be serialized to a dict"), - } + let dict = t.serialize(DictionarySerializer::new(depth)).unwrap(); + assert_eq!(val, Value::from(dict)); + } + + fn test_dict_consistency_and_roundtrip< + T: Serialize + DeserializeOwned + Eq + Clone + std::fmt::Debug, + >( + t: T, + ) { + test_dict_consistency(t.clone()); + test_roundtrip_dict(t) } fn test_preserved_ser(t: T) @@ -856,17 +1145,17 @@ mod test { map: a_hash_map(), }; test_roundtrip(desc.clone()); - test_roundtrip_dict(desc); + test_dict_consistency_and_roundtrip(desc); } #[test] - fn test_dict_deserialization() { - test_roundtrip_dict(Fancy::B(0, 1)); - test_roundtrip_dict(Fancy::C { + fn test_dict_serialization() { + test_dict_consistency_and_roundtrip(Fancy::B(0, 1)); + test_dict_consistency_and_roundtrip(Fancy::C { x: 1, y: vec![2, 3], }); - test_roundtrip_dict(a_hash_map()); + test_dict_consistency_and_roundtrip(a_hash_map()); } #[test] From 093d086f20c9064ed3aea6fe96a5f06d328c2ff1 Mon Sep 17 00:00:00 2001 From: Daniel Gulotta Date: Fri, 12 Sep 2025 15:31:14 -0700 Subject: [PATCH 7/9] flatten option inside struct --- Cargo.toml | 1 + src/middleware/serializer.rs | 551 ++++++++++++++++++++++++++++++----- 2 files changed, 475 insertions(+), 77 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 299dea35..66b064b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ serde_bytes = "0.11" serde_arrays = "0.2.0" sha2 = { version = "0.10.9" } rand_chacha = "0.3.1" +paste = "1.0.15" # Uncomment for debugging with https://github.com/ed255/plonky2/ at branch `feat/debug`. The repo directory needs to be checked out next to the pod2 repo directory. # [patch."https://github.com/0xPARC/plonky2"] diff --git a/src/middleware/serializer.rs b/src/middleware/serializer.rs index cb603257..476c1610 100644 --- a/src/middleware/serializer.rs +++ b/src/middleware/serializer.rs @@ -4,6 +4,7 @@ use std::{ }; use base64::{prelude::BASE64_STANDARD, Engine}; +use paste::paste; use serde::{ de::{ value::{ @@ -50,6 +51,9 @@ pub struct ValueSerializer { state: ValueSerializerState, } +#[derive(Clone, Copy)] +struct OptionValueSerializer(ValueSerializer); + #[derive(Clone, Copy)] pub struct DictionarySerializer { container_depth: usize, @@ -82,6 +86,11 @@ pub struct DictionarySerializeStructVariant { inner: ValueSerializeMap, } +struct OptionValueSerializeSeq(ValueSerializeSeq); +struct OptionValueSerializeMap(ValueSerializeMap); +struct OptionValueSerializeStructVariant(ValueSerializeStructVariant); +struct OptionValueSerializeTupleVariant(ValueSerializeTupleVariant); + impl ValueSerializer { pub fn new(container_depth: usize) -> Self { Self { @@ -95,6 +104,15 @@ impl ValueSerializer { container_depth: self.container_depth, } } + + fn update_state_newtype(&mut self, name: &'static str) { + match name { + "pod2::RawValue" => self.state = ValueSerializerState::RawValue, + "pod2::Point" => self.state = ValueSerializerState::Point, + "pod2::SecretKey" => self.state = ValueSerializerState::SecretKey, + _ => (), + } + } } impl DictionarySerializer { @@ -230,12 +248,7 @@ impl Serializer for ValueSerializer { where T: ?Sized + Serialize, { - match name { - "pod2::RawValue" => self.state = ValueSerializerState::RawValue, - "pod2::Point" => self.state = ValueSerializerState::Point, - "pod2::SecretKey" => self.state = ValueSerializerState::SecretKey, - _ => (), - } + self.update_state_newtype(name); value.serialize(self) } @@ -418,11 +431,11 @@ impl SerializeStruct for ValueSerializeMap { where T: ?Sized + Serialize, { - SerializeMap::serialize_entry(self, key, value) + SerializeStruct::serialize_field(&mut self.0, key, value) } fn end(self) -> Result { - SerializeMap::end(self) + SerializeStruct::end(self.0).map(Value::from) } } @@ -442,6 +455,33 @@ impl SerializeStructVariant for ValueSerializeStructVariant { } } +macro_rules! serialize_type { + ( bytes ) => { + &[u8] + }; + ( str ) => { + &str + }; + ( unit_struct ) => { + &'static str + }; + ( $name: ident ) => { + $name + }; +} + +macro_rules! map_expected { + ( $($item: ident )* ) => { + $( + paste! { + fn [](self, _: serialize_type!($item)) -> Result { + Err(serde::ser::Error::custom("expected a map")) + } + } + )* + } +} + impl Serializer for DictionarySerializer { type Ok = Dictionary; type Error = serde::de::value::Error; @@ -453,61 +493,7 @@ impl Serializer for DictionarySerializer { type SerializeStruct = DictionarySerializeMap; type SerializeStructVariant = DictionarySerializeStructVariant; - fn serialize_bool(self, _v: bool) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_i8(self, _v: i8) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_i16(self, _v: i16) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_i32(self, _v: i32) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_i64(self, _v: i64) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_u8(self, _v: u8) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_u16(self, _v: u16) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_u32(self, _v: u32) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_u64(self, _v: u64) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_f32(self, _v: f32) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_f64(self, _v: f64) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_char(self, _v: char) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_str(self, _v: &str) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - - fn serialize_bytes(self, _v: &[u8]) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } + map_expected!(bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 bytes str char unit_struct); fn serialize_none(self) -> Result { Err(serde::ser::Error::custom("expected a map")) @@ -524,10 +510,6 @@ impl Serializer for DictionarySerializer { Err(serde::ser::Error::custom("expected a map")) } - fn serialize_unit_struct(self, _name: &'static str) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } - fn serialize_unit_variant( self, _name: &'static str, @@ -647,11 +629,8 @@ impl SerializeTupleVariant for DictionarySerializeTupleVariant { } } -impl SerializeMap for DictionarySerializeMap { - type Ok = ::Ok; - type Error = ::Error; - - fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> +impl DictionarySerializeMap { + fn convert_key(&self, key: &T) -> Result::Error> where T: ?Sized + Serialize, { @@ -660,8 +639,7 @@ impl SerializeMap for DictionarySerializeMap { state: ValueSerializerState::Default, })?; if let TypedValue::String(s) = key_ser.typed() { - self.next_key = Some(Key::new(s.clone())); - Ok(()) + Ok(Key::new(s.clone())) } else { Err(serde::de::Error::invalid_value( Unexpected::Other("non-string key in map"), @@ -669,6 +647,19 @@ impl SerializeMap for DictionarySerializeMap { )) } } +} + +impl SerializeMap for DictionarySerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.next_key = Some(self.convert_key(key)?); + Ok(()) + } fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> where @@ -702,7 +693,16 @@ impl SerializeStruct for DictionarySerializeMap { where T: ?Sized + Serialize, { - SerializeMap::serialize_entry(self, key, value) + //SerializeMap::serialize_entry(self, key, value) + let opt_ser = value.serialize(OptionValueSerializer(ValueSerializer { + container_depth: self.container_depth, + state: ValueSerializerState::Default, + }))?; + if let Some(val_ser) = opt_ser { + let key = self.convert_key(key)?; + self.kvs.insert(key, val_ser); + } + Ok(()) } fn end(self) -> Result { @@ -731,6 +731,264 @@ impl SerializeStructVariant for DictionarySerializeStructVariant { } } +macro_rules! option_serializer_forward { + ( $($item: ident )* ) => { + $( + paste! { + fn [](self, v: serialize_type!($item)) -> Result { + self.0.[](v).map(Some) + } + } + )* + } +} + +impl Serializer for OptionValueSerializer { + type Ok = Option; + type Error = ::Error; + type SerializeSeq = OptionValueSerializeSeq; + type SerializeTuple = OptionValueSerializeSeq; + type SerializeTupleStruct = OptionValueSerializeSeq; + type SerializeMap = OptionValueSerializeMap; + type SerializeStruct = OptionValueSerializeMap; + type SerializeTupleVariant = OptionValueSerializeTupleVariant; + type SerializeStructVariant = OptionValueSerializeStructVariant; + + option_serializer_forward!(bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char bytes str); + + fn serialize_none(self) -> Result { + Ok(None) + } + + fn serialize_some(self, value: &T) -> Result + where + T: ?Sized + Serialize, + { + value.serialize(self.0).map(Some) + } + + fn serialize_unit(self) -> Result { + self.0.serialize_unit().map(Some) + } + + fn serialize_unit_struct(self, name: &'static str) -> Result { + self.0.serialize_unit_struct(name).map(Some) + } + + fn serialize_unit_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + ) -> Result { + self.0 + .serialize_unit_variant(name, variant_index, variant) + .map(Some) + } + + fn serialize_newtype_struct( + mut self, + name: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + self.0.update_state_newtype(name); + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + self.0 + .serialize_newtype_variant(name, variant_index, variant, value) + .map(Some) + } + + fn serialize_seq(self, len: Option) -> Result { + self.0.serialize_seq(len).map(OptionValueSerializeSeq) + } + + fn serialize_tuple(self, len: usize) -> Result { + self.0.serialize_tuple(len).map(OptionValueSerializeSeq) + } + + fn serialize_tuple_struct( + self, + name: &'static str, + len: usize, + ) -> Result { + self.0 + .serialize_tuple_struct(name, len) + .map(OptionValueSerializeSeq) + } + + fn serialize_tuple_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + self.0 + .serialize_tuple_variant(name, variant_index, variant, len) + .map(OptionValueSerializeTupleVariant) + } + + fn serialize_struct( + self, + name: &'static str, + len: usize, + ) -> Result { + self.0 + .serialize_struct(name, len) + .map(OptionValueSerializeMap) + } + + fn serialize_map(self, len: Option) -> Result { + self.0.serialize_map(len).map(OptionValueSerializeMap) + } + + fn serialize_struct_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + self.0 + .serialize_struct_variant(name, variant_index, variant, len) + .map(OptionValueSerializeStructVariant) + } +} + +impl SerializeSeq for OptionValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeSeq::serialize_element(&mut self.0, value) + } + + fn end(self) -> Result { + SerializeSeq::end(self.0).map(Some) + } +} + +impl SerializeTuple for OptionValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + SerializeTuple::serialize_element(&mut self.0, value) + } + + fn end(self) -> Result { + SerializeTuple::end(self.0).map(Some) + } +} + +impl SerializeTupleStruct for OptionValueSerializeSeq { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(value) + } + + fn end(self) -> Result { + SerializeTupleStruct::end(self.0).map(Some) + } +} + +impl SerializeMap for OptionValueSerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_key(key) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_value(value) + } + + fn end(self) -> Result { + SerializeMap::end(self.0).map(Some) + } +} + +impl SerializeStruct for OptionValueSerializeMap { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(key, value) + } + + fn end(self) -> Result { + SerializeStruct::end(self.0).map(Some) + } +} + +impl SerializeStructVariant for OptionValueSerializeStructVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(key, value) + } + + fn end(self) -> Result { + self.0.end().map(Some) + } +} + +impl SerializeTupleVariant for OptionValueSerializeTupleVariant { + type Ok = ::Ok; + type Error = ::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(value) + } + + fn end(self) -> Result { + self.0.end().map(Some) + } +} + impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for &'a TypedValue { type Deserializer = Self; fn into_deserializer(self) -> Self::Deserializer { @@ -797,7 +1055,21 @@ impl<'de> serde::Deserializer<'de> for &Dictionary { visitor.visit_unit() } - forward_to_deserialize_any! { bool i8 i16 i32 i64 f32 f64 u8 u16 u32 u64 char str bytes byte_buf string option unit unit_struct seq tuple_struct tuple map newtype_struct struct identifier } + fn deserialize_struct( + self, + _name: &'static str, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_map(MapDeserializer::new( + self.kvs().iter().map(|(k, v)| (k, StructField(v.typed()))), + )) + } + + forward_to_deserialize_any! { bool i8 i16 i32 i64 f32 f64 u8 u16 u32 u64 char str bytes byte_buf string option unit unit_struct seq tuple_struct tuple map newtype_struct identifier } } impl<'de> serde::Deserializer<'de> for &TypedValue { @@ -936,7 +1208,124 @@ impl<'de> serde::Deserializer<'de> for &TypedValue { } } - forward_to_deserialize_any! { bool i8 i16 i32 i64 u8 u16 u32 u64 char str string unit_struct seq tuple_struct tuple map struct identifier ignored_any } + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + TypedValue::Dictionary(d) => d.deserialize_struct(name, fields, visitor), + _ => self.deserialize_any(visitor), + } + } + + forward_to_deserialize_any! { bool i8 i16 i32 i64 u8 u16 u32 u64 char str string unit_struct seq tuple_struct tuple map identifier ignored_any } +} + +struct StructField<'a>(&'a TypedValue); + +impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for StructField<'a> { + type Deserializer = Self; + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +macro_rules! deserialize_forward { + ( $($item: ident )* ) => { + $( + paste! { + fn [](self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.[](visitor) + } + } + )* + } +} + +impl<'de> serde::Deserializer<'de> for StructField<'_> { + type Error = serde::de::value::Error; + + deserialize_forward!(any bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 str char string bytes byte_buf unit seq ignored_any map identifier); + + fn deserialize_unit_struct( + self, + name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.deserialize_unit_struct(name, visitor) + } + + fn deserialize_newtype_struct( + self, + name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.deserialize_newtype_struct(name, visitor) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.deserialize_tuple(len, visitor) + } + + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.deserialize_struct(name, fields, visitor) + } + + fn deserialize_tuple_struct( + self, + name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.deserialize_tuple_struct(name, len, visitor) + } + + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.deserialize_enum(name, variants, visitor) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_some(self.0) + } } #[cfg(test)] @@ -1034,6 +1423,10 @@ mod test { seed_range: Vec<(Method, RawValue)>, option1: Option, option2: Option, + vec_opt: Vec>, + optopt1: Option>, + optopt2: Option>, + optopt3: Option>, unit: (), sk: SecretKey, pk: Point, @@ -1125,6 +1518,10 @@ mod test { seed_range, option1: Some(2), option2: None, + optopt1: Some(Some(4)), + optopt2: Some(None), + optopt3: None, + vec_opt: vec![Some(3), None], unit: (), sk, pk, From 50407d1730f7c338c9d8596563828538ce4ec945 Mon Sep 17 00:00:00 2001 From: Daniel Gulotta Date: Fri, 12 Sep 2025 20:11:39 -0700 Subject: [PATCH 8/9] don't use paste --- Cargo.toml | 1 - src/lib.rs | 1 + src/middleware/serializer.rs | 25 +++++++++---------------- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 66b064b4..299dea35 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,7 +40,6 @@ serde_bytes = "0.11" serde_arrays = "0.2.0" sha2 = { version = "0.10.9" } rand_chacha = "0.3.1" -paste = "1.0.15" # Uncomment for debugging with https://github.com/ed255/plonky2/ at branch `feat/debug`. The repo directory needs to be checked out next to the pod2 repo directory. # [patch."https://github.com/0xPARC/plonky2"] diff --git a/src/lib.rs b/src/lib.rs index 4b7de653..76c6047e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ #![allow(clippy::uninlined_format_args)] // TODO: Remove this in another PR #![allow(clippy::manual_repeat_n)] // TODO: Remove this in another PR #![allow(clippy::large_enum_variant)] // TODO: Remove this in another PR +#![feature(macro_metavar_expr_concat)] #![feature(mapped_lock_guards)] pub mod backends; diff --git a/src/middleware/serializer.rs b/src/middleware/serializer.rs index 476c1610..5d3f743c 100644 --- a/src/middleware/serializer.rs +++ b/src/middleware/serializer.rs @@ -4,7 +4,6 @@ use std::{ }; use base64::{prelude::BASE64_STANDARD, Engine}; -use paste::paste; use serde::{ de::{ value::{ @@ -473,10 +472,8 @@ macro_rules! serialize_type { macro_rules! map_expected { ( $($item: ident )* ) => { $( - paste! { - fn [](self, _: serialize_type!($item)) -> Result { - Err(serde::ser::Error::custom("expected a map")) - } + fn ${concat(serialize_, $item)}(self, _: serialize_type!($item)) -> Result { + Err(serde::ser::Error::custom("expected a map")) } )* } @@ -734,10 +731,8 @@ impl SerializeStructVariant for DictionarySerializeStructVariant { macro_rules! option_serializer_forward { ( $($item: ident )* ) => { $( - paste! { - fn [](self, v: serialize_type!($item)) -> Result { - self.0.[](v).map(Some) - } + fn ${concat(serialize_, $item)}(self, v: serialize_type!($item)) -> Result { + self.0.${concat(serialize_, $item)}(v).map(Some) } )* } @@ -1238,13 +1233,11 @@ impl<'a, 'de> IntoDeserializer<'de, serde::de::value::Error> for StructField<'a> macro_rules! deserialize_forward { ( $($item: ident )* ) => { $( - paste! { - fn [](self, visitor: V) -> Result - where - V: serde::de::Visitor<'de>, - { - self.0.[](visitor) - } + fn ${concat(deserialize_, $item)}(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.0.${concat(deserialize_, $item)}(visitor) } )* } From 79f57eaac80029a9a77707612ed1410031109eff Mon Sep 17 00:00:00 2001 From: Daniel Gulotta Date: Sat, 13 Sep 2025 21:47:20 -0700 Subject: [PATCH 9/9] hint for sets --- src/middleware/serialization.rs | 2 +- src/middleware/serializer.rs | 86 +++++++++++++++++++++++++++++---- 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/src/middleware/serialization.rs b/src/middleware/serialization.rs index b309c3c8..3f5ce4e5 100644 --- a/src/middleware/serialization.rs +++ b/src/middleware/serialization.rs @@ -141,5 +141,5 @@ where { let mut sorted_values: Vec<&Value> = value.iter().collect(); sorted_values.sort_by_key(|v| v.raw()); - serializer.serialize_newtype_struct("Set", &sorted_values) + serializer.serialize_newtype_struct("pod2::Set", &sorted_values) } diff --git a/src/middleware/serializer.rs b/src/middleware/serializer.rs index 5d3f743c..e5764390 100644 --- a/src/middleware/serializer.rs +++ b/src/middleware/serializer.rs @@ -36,12 +36,50 @@ use crate::{ }, }; +/// Indicates that `value` should be serialized to a `Set` rather than +/// an `Array`. +/// +/// Serde regards both arrays and sets as "sequences", so the serializer +/// cannot distinguish between the two. By default, [`ValueSerializer`] will +/// serialize sequences to an `Array`. This function hints to the serializer +/// that it should serialize `value` as a `Set` instead. If `value` is not a sequence, +/// then the hint has no effect. +/// ``` +/// use pod2::middleware::{Key, TypedValue, +/// serializer::{DictionarySerializer, serialize_seq_as_set}}; +/// use std::collections::HashSet; +/// use serde::Serialize; +/// +/// #[derive(Serialize)] +/// struct ExampleStruct { +/// arr: HashSet, +/// #[serde(serialize_with = "serialize_seq_as_set")] +/// set: HashSet +/// } +/// +/// let set: HashSet = [2].into_iter().collect(); +/// let ex = ExampleStruct { +/// arr: set.clone(), +/// set +/// }; +/// let d = ex.serialize(DictionarySerializer::new(6)).unwrap(); +/// assert!(matches!(d.get(&Key::from("arr")).unwrap().typed(), TypedValue::Array(_))); +/// assert!(matches!(d.get(&Key::from("set")).unwrap().typed(), TypedValue::Set(_))); +/// ``` +pub fn serialize_seq_as_set( + value: &T, + serializer: S, +) -> Result { + serializer.serialize_newtype_struct("pod2::AsSet", value) +} + #[derive(Clone, Copy)] enum ValueSerializerState { Default, RawValue, Point, SecretKey, + AsSet, } #[derive(Clone, Copy)] @@ -58,8 +96,13 @@ pub struct DictionarySerializer { container_depth: usize, } +enum SeqData { + Array(Vec), + Set(HashSet), +} + pub struct ValueSerializeSeq { - data: Vec, + data: SeqData, container_depth: usize, } @@ -109,6 +152,7 @@ impl ValueSerializer { "pod2::RawValue" => self.state = ValueSerializerState::RawValue, "pod2::Point" => self.state = ValueSerializerState::Point, "pod2::SecretKey" => self.state = ValueSerializerState::SecretKey, + "pod2::AsSet" => self.state = ValueSerializerState::AsSet, _ => (), } } @@ -213,7 +257,10 @@ impl Serializer for ValueSerializer { fn serialize_seq(self, _len: Option) -> Result { Ok(ValueSerializeSeq { - data: Vec::new(), + data: match self.state { + ValueSerializerState::AsSet => SeqData::Set(HashSet::new()), + _ => SeqData::Array(vec![]), + }, container_depth: self.container_depth, }) } @@ -338,16 +385,30 @@ impl SerializeSeq for ValueSerializeSeq { where T: ?Sized + Serialize, { - self.data.push(value.serialize(ValueSerializer { + let val = value.serialize(ValueSerializer { container_depth: self.container_depth, state: ValueSerializerState::Default, - })?); + })?; + match &mut self.data { + SeqData::Set(s) => { + s.insert(val); + } + SeqData::Array(a) => a.push(val), + } Ok(()) } fn end(self) -> Result { - let arr = Array::new(self.container_depth, self.data).map_err(serde::de::Error::custom)?; - Ok(Value::from(arr)) + match self.data { + SeqData::Set(s) => { + let set = Set::new(self.container_depth, s).map_err(serde::de::Error::custom)?; + Ok(Value::from(set)) + } + SeqData::Array(a) => { + let arr = Array::new(self.container_depth, a).map_err(serde::de::Error::custom)?; + Ok(Value::from(arr)) + } + } } } @@ -569,7 +630,7 @@ impl Serializer for DictionarySerializer { Ok(DictionarySerializeTupleVariant { name: variant, inner: ValueSerializeSeq { - data: Vec::new(), + data: SeqData::Array(Vec::new()), container_depth: self.container_depth, }, }) @@ -1324,7 +1385,7 @@ impl<'de> serde::Deserializer<'de> for StructField<'_> { #[cfg(test)] mod test { - use std::collections::HashMap; + use std::collections::{HashMap, HashSet}; use serde::{ de::{DeserializeOwned, Visitor}, @@ -1334,7 +1395,7 @@ mod test { use crate::{ backends::plonky2::primitives::ec::{curve::Point, schnorr::SecretKey}, middleware::{ - serializer::{DictionarySerializer, ValueSerializer}, + serializer::{serialize_seq_as_set, DictionarySerializer, ValueSerializer}, Params, RawValue, TypedValue, Value, }, }; @@ -1432,6 +1493,8 @@ mod test { inf: Float, nan: Float, map: HashMap, + #[serde(serialize_with = "serialize_seq_as_set")] + set: HashSet, } fn test_roundtrip(t: T) { @@ -1497,6 +1560,10 @@ mod test { map } + fn a_hash_set() -> HashSet { + [2, 3].into_iter().collect() + } + #[test] fn test_complicated_struct() { let seed_range = vec![ @@ -1533,6 +1600,7 @@ mod test { inf: Float(f32::NEG_INFINITY), nan: Float(f32::NAN), map: a_hash_map(), + set: a_hash_set(), }; test_roundtrip(desc.clone()); test_dict_consistency_and_roundtrip(desc);