Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions encodings/fsst/src/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ mod tests {
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::FilterArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFn;
use vortex_array::arrays::varbin::builder::VarBinBuilder;
use vortex_array::assert_arrays_eq;
use vortex_array::dtype::DType;
Expand Down Expand Up @@ -230,4 +231,38 @@ mod tests {
assert_arrays_eq!(result, expected);
Ok(())
}

/// The validity of `byte_length(fsst)` is exactly the validity of the FSST input, so computing
/// it must read the FSST validity buffer directly rather than decompressing the codes into a
/// canonical `VarBinView` just to extract the validity mask.
#[test]
fn test_fsst_byte_length_validity() -> VortexResult<()> {
let mut builder = VarBinBuilder::<i32>::with_capacity(5);
builder.append_value(b"hello");
builder.append_null();
builder.append_value("Пуховички");
builder.append_null();
builder.append_value(b"");

let varbin = builder.finish(DType::Utf8(Nullability::Nullable));
let compressor = fsst_train_compressor(&varbin);
let len = varbin.len();
let dtype = varbin.dtype().clone();
let mut ctx = SESSION.create_execution_ctx();
let fsst = fsst_compress(varbin, len, &dtype, &compressor, &mut ctx).into_array();
assert!(fsst.is::<FSST>());

// `apply` keeps `byte_length` lazy as a `ScalarFnArray` wrapping the FSST child, so asking
// for its validity exercises the `ByteLength::validity` rule rather than executing the
// function over canonical data.
let result = fsst.clone().apply(&byte_length(root()))?;
assert!(result.is::<ScalarFn>());

// The result validity matches the FSST's own validity, sourced straight from the array.
let result_validity = result.validity()?.to_array(len);
let fsst_validity = fsst.validity()?.to_array(len);
assert_arrays_eq!(result_validity, fsst_validity);

Ok(())
}
}
80 changes: 77 additions & 3 deletions vortex-array/src/arrays/scalar_fn/vtable/validity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,22 @@ use crate::arrays::scalar_fn::vtable::FakeEq;
use crate::arrays::scalar_fn::vtable::ScalarFn;
use crate::expr::Expression;
use crate::expr::lit;
use crate::scalar::Scalar;
use crate::scalar_fn::TypedScalarFnInstance;
use crate::scalar_fn::VecExecutionArgs;
use crate::scalar_fn::fns::literal::Literal;
use crate::scalar_fn::fns::root::Root;
use crate::validity::Validity;

