Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions avro/src/serde/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,11 +590,13 @@ where
named_schemas: &mut HashSet<Name>,
enclosing_namespace: NamespaceRef,
) -> Schema {
let variants = vec![
Schema::Null,
T::get_schema_in_ctxt(named_schemas, enclosing_namespace),
];

let variants = match T::get_schema_in_ctxt(named_schemas, enclosing_namespace) {
Schema::Union(union) => vec![Schema::Null]
.into_iter()
.chain(union.schemas)
.collect(),
schema => vec![Schema::Null, schema],
};
Schema::Union(
UnionSchema::new(variants).expect("Option<T> must produce a valid (non-nested) union"),
)
Expand Down Expand Up @@ -970,11 +972,12 @@ mod tests {
use apache_avro_test_helper::TestResult;

use crate::{
AvroSchema, Schema,
AvroSchema, AvroSchemaComponent, Schema,
reader::datum::GenericDatumReader,
schema::{FixedSchema, Name},
schema::{FixedSchema, Name, NamespaceRef, UnionSchema},
writer::datum::GenericDatumWriter,
};
use std::collections::HashSet;

#[test]
fn avro_rs_401_str() -> TestResult {
Expand Down Expand Up @@ -1090,9 +1093,7 @@ mod tests {
}

#[test]
#[should_panic(
expected = "Option<T> must produce a valid (non-nested) union: Error { details: Unions may not directly contain a union }"
)]
#[should_panic(expected = "Unions cannot contain duplicate types, found at least two Null")]
fn avro_rs_489_option_option() {
<Option<Option<i32>>>::get_schema();
}
Expand Down Expand Up @@ -1226,4 +1227,32 @@ mod tests {
);
Ok(())
}

#[test]
fn test_nullable_complex_union() -> TestResult {
let schema = Schema::parse_str(r#"["null", "int", "string"]"#)?;

#[allow(dead_code)]
enum MyUnion {
Int(i32),
String(String),
}

impl AvroSchemaComponent for MyUnion {
fn get_schema_in_ctxt(
named_schemas: &mut HashSet<Name>,
enclosing_namespace: NamespaceRef,
) -> Schema {
let int_schema = i32::get_schema_in_ctxt(named_schemas, enclosing_namespace);
let string_schema = String::get_schema_in_ctxt(named_schemas, enclosing_namespace);
Schema::Union(
UnionSchema::new(vec![int_schema, string_schema]).expect("Union must be valid"),
)
}
}

assert_eq!(schema, Option::<MyUnion>::get_schema());

Ok(())
}
}
30 changes: 24 additions & 6 deletions avro/src/serde/deser_schema/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,28 @@ pub struct UnionEnumDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
reader: &'r mut R,
variants: &'s [Schema],
config: Config<'s, S>,
branch_index: Option<usize>,
}

impl<'s, 'r, R: Read, S: Borrow<Schema>> UnionEnumDeserializer<'s, 'r, R, S> {
pub fn new(reader: &'r mut R, schema: &'s UnionSchema, config: Config<'s, S>) -> Self {
pub fn new(
reader: &'r mut R,
schema: &'s UnionSchema,
config: Config<'s, S>,
branch_index: Option<usize>,
) -> Self {
Self {
reader,
variants: schema.variants(),
config,
branch_index,
}
}

fn get_variant_index(&self, branch_index: usize) -> usize {
match self.branch_index {
Some(null_index) if branch_index >= null_index => branch_index - 1,
_ => branch_index,
}
}
}
Expand All @@ -124,20 +138,24 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> EnumAccess<'de>
where
V: DeserializeSeed<'de>,
{
let index = zag_i32(self.reader)?;
let index = usize::try_from(index).map_err(|e| Details::ConvertI32ToUsize(e, index))?;
let index = match self.branch_index {
Some(index) => index,
None => {
let index = zag_i32(self.reader)?;
usize::try_from(index).map_err(|e| Details::ConvertI32ToUsize(e, index))?
}
};
let schema = self.variants.get(index).ok_or(Details::GetUnionVariant {
index: index as i64,
num_variants: self.variants.len(),
})?;

let variant_index = self.get_variant_index(index);
Ok((
seed.deserialize(IdentifierDeserializer::index(index as u32))?,
seed.deserialize(IdentifierDeserializer::index(variant_index as u32))?,
UnionVariantAccess::new(schema, self.reader, self.config)?,
))
}
}

