From 4813becca03150ba6d0a649c38d6e69cf9068634 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Wed, 10 Jun 2026 12:50:22 +0000 Subject: [PATCH] Fix Validity::mask_eq semantics for mixed variants mask_eq previously returned false for any mixed-variant pairing without executing, e.g. comparing a Validity::Array that resolves to all-true against Validity::AllValid. As validity arrays become lazy, unresolved Array variants frequently hold constant masks, making this silently wrong rather than merely conservative. mask_eq now takes the validity length and compares logical masks: constant variants keep their no-execute fast paths, and any pairing involving an Array executes both sides via execute_mask and compares the resulting Masks. NonNullable and AllValid now compare equal, since their logical masks are identical. https://claude.ai/code/session_01VPQ7dfZtijfrsjAipwXvEj Signed-off-by: Joe Isaacs --- encodings/datetime-parts/src/canonical.rs | 17 ++--- encodings/datetime-parts/src/compress.rs | 2 +- encodings/pco/src/tests.rs | 1 + encodings/zstd/src/test.rs | 3 +- vortex-array/src/arrays/masked/tests.rs | 2 +- vortex-array/src/builders/bool.rs | 10 +-- vortex-array/src/builders/list.rs | 1 + vortex-array/src/validity.rs | 75 +++++++++++++++++++---- vortex/src/lib.rs | 10 +-- 9 files changed, 88 insertions(+), 33 deletions(-) diff --git a/encodings/datetime-parts/src/canonical.rs b/encodings/datetime-parts/src/canonical.rs index 78dc8e689c1..91211f84acc 100644 --- a/encodings/datetime-parts/src/canonical.rs +++ b/encodings/datetime-parts/src/canonical.rs @@ -143,12 +143,11 @@ mod test { &mut ctx, )?; - assert!( - date_times - .as_array() - .validity()? - .mask_eq(&validity, &mut ctx)? - ); + assert!(date_times.as_array().validity()?.mask_eq( + &validity, + milliseconds.len(), + &mut ctx + )?); let dtype = date_times.dtype().clone(); let parts = DateTimePartsParts { @@ -163,7 +162,11 @@ mod test { .execute::(&mut ctx)?; assert_arrays_eq!(primitive_values, milliseconds); - assert!(primitive_values.validity()?.mask_eq(&validity, &mut ctx)?); + assert!(primitive_values.validity()?.mask_eq( + &validity, + primitive_values.len(), + &mut ctx + )?); Ok(()) } } diff --git a/encodings/datetime-parts/src/compress.rs b/encodings/datetime-parts/src/compress.rs index 676e7bbfd43..6caf3c9f643 100644 --- a/encodings/datetime-parts/src/compress.rs +++ b/encodings/datetime-parts/src/compress.rs @@ -103,7 +103,7 @@ mod tests { days_prim .validity() .vortex_expect("days validity should be derivable") - .mask_eq(&validity, &mut ctx) + .mask_eq(&validity, days_prim.len(), &mut ctx) .unwrap() ); let seconds_prim = seconds.execute::(&mut ctx).unwrap(); diff --git a/encodings/pco/src/tests.rs b/encodings/pco/src/tests.rs index 6e5b4f9841b..20fec7dd9e0 100644 --- a/encodings/pco/src/tests.rs +++ b/encodings/pco/src/tests.rs @@ -149,6 +149,7 @@ fn test_validity_and_multiple_chunks_and_pages() { .unwrap() .mask_eq( &Validity::Array(BoolArray::from_iter(vec![true, false, true]).into_array()), + primitive.len(), &mut ctx, ) .unwrap() diff --git a/encodings/zstd/src/test.rs b/encodings/zstd/src/test.rs index 7ed22886b82..094c5ef60cc 100644 --- a/encodings/zstd/src/test.rs +++ b/encodings/zstd/src/test.rs @@ -89,7 +89,7 @@ fn test_zstd_with_validity_and_multi_frame() { decompressed .validity() .unwrap() - .mask_eq(&array.validity().unwrap(), &mut ctx) + .mask_eq(&array.validity().unwrap(), decompressed.len(), &mut ctx) .unwrap() ); @@ -106,6 +106,7 @@ fn test_zstd_with_validity_and_multi_frame() { .unwrap() .mask_eq( &Validity::Array(BoolArray::from_iter(vec![false, true, false]).into_array()), + primitive.len(), &mut ctx ) .unwrap() diff --git a/vortex-array/src/arrays/masked/tests.rs b/vortex-array/src/arrays/masked/tests.rs index 2721ba9b519..b26d101eb33 100644 --- a/vortex-array/src/arrays/masked/tests.rs +++ b/vortex-array/src/arrays/masked/tests.rs @@ -134,7 +134,7 @@ fn test_masked_child_preserves_length(#[case] validity: Validity) { array .validity() .vortex_expect("masked validity should be derivable") - .mask_eq(&validity, &mut ctx) + .mask_eq(&validity, array.len(), &mut ctx) .unwrap(), ); } diff --git a/vortex-array/src/builders/bool.rs b/vortex-array/src/builders/bool.rs index fdae7984844..e829a58f6ec 100644 --- a/vortex-array/src/builders/bool.rs +++ b/vortex-array/src/builders/bool.rs @@ -209,11 +209,11 @@ mod tests { #[expect(deprecated)] let into_canon = chunk.to_bool(); - assert!( - canon_into - .validity()? - .mask_eq(&into_canon.validity()?, &mut ctx)? - ); + assert!(canon_into.validity()?.mask_eq( + &into_canon.validity()?, + canon_into.len(), + &mut ctx + )?); assert_eq!(canon_into.to_bit_buffer(), into_canon.to_bit_buffer()); Ok(()) } diff --git a/vortex-array/src/builders/list.rs b/vortex-array/src/builders/list.rs index c7b506f228a..ac97230daa8 100644 --- a/vortex-array/src/builders/list.rs +++ b/vortex-array/src/builders/list.rs @@ -490,6 +490,7 @@ mod tests { &expected .validity() .vortex_expect("list validity should be derivable"), + actual.len(), &mut ctx, ) .unwrap(), diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index f3a77b4759e..73b5b08d5cb 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -263,18 +263,25 @@ impl Validity { } } - /// Compare two Validity values of the same length by executing them into masks if necessary. - pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult { + /// Compare the logical masks of two Validity values of the given length, executing them + /// into [`Mask`]s if necessary. + /// + /// This compares *masks*, not variants: [`Validity::NonNullable`] equals + /// [`Validity::AllValid`], and a [`Validity::Array`] that resolves to all-true equals both. + pub fn mask_eq( + &self, + other: &Validity, + length: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { match (self, other) { - (Validity::NonNullable, Validity::NonNullable) => Ok(true), - (Validity::AllValid, Validity::AllValid) => Ok(true), - (Validity::AllInvalid, Validity::AllInvalid) => Ok(true), - (Validity::Array(a), Validity::Array(b)) => { - let a = a.clone().execute::(ctx)?; - let b = b.clone().execute::(ctx)?; - Ok(a == b) - } - _ => Ok(false), + // Fast paths that avoid executing: constant variants with known-equal masks. + ( + Validity::NonNullable | Validity::AllValid, + Validity::NonNullable | Validity::AllValid, + ) + | (Validity::AllInvalid, Validity::AllInvalid) => Ok(true), + _ => Ok(self.execute_mask(length, ctx)? == other.execute_mask(length, ctx)?), } } @@ -703,7 +710,7 @@ mod tests { validity .patch(len, 0, &indices, &patches, &mut ctx,) .unwrap() - .mask_eq(&expected, &mut ctx) + .mask_eq(&expected, len, &mut ctx) .unwrap() ); } @@ -768,8 +775,50 @@ mod tests { validity .take(&indices) .unwrap() - .mask_eq(&expected, &mut ctx) + .mask_eq(&expected, indices.len(), &mut ctx) .unwrap() ); } + + #[rstest] + // Mixed constant variants with equal masks. + #[case(Validity::NonNullable, Validity::AllValid, true)] + #[case(Validity::AllValid, Validity::NonNullable, true)] + #[case(Validity::AllValid, Validity::AllInvalid, false)] + #[case(Validity::NonNullable, Validity::AllInvalid, false)] + // An array that resolves to a constant mask must equal the constant variant. + #[case( + Validity::Array(BoolArray::from_iter([true, true, true]).into_array()), + Validity::AllValid, + true + )] + #[case( + Validity::NonNullable, + Validity::Array(BoolArray::from_iter([true, true, true]).into_array()), + true + )] + #[case( + Validity::Array(BoolArray::from_iter([false, false, false]).into_array()), + Validity::AllInvalid, + true + )] + #[case( + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + Validity::AllValid, + false + )] + #[case( + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + Validity::AllInvalid, + false + )] + fn mask_eq_mixed_variants( + #[case] lhs: Validity, + #[case] rhs: Validity, + #[case] expected: bool, + ) -> vortex_error::VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(lhs.mask_eq(&rhs, 3, &mut ctx)?, expected); + Ok(()) + } } diff --git a/vortex/src/lib.rs b/vortex/src/lib.rs index 3f7016be38f..baf0d0ae761 100644 --- a/vortex/src/lib.rs +++ b/vortex/src/lib.rs @@ -345,11 +345,11 @@ mod test { let mut ctx = LEGACY_SESSION.create_execution_ctx(); let recovered_primitive = recovered_array.execute::(&mut ctx)?; - assert!( - recovered_primitive - .validity()? - .mask_eq(&array.validity()?, &mut ctx)? - ); + assert!(recovered_primitive.validity()?.mask_eq( + &array.validity()?, + array.len(), + &mut ctx + )?); assert_eq!( recovered_primitive.to_buffer::(), array.to_buffer::()