/// Execute an expression tree recursively.
/// Execute a validity expression tree recursively.
///
/// This assumes all leaf expressions are either ArrayExpr (wrapping actual arrays) or Literals.
///
/// Evaluation is eager on purpose: it lets a validity expression that is in fact fully valid (or
/// fully invalid) fold down to a constant, which the caller then collapses into the most specific
/// [`Validity`] variant. That keeps `Validity::no_nulls` (and friends) precise. A lazy tree would
/// hide such constants behind an unevaluated `ScalarFn` array, leaving `no_nulls` conservatively
/// false even when there are provably no nulls.
fn execute_expr(expr: &Expression, row_count: usize) -> VortexResult<ArrayRef> {
let mut ctx = LEGACY_SESSION.create_execution_ctx();

Expand All @@ -50,8 +57,29 @@ fn execute_expr(expr: &Expression, row_count: usize) -> VortexResult<ArrayRef> {
Ok(expr.scalar_fn().execute(&args, &mut ctx)?.into_array())
}

/// Collapse a constant boolean validity scalar into the most specific [`Validity`] variant.
///
/// Returns `None` if the scalar is not a definite boolean (e.g. a null), in which case the caller
/// should keep the validity as an array.
fn constant_validity(scalar: &Scalar) -> Option<Validity> {
scalar.as_bool().value().map(|valid| {
if valid {
Validity::AllValid
} else {
Validity::AllInvalid
}
})
}

impl ValidityVTable<ScalarFn> for ScalarFn {
fn validity(array: ArrayView<'_, ScalarFn>) -> VortexResult<Validity> {
// A non-nullable result dtype guarantees there are no nulls, so we can skip building and
// evaluating any validity expression entirely. This also keeps downstream `no_nulls`
// fast-paths intact instead of handing them a constant-true validity array.
if !array.dtype().is_nullable() {
return Ok(Validity::NonNullable);
}

let inputs: Vec<_> = array
.iter_children()
.map(|child| {
Expand All @@ -68,7 +96,53 @@ impl ValidityVTable<ScalarFn> for ScalarFn {
let expr = Expression::try_new(array.scalar_fn().clone(), inputs)?;
let validity_expr = array.scalar_fn().validity(&expr)?;

// Execute the validity expression. All leaves are ArrayExpr nodes.
Ok(Validity::Array(execute_expr(&validity_expr, array.len())?))
// A literal validity expression collapses to a constant validity without evaluating
// anything (e.g. functions whose result is always valid return `lit(true)`).
if let Some(scalar) = validity_expr.as_opt::<Literal>()
&& let Some(validity) = constant_validity(scalar)
{
return Ok(validity);
}

// Otherwise evaluate the validity expression. All leaves are ArrayExpr or Literal nodes.
let validity_array = execute_expr(&validity_expr, array.len())?;

// Collapse a constant result into the most specific variant so that fast-paths keying off
// `Validity::NonNullable | AllValid | AllInvalid` are not defeated by a constant array.
if let Some(scalar) = validity_array.as_constant()
&& let Some(validity) = constant_validity(&scalar)
{
return Ok(validity);
}

Ok(Validity::Array(validity_array))
}
}

#[cfg(test)]
mod tests {
use vortex_error::VortexResult;

use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::VortexSessionExecute;
use crate::arrays::BoolArray;
use crate::arrays::StructArray;
use crate::expr::col;
use crate::expr::eq;

#[test]
fn compound_validity_evaluates_correctly() -> VortexResult<()> {
let a = BoolArray::from_iter([Some(true), None, Some(false)]).into_array();
let b = BoolArray::from_iter([Some(true), Some(true), None]).into_array();
let struct_arr = StructArray::from_fields(&[("a", a), ("b", b)])?.into_array();

// `a == b` has validity `and(valid(a), valid(b))`: null in either operand is null.
let result = struct_arr.apply(&eq(col("a"), col("b")))?;
let mut ctx = LEGACY_SESSION.create_execution_ctx();
assert!(result.is_valid(0, &mut ctx)?);
assert!(!result.is_valid(1, &mut ctx)?);
assert!(!result.is_valid(2, &mut ctx)?);
Ok(())
}
}
56 changes: 56 additions & 0 deletions vortex-array/src/scalar_fn/fns/byte_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::arrays::varbinview::VarBinViewArrayExt;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::expr::Expression;
use crate::kernel::ExecuteParentKernel;
use crate::scalar::Scalar;
use crate::scalar_fn::Arity;
Expand Down Expand Up @@ -122,6 +123,19 @@ impl ScalarFnVTable for ByteLength {
}
}

fn validity(
&self,
_options: &Self::Options,
expression: &Expression,
) -> VortexResult<Option<Expression>> {
// byte_length is null-preserving: the result is null exactly when the input is null, so
// the result's validity is the input's validity. Returning the input's validity directly
// avoids computing any byte lengths just to extract the validity mask. For encodings like
// FSST this means the validity buffer is read straight from the array (via is_not_null)
// instead of decompressing the codes into a canonical VarBinView.
Ok(Some(expression.child(0).validity()?))
}

fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
false
}
Expand Down Expand Up @@ -183,6 +197,7 @@ mod tests {
use crate::expr::byte_length;
use crate::expr::root;
use crate::scalar::Scalar;
use crate::validity::Validity;

#[rstest]
#[case(VarBinArray::from_strs(vec!["hello", "world", ""]).into_array(), vec![5u64, 5, 0])]
Expand Down Expand Up @@ -245,4 +260,45 @@ mod tests {
let expr = byte_length(root());
assert_eq!(expr.to_string(), "vortex.byte_length($)");
}

#[test]
fn test_validity_pushes_down_to_input() -> VortexResult<()> {
// byte_length is null-preserving, so its validity must push down to the input's validity
// rather than wrapping the whole `byte_length` call in `is_not_null`. The latter would
// force the function to be evaluated (and any compressed input decompressed) just to read
// the validity mask.
let validity = byte_length(root()).validity()?;
assert_eq!(validity.to_string(), "is_not_null($)");
Ok(())
}

#[test]
fn test_non_nullable_validity_is_non_nullable() -> VortexResult<()> {
// A non-nullable result dtype short-circuits to `NonNullable` without evaluating any
// validity expression.
let array = VarBinViewArray::from_iter_str(["a", "bb"]).into_array();
let result = array.apply(&byte_length(root()))?;
assert!(matches!(result.validity()?, Validity::NonNullable));
Ok(())
}

#[test]
fn test_nullable_all_valid_validity_collapses_to_all_valid() -> VortexResult<()> {
// A non-null nullable constant is all-valid, which should collapse to `AllValid` rather
// than a constant-true validity array.
let array =
ConstantArray::new(Scalar::utf8("hello", Nullability::Nullable), 3).into_array();
let result = array.apply(&byte_length(root()))?;
assert!(matches!(result.validity()?, Validity::AllValid));
Ok(())
}

#[test]
fn test_null_constant_validity_collapses_to_all_invalid() -> VortexResult<()> {
let null_scalar = Scalar::null(DType::Utf8(Nullability::Nullable));
let array = ConstantArray::new(null_scalar, 2).into_array();
let result = array.apply(&byte_length(root()))?;
assert!(matches!(result.validity()?, Validity::AllInvalid));
Ok(())
}
}
Loading