Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions encodings/datetime-parts/src/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,11 @@ mod test {
&mut ctx,
)?;

assert!(
date_times
.as_array()
.validity()?
.mask_eq(&validity, &mut ctx)?
);
assert!(date_times.as_array().validity()?.mask_eq(
&validity,
milliseconds.len(),
&mut ctx
)?);

let dtype = date_times.dtype().clone();
let parts = DateTimePartsParts {
Expand All @@ -163,7 +162,11 @@ mod test {
.execute::<PrimitiveArray>(&mut ctx)?;

assert_arrays_eq!(primitive_values, milliseconds);
assert!(primitive_values.validity()?.mask_eq(&validity, &mut ctx)?);
assert!(primitive_values.validity()?.mask_eq(
&validity,
primitive_values.len(),
&mut ctx
)?);
Ok(())
}
}
2 changes: 1 addition & 1 deletion encodings/datetime-parts/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ mod tests {
days_prim
.validity()
.vortex_expect("days validity should be derivable")
.mask_eq(&validity, &mut ctx)
.mask_eq(&validity, days_prim.len(), &mut ctx)
.unwrap()
);
let seconds_prim = seconds.execute::<PrimitiveArray>(&mut ctx).unwrap();
Expand Down
1 change: 1 addition & 0 deletions encodings/pco/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ fn test_validity_and_multiple_chunks_and_pages() {
.unwrap()
.mask_eq(
&Validity::Array(BoolArray::from_iter(vec![true, false, true]).into_array()),
primitive.len(),
&mut ctx,
)
.unwrap()
Expand Down
3 changes: 2 additions & 1 deletion encodings/zstd/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ fn test_zstd_with_validity_and_multi_frame() {
decompressed
.validity()
.unwrap()
.mask_eq(&array.validity().unwrap(), &mut ctx)
.mask_eq(&array.validity().unwrap(), decompressed.len(), &mut ctx)
.unwrap()
);

Expand All @@ -106,6 +106,7 @@ fn test_zstd_with_validity_and_multi_frame() {
.unwrap()
.mask_eq(
&Validity::Array(BoolArray::from_iter(vec![false, true, false]).into_array()),
primitive.len(),
&mut ctx
)
.unwrap()
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/arrays/masked/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ fn test_masked_child_preserves_length(#[case] validity: Validity) {
array
.validity()
.vortex_expect("masked validity should be derivable")
.mask_eq(&validity, &mut ctx)
.mask_eq(&validity, array.len(), &mut ctx)
.unwrap(),
);
}
10 changes: 5 additions & 5 deletions vortex-array/src/builders/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,11 @@ mod tests {
#[expect(deprecated)]
let into_canon = chunk.to_bool();

assert!(
canon_into
.validity()?
.mask_eq(&into_canon.validity()?, &mut ctx)?
);
assert!(canon_into.validity()?.mask_eq(
&into_canon.validity()?,
canon_into.len(),
&mut ctx
)?);
assert_eq!(canon_into.to_bit_buffer(), into_canon.to_bit_buffer());
Ok(())
}
Expand Down
1 change: 1 addition & 0 deletions vortex-array/src/builders/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ mod tests {
&expected
.validity()
.vortex_expect("list validity should be derivable"),
actual.len(),
&mut ctx,
)
.unwrap(),
Expand Down
75 changes: 62 additions & 13 deletions vortex-array/src/validity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,25 @@ impl Validity {
}
}

/// Compare two Validity values of the same length by executing them into masks if necessary.
pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
/// Compare the logical masks of two Validity values of the given length, executing them
/// into [`Mask`]s if necessary.
///
/// This compares *masks*, not variants: [`Validity::NonNullable`] equals
/// [`Validity::AllValid`], and a [`Validity::Array`] that resolves to all-true equals both.
pub fn mask_eq(
&self,
other: &Validity,
length: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<bool> {
match (self, other) {
(Validity::NonNullable, Validity::NonNullable) => Ok(true),
(Validity::AllValid, Validity::AllValid) => Ok(true),
(Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
(Validity::Array(a), Validity::Array(b)) => {
let a = a.clone().execute::<Mask>(ctx)?;
let b = b.clone().execute::<Mask>(ctx)?;
Ok(a == b)
}
_ => Ok(false),
// Fast paths that avoid executing: constant variants with known-equal masks.
(
Validity::NonNullable | Validity::AllValid,
Validity::NonNullable | Validity::AllValid,
)
| (Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
_ => Ok(self.execute_mask(length, ctx)? == other.execute_mask(length, ctx)?),
}
}

Expand Down Expand Up @@ -703,7 +710,7 @@ mod tests {
validity
.patch(len, 0, &indices, &patches, &mut ctx,)
.unwrap()
.mask_eq(&expected, &mut ctx)
.mask_eq(&expected, len, &mut ctx)
.unwrap()
);
}
Expand Down Expand Up @@ -768,8 +775,50 @@ mod tests {
validity
.take(&indices)
.unwrap()
.mask_eq(&expected, &mut ctx)
.mask_eq(&expected, indices.len(), &mut ctx)
.unwrap()
);
}

#[rstest]
// Mixed constant variants with equal masks.
#[case(Validity::NonNullable, Validity::AllValid, true)]
#[case(Validity::AllValid, Validity::NonNullable, true)]
#[case(Validity::AllValid, Validity::AllInvalid, false)]
#[case(Validity::NonNullable, Validity::AllInvalid, false)]
// An array that resolves to a constant mask must equal the constant variant.
#[case(
Validity::Array(BoolArray::from_iter([true, true, true]).into_array()),
Validity::AllValid,
true
)]
#[case(
Validity::NonNullable,
Validity::Array(BoolArray::from_iter([true, true, true]).into_array()),
true
)]
#[case(
Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
Validity::AllInvalid,
true
)]
#[case(
Validity::Array(BoolArray::from_iter([true, false, true]).into_array()),
Validity::AllValid,
false
)]
#[case(
Validity::Array(BoolArray::from_iter([true, false, true]).into_array()),
Validity::AllInvalid,
false
)]
fn mask_eq_mixed_variants(
#[case] lhs: Validity,
#[case] rhs: Validity,
#[case] expected: bool,
) -> vortex_error::VortexResult<()> {
let mut ctx = LEGACY_SESSION.create_execution_ctx();
assert_eq!(lhs.mask_eq(&rhs, 3, &mut ctx)?, expected);
Ok(())
}
}
10 changes: 5 additions & 5 deletions vortex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,11 @@ mod test {
let mut ctx = LEGACY_SESSION.create_execution_ctx();

let recovered_primitive = recovered_array.execute::<PrimitiveArray>(&mut ctx)?;
assert!(
recovered_primitive
.validity()?
.mask_eq(&array.validity()?, &mut ctx)?
);
assert!(recovered_primitive.validity()?.mask_eq(
&array.validity()?,
array.len(),
&mut ctx
)?);
assert_eq!(
recovered_primitive.to_buffer::<u64>(),
array.to_buffer::<u64>()
Expand Down
Loading