pub struct UnionVariantAccess<'s, 'r, R: Read, S: Borrow<Schema>> {
schema: &'s Schema,
reader: &'r mut R,
Expand Down
36 changes: 27 additions & 9 deletions avro/src/serde/deser_schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@ mod tuple;

use block::BlockDeserializer;
use enums::PlainEnumDeserializer;
use enums::UnionEnumDeserializer;
use record::RecordDeserializer;
use tuple::{ManyTupleDeserializer, OneTupleDeserializer};

use crate::serde::deser_schema::enums::UnionEnumDeserializer;

/// Configure the deserializer.
#[derive(Debug)]
pub struct Config<'s, S: Borrow<Schema>> {
Expand Down Expand Up @@ -79,6 +78,7 @@ pub struct SchemaAwareDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
/// This schema is guaranteed to not be a [`Schema::Ref`].
schema: &'s Schema,
config: Config<'s, S>,
branch_index: Option<usize>,
}

impl<'s, 'r, R: Read, S: Borrow<Schema>> SchemaAwareDeserializer<'s, 'r, R, S> {
Expand All @@ -96,12 +96,14 @@ impl<'s, 'r, R: Read, S: Borrow<Schema>> SchemaAwareDeserializer<'s, 'r, R, S> {
reader,
schema,
config,
branch_index: None,
})
} else {
Ok(Self {
reader,
schema,
config,
branch_index: None,
})
}
}
Expand Down Expand Up @@ -129,12 +131,22 @@ impl<'s, 'r, R: Read, S: Borrow<Schema>> SchemaAwareDeserializer<'s, 'r, R, S> {
Ok(self)
}

fn with_branch_index(mut self, branch_index: usize) -> Self {
self.branch_index = Some(branch_index);
self
}

/// Read the union and create a new deserializer with the existing reader and config.
///
/// This will resolve the read schema if it is a reference.
fn with_union(self, schema: &'s UnionSchema) -> Result<Self, Error> {
let index = zag_i32(self.reader)?;
let index = usize::try_from(index).map_err(|e| Details::ConvertI32ToUsize(e, index))?;
let index = match self.branch_index {
Some(index) => index,
None => {
let index = zag_i32(self.reader)?;
usize::try_from(index).map_err(|e| Details::ConvertI32ToUsize(e, index))?
}
};
let variant = schema.get_variant(index)?;
self.with_different_schema(variant)
}
Expand Down Expand Up @@ -524,7 +536,6 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de>
V: Visitor<'de>,
{
if let Schema::Union(union) = self.schema
&& union.variants().len() == 2
&& union.is_nullable()
{
let index = zag_i32(self.reader)?;
Expand All @@ -533,7 +544,11 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de>
if let Schema::Null = schema {
visitor.visit_none()
} else {
visitor.visit_some(self.with_different_schema(schema)?)
if union.variants().len() == 2 {
visitor.visit_some(self.with_different_schema(schema)?)
} else {
visitor.visit_some(self.with_branch_index(index))
}
}
} else {
Err(self.error("option", "Expected Schema::Union([Schema::Null, _])"))
Expand Down Expand Up @@ -708,9 +723,12 @@ impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> Deserializer<'de>
Schema::Enum(schema) => {
visitor.visit_enum(PlainEnumDeserializer::new(self.reader, schema))
}
Schema::Union(union) => {
visitor.visit_enum(UnionEnumDeserializer::new(self.reader, union, self.config))
}
Schema::Union(union) => visitor.visit_enum(UnionEnumDeserializer::new(
self.reader,
union,
self.config,
self.branch_index,
)),
_ => Err(self.error("enum", "Expected Schema::Enum | Schema::Union")),
}
}
Expand Down
Loading