From e64f2a2a1658de94089946268ee6ef7abfb05a14 Mon Sep 17 00:00:00 2001 From: Mikhail Kot Date: Wed, 10 Jun 2026 18:02:53 +0100 Subject: [PATCH] Pushdown some expressions to Dict layout reader Signed-off-by: Mikhail Kot --- vortex-array/src/scalar_fn/mod.rs | 15 +++ vortex-layout/src/layouts/dict/reader.rs | 137 +++++++++++++++++++++-- 2 files changed, 145 insertions(+), 7 deletions(-) diff --git a/vortex-array/src/scalar_fn/mod.rs b/vortex-array/src/scalar_fn/mod.rs index 590ccb44224..710d3a95b48 100644 --- a/vortex-array/src/scalar_fn/mod.rs +++ b/vortex-array/src/scalar_fn/mod.rs @@ -48,3 +48,18 @@ mod sealed { /// This can be the **only** implementor for [`super::typed::DynScalarFn`]. impl Sealed for TypedScalarFnInstance {} } + +/* + * A scalar function has a negative cost if applying it to an array and + * canonicalizing is cheaper than canonicalizing an array and applying it. + * + * Example of negative cost expressions are byte_length() and get_item() since + * they don't depend on input size. + * + * Example of non-negative cost expression is like() + */ +pub fn is_negative_cost(id: ScalarFnId) -> bool { + id == Id::new_static("vortex.byte_length") + || id == Id::new_static("vortex.get_item") + || id == Id::new_static("vortex.literal") +} diff --git a/vortex-layout/src/layouts/dict/reader.rs b/vortex-layout/src/layouts/dict/reader.rs index 002b4b1e902..d64a2e47d59 100644 --- a/vortex-layout/src/layouts/dict/reader.rs +++ b/vortex-layout/src/layouts/dict/reader.rs @@ -19,8 +19,15 @@ use vortex_array::arrays::SharedArray; use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; use vortex_array::expr::Expression; +use vortex_array::expr::is_root; +use vortex_array::expr::label_is_fallible; +use vortex_array::expr::label_null_sensitive; use vortex_array::expr::root; +use vortex_array::expr::traversal::NodeExt; +use vortex_array::expr::traversal::Transformed; +use vortex_array::expr::traversal::TraversalOrder; use vortex_array::optimizer::ArrayOptimizer; +use vortex_array::scalar_fn::is_negative_cost; use vortex_error::VortexError; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -100,10 +107,7 @@ impl DictReader { ) .vortex_expect("must construct dict values array evaluation") .map_err(Arc::new) - .map(move |array| { - let array = array?; - Ok(SharedArray::new(array).into_array()) - }) + .map(move |array| Ok(SharedArray::new(array?).into_array())) .boxed() .shared() }) @@ -155,6 +159,49 @@ impl DictReader { } } +fn references_root(expr: &Expression) -> bool { + is_root(expr) || expr.children().iter().any(references_root) +} + +/// Split expression into two parts: +/// +/// left is the optional outer part that we want to apply to array after +/// canonicalizing. +/// right is the optional inner part that we want to apply to array before +/// canonicalizing. +/// +/// We want to push to array only if expression has a negative cost, is +/// infallible and null-insensitive. +fn split_expression_for_pushdown(expr: Expression) -> (Option, Option) { + let labelled_expr = expr.clone(); + let fallible = label_is_fallible(&labelled_expr); + let null_sensitive = label_null_sensitive(&labelled_expr); + let mut inner: Option = None; + + let outer = expr + .transform_down(|node| { + if is_negative_cost(node.id()) + && references_root(&node) + && !fallible.get(&node).copied().unwrap_or(true) + && !null_sensitive.get(&node).copied().unwrap_or(true) + { + inner = Some(node); + Ok(Transformed { + value: root(), + changed: true, + order: TraversalOrder::Skip, + }) + } else { + Ok(Transformed::no(node)) + } + }) + .vortex_expect("infallible") + .into_inner(); + + let outer = (!is_root(&outer)).then_some(outer); + (outer, inner) +} + impl LayoutReader for DictReader { fn name(&self) -> &Arc { &self.name @@ -229,13 +276,18 @@ impl LayoutReader for DictReader { mask: MaskFuture, ) -> VortexResult>> { // TODO: fix up expr partitioning with fallible & null sensitive annotations - let values_eval = self.values_array(); let codes_eval = self .codes .projection_evaluation(row_range, &root(), mask) .map_err(|err| err.with_context("While evaluating projection on codes"))?; - let expr = expr.clone(); + let (expr_outer, expr_inner) = split_expression_for_pushdown(expr.clone()); + + let values_eval = if let Some(inner) = expr_inner { + self.values_eval(inner) + } else { + self.values_array() + }; let all_values_referenced = self.layout.has_all_values_referenced(); Ok(async move { let (values, codes) = try_join!(values_eval.map_err(VortexError::from), codes_eval)?; @@ -252,7 +304,11 @@ impl LayoutReader for DictReader { .into_array() .optimize()?; - array.apply(&expr) + if let Some(expr) = expr_outer { + array.apply(&expr) + } else { + Ok(array) + } } .boxed()) } @@ -281,11 +337,20 @@ mod tests { use vortex_array::dtype::FieldName; use vortex_array::dtype::FieldNames; use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::expr::Expression; + use vortex_array::expr::byte_length; + use vortex_array::expr::cast; use vortex_array::expr::eq; use vortex_array::expr::is_not_null; + use vortex_array::expr::is_root; + use vortex_array::expr::like; use vortex_array::expr::lit; use vortex_array::expr::pack; use vortex_array::expr::root; + use vortex_array::expr::traversal::NodeExt; + use vortex_array::expr::traversal::Transformed; + use vortex_array::expr::traversal::TraversalOrder; use vortex_array::scalar_fn::session::ScalarFnSession; use vortex_array::session::ArraySession; use vortex_array::validity::Validity; @@ -296,6 +361,7 @@ mod tests { use vortex_io::session::RuntimeSessionExt; use vortex_session::VortexSession; + use super::split_expression_for_pushdown; use crate::LayoutId; use crate::LayoutRef; use crate::LayoutStrategy; @@ -542,4 +608,61 @@ mod tests { assert_arrays_eq!(actual_canonical, expected); }) } + + fn join_split_expr(initial: &Expression, outer: Option, inner: Option) { + let outer_expr = outer.unwrap_or_else(root); + let inner_expr = inner.unwrap_or_else(root); + let expected = outer_expr + .transform_down(|node| { + if !is_root(&node) { + return Ok(Transformed::no(node)); + } + Ok(Transformed { + value: inner_expr.clone(), + changed: true, + order: TraversalOrder::Skip, + }) + }) + .vortex_expect("infallible"); + assert_eq!(&expected.into_inner(), initial); + } + + #[test] + fn split_expr_cast_root() { + let (outer, inner) = split_expression_for_pushdown(root()); + assert_eq!(outer, None); + assert_eq!(inner, None); // Applying root to array is useless work + } + + #[test] + fn split_expr_partial_pushdown() { + let dtype = DType::Primitive(PType::U64, Nullability::NonNullable); + let expr = cast(byte_length(root()), dtype.clone()); + let (outer, inner) = split_expression_for_pushdown(expr.clone()); + // [0] = cast([1], dtype) + // [1] = byte_length(root) + assert_eq!(outer, Some(cast(root(), dtype))); + assert_eq!(inner, Some(byte_length(root()))); + join_split_expr(&expr, outer, inner); + } + + #[test] + fn split_expr_full_pushdown() { + let expr = byte_length(root()); + let (outer, inner) = split_expression_for_pushdown(expr.clone()); + assert_eq!(outer, None); + assert_eq!(inner, Some(byte_length(root()))); + join_split_expr(&expr, outer, inner); + } + + #[test] + fn split_expr_no_pushdown() { + // We can push down lit(), but it we replace + // lit() with root(), the semantics change. + let expr = like(root(), lit(1u64)); + let (outer, inner) = split_expression_for_pushdown(expr.clone()); + assert_eq!(outer, Some(expr.clone())); + assert_eq!(inner, None); + join_split_expr(&expr, outer, inner); + } }