diff --git a/encodings/fsst/src/kernel.rs b/encodings/fsst/src/kernel.rs index 386da3b0f60..80ef8a1d282 100644 --- a/encodings/fsst/src/kernel.rs +++ b/encodings/fsst/src/kernel.rs @@ -30,6 +30,7 @@ mod tests { use vortex_array::VortexSessionExecute; use vortex_array::arrays::FilterArray; use vortex_array::arrays::PrimitiveArray; + use vortex_array::arrays::ScalarFn; use vortex_array::arrays::varbin::builder::VarBinBuilder; use vortex_array::assert_arrays_eq; use vortex_array::dtype::DType; @@ -230,4 +231,38 @@ mod tests { assert_arrays_eq!(result, expected); Ok(()) } + + /// The validity of `byte_length(fsst)` is exactly the validity of the FSST input, so computing + /// it must read the FSST validity buffer directly rather than decompressing the codes into a + /// canonical `VarBinView` just to extract the validity mask. + #[test] + fn test_fsst_byte_length_validity() -> VortexResult<()> { + let mut builder = VarBinBuilder::::with_capacity(5); + builder.append_value(b"hello"); + builder.append_null(); + builder.append_value("Пуховички"); + builder.append_null(); + builder.append_value(b""); + + let varbin = builder.finish(DType::Utf8(Nullability::Nullable)); + let compressor = fsst_train_compressor(&varbin); + let len = varbin.len(); + let dtype = varbin.dtype().clone(); + let mut ctx = SESSION.create_execution_ctx(); + let fsst = fsst_compress(varbin, len, &dtype, &compressor, &mut ctx).into_array(); + assert!(fsst.is::()); + + // `apply` keeps `byte_length` lazy as a `ScalarFnArray` wrapping the FSST child, so asking + // for its validity exercises the `ByteLength::validity` rule rather than executing the + // function over canonical data. + let result = fsst.clone().apply(&byte_length(root()))?; + assert!(result.is::()); + + // The result validity matches the FSST's own validity, sourced straight from the array. + let result_validity = result.validity()?.to_array(len); + let fsst_validity = fsst.validity()?.to_array(len); + assert_arrays_eq!(result_validity, fsst_validity); + + Ok(()) + } } diff --git a/vortex-array/src/arrays/scalar_fn/vtable/validity.rs b/vortex-array/src/arrays/scalar_fn/vtable/validity.rs index 2ac376155e3..df72e6a6793 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/validity.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/validity.rs @@ -15,15 +15,22 @@ use crate::arrays::scalar_fn::vtable::FakeEq; use crate::arrays::scalar_fn::vtable::ScalarFn; use crate::expr::Expression; use crate::expr::lit; +use crate::scalar::Scalar; use crate::scalar_fn::TypedScalarFnInstance; use crate::scalar_fn::VecExecutionArgs; use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::root::Root; use crate::validity::Validity; -/// Execute an expression tree recursively. +/// Execute a validity expression tree recursively. /// /// This assumes all leaf expressions are either ArrayExpr (wrapping actual arrays) or Literals. +/// +/// Evaluation is eager on purpose: it lets a validity expression that is in fact fully valid (or +/// fully invalid) fold down to a constant, which the caller then collapses into the most specific +/// [`Validity`] variant. That keeps `Validity::no_nulls` (and friends) precise. A lazy tree would +/// hide such constants behind an unevaluated `ScalarFn` array, leaving `no_nulls` conservatively +/// false even when there are provably no nulls. fn execute_expr(expr: &Expression, row_count: usize) -> VortexResult { let mut ctx = LEGACY_SESSION.create_execution_ctx(); @@ -50,8 +57,29 @@ fn execute_expr(expr: &Expression, row_count: usize) -> VortexResult { Ok(expr.scalar_fn().execute(&args, &mut ctx)?.into_array()) } +/// Collapse a constant boolean validity scalar into the most specific [`Validity`] variant. +/// +/// Returns `None` if the scalar is not a definite boolean (e.g. a null), in which case the caller +/// should keep the validity as an array. +fn constant_validity(scalar: &Scalar) -> Option { + scalar.as_bool().value().map(|valid| { + if valid { + Validity::AllValid + } else { + Validity::AllInvalid + } + }) +} + impl ValidityVTable for ScalarFn { fn validity(array: ArrayView<'_, ScalarFn>) -> VortexResult { + // A non-nullable result dtype guarantees there are no nulls, so we can skip building and + // evaluating any validity expression entirely. This also keeps downstream `no_nulls` + // fast-paths intact instead of handing them a constant-true validity array. + if !array.dtype().is_nullable() { + return Ok(Validity::NonNullable); + } + let inputs: Vec<_> = array .iter_children() .map(|child| { @@ -68,7 +96,53 @@ impl ValidityVTable for ScalarFn { let expr = Expression::try_new(array.scalar_fn().clone(), inputs)?; let validity_expr = array.scalar_fn().validity(&expr)?; - // Execute the validity expression. All leaves are ArrayExpr nodes. - Ok(Validity::Array(execute_expr(&validity_expr, array.len())?)) + // A literal validity expression collapses to a constant validity without evaluating + // anything (e.g. functions whose result is always valid return `lit(true)`). + if let Some(scalar) = validity_expr.as_opt::() + && let Some(validity) = constant_validity(scalar) + { + return Ok(validity); + } + + // Otherwise evaluate the validity expression. All leaves are ArrayExpr or Literal nodes. + let validity_array = execute_expr(&validity_expr, array.len())?; + + // Collapse a constant result into the most specific variant so that fast-paths keying off + // `Validity::NonNullable | AllValid | AllInvalid` are not defeated by a constant array. + if let Some(scalar) = validity_array.as_constant() + && let Some(validity) = constant_validity(&scalar) + { + return Ok(validity); + } + + Ok(Validity::Array(validity_array)) + } +} + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::arrays::BoolArray; + use crate::arrays::StructArray; + use crate::expr::col; + use crate::expr::eq; + + #[test] + fn compound_validity_evaluates_correctly() -> VortexResult<()> { + let a = BoolArray::from_iter([Some(true), None, Some(false)]).into_array(); + let b = BoolArray::from_iter([Some(true), Some(true), None]).into_array(); + let struct_arr = StructArray::from_fields(&[("a", a), ("b", b)])?.into_array(); + + // `a == b` has validity `and(valid(a), valid(b))`: null in either operand is null. + let result = struct_arr.apply(&eq(col("a"), col("b")))?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert!(result.is_valid(0, &mut ctx)?); + assert!(!result.is_valid(1, &mut ctx)?); + assert!(!result.is_valid(2, &mut ctx)?); + Ok(()) } } diff --git a/vortex-array/src/scalar_fn/fns/byte_length.rs b/vortex-array/src/scalar_fn/fns/byte_length.rs index 13a4f3158b5..72c0bcb47be 100644 --- a/vortex-array/src/scalar_fn/fns/byte_length.rs +++ b/vortex-array/src/scalar_fn/fns/byte_length.rs @@ -24,6 +24,7 @@ use crate::arrays::varbinview::VarBinViewArrayExt; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; +use crate::expr::Expression; use crate::kernel::ExecuteParentKernel; use crate::scalar::Scalar; use crate::scalar_fn::Arity; @@ -122,6 +123,19 @@ impl ScalarFnVTable for ByteLength { } } + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + // byte_length is null-preserving: the result is null exactly when the input is null, so + // the result's validity is the input's validity. Returning the input's validity directly + // avoids computing any byte lengths just to extract the validity mask. For encodings like + // FSST this means the validity buffer is read straight from the array (via is_not_null) + // instead of decompressing the codes into a canonical VarBinView. + Ok(Some(expression.child(0).validity()?)) + } + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } @@ -183,6 +197,7 @@ mod tests { use crate::expr::byte_length; use crate::expr::root; use crate::scalar::Scalar; + use crate::validity::Validity; #[rstest] #[case(VarBinArray::from_strs(vec!["hello", "world", ""]).into_array(), vec![5u64, 5, 0])] @@ -245,4 +260,45 @@ mod tests { let expr = byte_length(root()); assert_eq!(expr.to_string(), "vortex.byte_length($)"); } + + #[test] + fn test_validity_pushes_down_to_input() -> VortexResult<()> { + // byte_length is null-preserving, so its validity must push down to the input's validity + // rather than wrapping the whole `byte_length` call in `is_not_null`. The latter would + // force the function to be evaluated (and any compressed input decompressed) just to read + // the validity mask. + let validity = byte_length(root()).validity()?; + assert_eq!(validity.to_string(), "is_not_null($)"); + Ok(()) + } + + #[test] + fn test_non_nullable_validity_is_non_nullable() -> VortexResult<()> { + // A non-nullable result dtype short-circuits to `NonNullable` without evaluating any + // validity expression. + let array = VarBinViewArray::from_iter_str(["a", "bb"]).into_array(); + let result = array.apply(&byte_length(root()))?; + assert!(matches!(result.validity()?, Validity::NonNullable)); + Ok(()) + } + + #[test] + fn test_nullable_all_valid_validity_collapses_to_all_valid() -> VortexResult<()> { + // A non-null nullable constant is all-valid, which should collapse to `AllValid` rather + // than a constant-true validity array. + let array = + ConstantArray::new(Scalar::utf8("hello", Nullability::Nullable), 3).into_array(); + let result = array.apply(&byte_length(root()))?; + assert!(matches!(result.validity()?, Validity::AllValid)); + Ok(()) + } + + #[test] + fn test_null_constant_validity_collapses_to_all_invalid() -> VortexResult<()> { + let null_scalar = Scalar::null(DType::Utf8(Nullability::Nullable)); + let array = ConstantArray::new(null_scalar, 2).into_array(); + let result = array.apply(&byte_length(root()))?; + assert!(matches!(result.validity()?, Validity::AllInvalid)); + Ok(()) + } }