diff --git a/avro/src/error.rs b/avro/src/error.rs index a3e2cf05..bdb20552 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -579,6 +579,12 @@ pub enum Details { #[error("Cannot convert a slice to Uuid: {0}")] UuidFromSlice(#[source] uuid::Error), + + #[error("Expected String for Map key when serializing a flattened struct")] + MapFieldExpectedString, + + #[error("No key for value when serializing a map")] + MapNoKey, } #[derive(thiserror::Error, PartialEq)] diff --git a/avro/src/lib.rs b/avro/src/lib.rs index 853722d6..f75c5a3a 100644 --- a/avro/src/lib.rs +++ b/avro/src/lib.rs @@ -945,14 +945,12 @@ mod bigdecimal; mod bytes; mod codec; -mod de; mod decimal; mod decode; mod duration; mod encode; mod reader; -mod ser; -mod ser_schema; +mod serde; mod writer; pub mod error; @@ -979,7 +977,6 @@ pub use codec::xz::XzSettings; #[cfg(feature = "zstandard")] pub use codec::zstandard::ZstandardSettings; pub use codec::{Codec, DeflateSettings}; -pub use de::from_value; pub use decimal::Decimal; pub use duration::{Days, Duration, Millis, Months}; pub use error::Error; @@ -988,7 +985,7 @@ pub use reader::{ from_avro_datum_reader_schemata, from_avro_datum_schemata, read_marker, }; pub use schema::{AvroSchema, Schema}; -pub use ser::to_value; +pub use serde::{de::from_value, ser::to_value}; pub use uuid::Uuid; pub use writer::{ GenericSingleObjectWriter, SpecificSingleObjectWriter, Writer, WriterBuilder, to_avro_datum, diff --git a/avro/src/de.rs b/avro/src/serde/de.rs similarity index 100% rename from avro/src/de.rs rename to avro/src/serde/de.rs diff --git a/avro/src/serde/mod.rs b/avro/src/serde/mod.rs new file mode 100644 index 00000000..509d2e5c --- /dev/null +++ b/avro/src/serde/mod.rs @@ -0,0 +1,4 @@ +pub mod de; +pub mod ser; +pub mod ser_schema; +mod util; diff --git a/avro/src/ser.rs b/avro/src/serde/ser.rs similarity index 99% rename from avro/src/ser.rs rename to avro/src/serde/ser.rs index 1bc90755..d78f5017 100644 --- a/avro/src/ser.rs +++ b/avro/src/serde/ser.rs @@ -479,7 +479,12 @@ impl ser::SerializeStructVariant for StructVariantSerializer<'_> { /// Interpret a serializeable instance as a `Value`. /// /// This conversion can fail if the value is not valid as per the Avro specification. -/// e.g: HashMap with non-string keys +/// e.g: `HashMap` with non-string keys. +/// +/// This function does not work if `S` has any fields (recursively) that have the `#[serde(flatten)]` +/// attribute. Please use [`Writer::append_ser`] if that's the case. +/// +/// [`Writer::append_ser`]: crate::Writer::append_ser pub fn to_value(value: S) -> Result { let mut serializer = Serializer::default(); value.serialize(&mut serializer) diff --git a/avro/src/ser_schema.rs b/avro/src/serde/ser_schema.rs similarity index 96% rename from avro/src/ser_schema.rs rename to avro/src/serde/ser_schema.rs index f9ee2fc0..02a65bca 100644 --- a/avro/src/ser_schema.rs +++ b/avro/src/serde/ser_schema.rs @@ -23,9 +23,10 @@ use crate::{ encode::{encode_int, encode_long}, error::{Details, Error}, schema::{Name, NamesRef, Namespace, RecordField, RecordSchema, Schema}, + serde::util::StringSerializer, }; use bigdecimal::BigDecimal; -use serde::ser; +use serde::{Serialize, ser}; use std::{borrow::Cow, cmp::Ordering, collections::HashMap, io::Write, str::FromStr}; const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024; @@ -251,6 +252,8 @@ pub struct SchemaAwareWriteSerializeStruct<'a, 's, W: Write> { record_schema: &'s RecordSchema, /// Fields we received in the wrong order field_cache: HashMap>, + /// The current field name when serializing from a map (for `flatten` support). + map_field_name: Option, field_position: usize, bytes_written: usize, } @@ -264,6 +267,7 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> { ser, record_schema, field_cache: HashMap::new(), + map_field_name: None, field_position: 0, bytes_written: 0, } @@ -352,6 +356,11 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> { "There should be no more unwritten fields at this point: {:?}", self.field_cache ); + debug_assert!( + self.map_field_name.is_none(), + "There should be no field name at this point: field {:?}", + self.map_field_name + ); Ok(self.bytes_written) } } @@ -371,17 +380,14 @@ impl ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_, .and_then(|idx| self.record_schema.fields.get(*idx)); match record_field { - Some(field) => { - // self.item_count += 1; - self.serialize_next_field(field, value).map_err(|e| { - Details::SerializeRecordFieldWithSchema { - field_name: key.to_string(), - record_schema: Schema::Record(self.record_schema.clone()), - error: Box::new(e), - } - .into() - }) - } + Some(field) => self.serialize_next_field(field, value).map_err(|e| { + Details::SerializeRecordFieldWithSchema { + field_name: key.to_string(), + record_schema: Schema::Record(self.record_schema.clone()), + error: Box::new(e), + } + .into() + }), None => Err(Details::FieldName(String::from(key)).into()), } } @@ -420,6 +426,53 @@ impl ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_, } } +/// This implementation is used to support `#[serde(flatten)]` as that uses SerializeMap instead of SerializeStruct. +impl ser::SerializeMap for SchemaAwareWriteSerializeStruct<'_, '_, W> { + type Ok = usize; + type Error = Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let name = key.serialize(StringSerializer)?; + let old = self.map_field_name.replace(name); + debug_assert!( + old.is_none(), + "Expected a value instead of a key: old key: {old:?}, new key: {:?}", + self.map_field_name + ); + Ok(()) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let key = self.map_field_name.take().ok_or(Details::MapNoKey)?; + let record_field = self + .record_schema + .lookup + .get(&key) + .and_then(|idx| self.record_schema.fields.get(*idx)); + match record_field { + Some(field) => self.serialize_next_field(field, value).map_err(|e| { + Details::SerializeRecordFieldWithSchema { + field_name: key.to_string(), + record_schema: Schema::Record(self.record_schema.clone()), + error: Box::new(e), + } + .into() + }), + None => Err(Details::FieldName(key).into()), + } + } + + fn end(self) -> Result { + self.end() + } +} + impl ser::SerializeStructVariant for SchemaAwareWriteSerializeStruct<'_, '_, W> { type Ok = usize; type Error = Error; @@ -436,6 +489,46 @@ impl ser::SerializeStructVariant for SchemaAwareWriteSerializeStruct<' } } +/// Map serializer that switches between Struct or Map. +/// +/// This exists because when `#[serde(flatten)]` is used, struct fields are serialized as a map. +pub enum SchemaAwareWriteSerializeMapOrStruct<'a, 's, W: Write> { + Struct(SchemaAwareWriteSerializeStruct<'a, 's, W>), + Map(SchemaAwareWriteSerializeMap<'a, 's, W>), +} + +impl ser::SerializeMap for SchemaAwareWriteSerializeMapOrStruct<'_, '_, W> { + type Ok = usize; + type Error = Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + match self { + Self::Struct(s) => s.serialize_key(key), + Self::Map(s) => s.serialize_key(key), + } + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + match self { + Self::Struct(s) => s.serialize_value(value), + Self::Map(s) => s.serialize_value(value), + } + } + + fn end(self) -> Result { + match self { + Self::Struct(s) => s.end(), + Self::Map(s) => s.end(), + } + } +} + /// The tuple struct serializer for [`SchemaAwareWriteSerializer`]. /// [`SchemaAwareWriteSerializeTupleStruct`] can serialize to an Avro array, record, or big-decimal. /// When serializing to a record, fields must be provided in the correct order, since no names are provided. @@ -1499,7 +1592,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { &'a mut self, len: Option, schema: &'s Schema, - ) -> Result, Error> { + ) -> Result, Error> { let create_error = |cause: String| { let len_str = len .map(|l| format!("{l}")) @@ -1513,15 +1606,17 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { }; match schema { - Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMap::new( - self, - map_schema.types.as_ref(), - len, + Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMapOrStruct::Map( + SchemaAwareWriteSerializeMap::new(self, map_schema.types.as_ref(), len), )), + Schema::Ref { name: ref_name } => { + let ref_schema = self.get_ref_schema(ref_name)?; + self.serialize_map_with_schema(len, ref_schema) + } Schema::Union(union_schema) => { for (i, variant_schema) in union_schema.schemas.iter().enumerate() { match variant_schema { - Schema::Map(_) => { + Schema::Map(_) | Schema::Record(_) | Schema::Ref { .. } => { encode_int(i as i32, &mut *self.writer)?; return self.serialize_map_with_schema(len, variant_schema); } @@ -1532,6 +1627,9 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { "Expected a Map schema in {union_schema:?}" ))) } + Schema::Record(record_schema) => Ok(SchemaAwareWriteSerializeMapOrStruct::Struct( + SchemaAwareWriteSerializeStruct::new(self, record_schema), + )), _ => Err(create_error(format!( "Expected Map or Union schema. Got: {schema}" ))), @@ -1630,7 +1728,7 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut SchemaAwareWriteSerializer<'s type SerializeTuple = SchemaAwareWriteSerializeSeq<'a, 's, W>; type SerializeTupleStruct = SchemaAwareWriteSerializeTupleStruct<'a, 's, W>; type SerializeTupleVariant = SchemaAwareWriteSerializeTupleStruct<'a, 's, W>; - type SerializeMap = SchemaAwareWriteSerializeMap<'a, 's, W>; + type SerializeMap = SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>; type SerializeStruct = SchemaAwareWriteSerializeStruct<'a, 's, W>; type SerializeStructVariant = SchemaAwareWriteSerializeStruct<'a, 's, W>; diff --git a/avro/src/serde/util.rs b/avro/src/serde/util.rs new file mode 100644 index 00000000..55ea2ead --- /dev/null +++ b/avro/src/serde/util.rs @@ -0,0 +1,300 @@ +use crate::{Error, error::Details}; +use serde::{ + Serialize, Serializer, + ser::{ + SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, + SerializeTupleStruct, SerializeTupleVariant, + }, +}; + +/// Serialize a `T: Serialize` as a `String`. +/// +/// An error will be returned if any other function than [`Serializer::serialize_str`] is called. +pub struct StringSerializer; + +impl Serializer for StringSerializer { + type Ok = String; + type Error = Error; + type SerializeSeq = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + type SerializeMap = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + + fn serialize_bool(self, _v: bool) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_i8(self, _v: i8) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_i16(self, _v: i16) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_i32(self, _v: i32) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_i64(self, _v: i64) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_u8(self, _v: u8) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_u16(self, _v: u16) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_u32(self, _v: u32) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_u64(self, _v: u64) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_f32(self, _v: f32) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_f64(self, _v: f64) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_char(self, _v: char) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_str(self, v: &str) -> Result { + Ok(v.to_string()) + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_none(self) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_some(self, _value: &T) -> Result + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_unit(self) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + _value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_map(self, _len: Option) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Details::MapFieldExpectedString.into()) + } +} + +impl SerializeSeq for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_element(&mut self, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result { + Err(Details::MapFieldExpectedString.into()) + } +} + +impl SerializeTuple for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_element(&mut self, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result { + Err(Details::MapFieldExpectedString.into()) + } +} + +impl SerializeTupleStruct for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_field(&mut self, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result { + Err(Details::MapFieldExpectedString.into()) + } +} + +impl SerializeTupleVariant for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_field(&mut self, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result { + Err(Details::MapFieldExpectedString.into()) + } +} + +impl SerializeMap for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_key(&mut self, _key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_value(&mut self, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result { + Err(Details::MapFieldExpectedString.into()) + } +} + +impl SerializeStruct for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_field(&mut self, _key: &'static str, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result { + Err(Details::MapFieldExpectedString.into()) + } +} + +impl SerializeStructVariant for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_field(&mut self, _key: &'static str, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result { + Err(Details::MapFieldExpectedString.into()) + } +} diff --git a/avro/src/types.rs b/avro/src/types.rs index 4448eef2..5a54c3f2 100644 --- a/avro/src/types.rs +++ b/avro/src/types.rs @@ -2701,7 +2701,7 @@ Field with name '"b"' is not a member of the map items"#, #[test] fn test_avro_3460_validation_with_refs_real_struct() -> TestResult { - use crate::ser::Serializer; + use crate::serde::ser::Serializer; use serde::Serialize; #[derive(Serialize, Clone)] @@ -2790,7 +2790,7 @@ Field with name '"b"' is not a member of the map items"#, } fn avro_3674_with_or_without_namespace(with_namespace: bool) -> TestResult { - use crate::ser::Serializer; + use crate::serde::ser::Serializer; use serde::Serialize; let schema_str = r#" @@ -2883,7 +2883,7 @@ Field with name '"b"' is not a member of the map items"#, } fn avro_3688_schema_resolution_panic(set_field_b: bool) -> TestResult { - use crate::ser::Serializer; + use crate::serde::ser::Serializer; use serde::{Deserialize, Serialize}; let schema_str = r#"{ diff --git a/avro/src/writer.rs b/avro/src/writer.rs index 3b62b168..a1ae239d 100644 --- a/avro/src/writer.rs +++ b/avro/src/writer.rs @@ -22,7 +22,7 @@ use crate::{ error::Details, headers::{HeaderBuilder, RabinFingerprintHeader}, schema::{AvroSchema, Name, ResolvedOwnedSchema, ResolvedSchema, Schema}, - ser_schema::SchemaAwareWriteSerializer, + serde::ser_schema::SchemaAwareWriteSerializer, types::Value, }; use serde::Serialize; diff --git a/avro_derive/src/lib.rs b/avro_derive/src/lib.rs index c447e58b..2b225fb9 100644 --- a/avro_derive/src/lib.rs +++ b/avro_derive/src/lib.rs @@ -39,6 +39,8 @@ struct FieldOptions { rename: Option, #[darling(default)] skip: Option, + #[darling(default)] + flatten: Option, } #[derive(darling::FromAttributes)] @@ -142,26 +144,46 @@ fn get_data_struct_schema_def( let mut record_field_exprs = vec![]; match s.fields { syn::Fields::Named(ref a) => { - let mut index: usize = 0; for field in a.named.iter() { - let mut name = field.ident.as_ref().unwrap().to_string(); // we know everything has a name + let mut name = field + .ident + .as_ref() + .expect("Field must have a name") + .to_string(); if let Some(raw_name) = name.strip_prefix("r#") { name = raw_name.to_string(); } let field_attrs = - FieldOptions::from_attributes(&field.attrs[..]).map_err(darling_to_syn)?; + FieldOptions::from_attributes(&field.attrs).map_err(darling_to_syn)?; let doc = preserve_optional(field_attrs.doc.or_else(|| extract_outer_doc(&field.attrs))); match (field_attrs.rename, rename_all) { (Some(rename), _) => { name = rename; } - (None, rename_all) if !matches!(rename_all, RenameRule::None) => { + (None, rename_all) if rename_all != RenameRule::None => { name = rename_all.apply_to_field(&name); } _ => {} } - if let Some(true) = field_attrs.skip { + if Some(true) == field_attrs.skip { + continue; + } else if Some(true) == field_attrs.flatten { + // Inline the fields of the child record at runtime, as we don't have access to + // the schema here. + let flatten_ty = &field.ty; + record_field_exprs.push(quote! { + if let ::apache_avro::schema::Schema::Record(::apache_avro::schema::RecordSchema { fields, .. }) = #flatten_ty::get_schema() { + for mut field in fields { + field.position = schema_fields.len(); + schema_fields.push(field) + } + } else { + panic!("Can only flatten RecordSchema, got {:?}", #flatten_ty::get_schema()) + } + }); + + // Don't add this field as it's been replaced by the child record fields continue; } let default_value = match field_attrs.default { @@ -181,20 +203,18 @@ fn get_data_struct_schema_def( }; let aliases = preserve_vec(field_attrs.alias); let schema_expr = type_to_schema_expr(&field.ty)?; - let position = index; record_field_exprs.push(quote! { - apache_avro::schema::RecordField { - name: #name.to_string(), - doc: #doc, - default: #default_value, - aliases: #aliases, - schema: #schema_expr, - order: apache_avro::schema::RecordFieldOrder::Ascending, - position: #position, - custom_attributes: Default::default(), - } + schema_fields.push(::apache_avro::schema::RecordField { + name: #name.to_string(), + doc: #doc, + default: #default_value, + aliases: #aliases, + schema: #schema_expr, + order: ::apache_avro::schema::RecordFieldOrder::Ascending, + position: schema_fields.len(), + custom_attributes: Default::default(), + }); }); - index += 1; } } syn::Fields::Unnamed(_) => { @@ -212,8 +232,14 @@ fn get_data_struct_schema_def( } let record_doc = preserve_optional(record_doc); let record_aliases = preserve_vec(aliases); + // When flatten is involved, there will be more but we don't know how many. This optimises for + // the most common case where there is no flatten. + let minimum_fields = record_field_exprs.len(); Ok(quote! { - let schema_fields = vec![#(#record_field_exprs),*]; + let mut schema_fields = Vec::with_capacity(#minimum_fields); + #(#record_field_exprs)* + let schema_field_set: ::std::collections::HashSet<_> = schema_fields.iter().map(|rf| &rf.name).collect(); + assert_eq!(schema_fields.len(), schema_field_set.len(), "Duplicate field names found: {schema_fields:?}"); let name = apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse struct name for schema {}", #full_schema_name)[..]); let lookup: std::collections::BTreeMap = schema_fields .iter() @@ -683,7 +709,7 @@ mod tests { match syn::parse2::(test_struct) { Ok(mut input) => { let schema_res = derive_avro_schema(&mut input); - let expected_token_stream = r#"let schema_fields = vec ! [apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 0usize , custom_attributes : Default :: default () , }] ;"#; + let expected_token_stream = r#"let mut schema_fields = Vec :: with_capacity (1usize) ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : schema_fields . len () , custom_attributes : Default :: default () , }) ;"#; let schema_token_stream = schema_res.unwrap().to_string(); assert!(schema_token_stream.contains(expected_token_stream)); } @@ -725,7 +751,7 @@ mod tests { match syn::parse2::(test_struct) { Ok(mut input) => { let schema_res = derive_avro_schema(&mut input); - let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let schema_fields = vec ! [apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 0usize , custom_attributes : Default :: default () , } , apache_avro :: schema :: RecordField { name : "DOUBLE_ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 1usize , custom_attributes : Default :: default () , }] ;"#; + let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let mut schema_fields = Vec :: with_capacity (2usize) ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : schema_fields . len () , custom_attributes : Default :: default () , }) ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "DOUBLE_ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : schema_fields . len () , custom_attributes : Default :: default () , }) ;"#; let schema_token_stream = schema_res.unwrap().to_string(); assert!(schema_token_stream.contains(expected_token_stream)); } @@ -769,7 +795,7 @@ mod tests { match syn::parse2::(test_struct) { Ok(mut input) => { let schema_res = derive_avro_schema(&mut input); - let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let schema_fields = vec ! [apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 0usize , custom_attributes : Default :: default () , } , apache_avro :: schema :: RecordField { name : "DoubleItem" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 1usize , custom_attributes : Default :: default () , }] ;"#; + let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let mut schema_fields = Vec :: with_capacity (2usize) ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : schema_fields . len () , custom_attributes : Default :: default () , }) ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "DoubleItem" . to_string () , doc : None , default : None , aliases : None , schema : apache_avro :: schema :: Schema :: Int , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : schema_fields . len () , custom_attributes : Default :: default () , }) ;"#; let schema_token_stream = schema_res.unwrap().to_string(); assert!(schema_token_stream.contains(expected_token_stream)); } diff --git a/avro_derive/tests/derive.rs b/avro_derive/tests/derive.rs index 8d92c57a..6972e9a5 100644 --- a/avro_derive/tests/derive.rs +++ b/avro_derive/tests/derive.rs @@ -1686,4 +1686,168 @@ mod test_derive { panic!("Unexpected schema type for Foo") } } + + #[test] + fn avro_rs_247_serde_flatten_support() { + #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)] + struct Nested { + a: bool, + } + + #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)] + struct Foo { + #[serde(flatten)] + #[avro(flatten)] + nested: Nested, + b: i32, + } + + let schema = r#" + { + "type":"record", + "name":"Foo", + "fields": [ + { + "name":"a", + "type":"boolean" + }, + { + "name":"b", + "type":"int" + } + ] + } + "#; + + let schema = Schema::parse_str(schema).unwrap(); + assert_eq!(schema, Foo::get_schema()); + + serde_assert(Foo { + nested: Nested { a: true }, + b: 321, + }); + } + + #[test] + fn avro_rs_247_serde_nested_flatten_support() { + use apache_avro::AvroSchema; + use serde::{Deserialize, Serialize}; + + #[derive(AvroSchema, Debug, Clone, PartialEq, Serialize, Deserialize)] + pub struct NestedFoo { + one: u32, + } + + #[derive(AvroSchema, Debug, Clone, PartialEq, Serialize, Deserialize)] + pub struct Foo { + #[serde(flatten)] + #[avro(flatten)] + nested_foo: NestedFoo, + } + + #[derive(AvroSchema, Debug, Clone, PartialEq, Serialize, Deserialize)] + struct Bar { + foo: Foo, + two: u32, + } + + let schema = r#" + { + "type":"record", + "name":"Bar", + "fields": [ + { + "name":"foo", + "type": { + "type": "record", + "name": "Foo", + "fields": [ + { + "name": "one", + "type": "long" + } + ] + } + }, + { + "name":"two", + "type":"long" + } + ] + } + "#; + + let schema = Schema::parse_str(schema).unwrap(); + assert_eq!(schema, Bar::get_schema()); + + serde_assert(Bar { + foo: Foo { + nested_foo: NestedFoo { one: 42 }, + }, + two: 2, + }); + } + + #[test] + #[should_panic(expected = "Duplicate field names found")] + fn avro_rs_247_serde_flatten_support_duplicate_field_name() { + #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)] + struct Nested { + a: i32, + } + + #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)] + struct Foo { + #[serde(flatten)] + #[avro(flatten)] + nested: Nested, + a: i32, + } + + Foo::get_schema(); + } + + #[test] + fn avro_rs_247_serde_flatten_support_with_skip() { + #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)] + struct Nested { + a: bool, + #[serde(skip)] + #[avro(skip)] + c: f64, + } + + #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)] + struct Foo { + #[serde(flatten)] + #[avro(flatten)] + nested: Nested, + b: i32, + } + + let schema = r#" + { + "type":"record", + "name":"Foo", + "fields": [ + { + "name":"a", + "type":"boolean" + }, + { + "name":"b", + "type":"int" + } + ] + } + "#; + + let schema = Schema::parse_str(schema).unwrap(); + assert_eq!(schema, Foo::get_schema()); + + serde_assert(Foo { + nested: Nested { a: true, c: 0.0 }, + b: 321, + }); + } }