diff --git a/vortex-array/src/expr/pruning/mod.rs b/vortex-array/src/expr/pruning/mod.rs index 7c20508b7a8..bbcfa5942a0 100644 --- a/vortex-array/src/expr/pruning/mod.rs +++ b/vortex-array/src/expr/pruning/mod.rs @@ -6,6 +6,7 @@ mod relation; pub use pruning_expr::RequiredStats; pub use pruning_expr::checked_pruning_expr; +pub use pruning_expr::checked_pruning_expr_with_session; pub use pruning_expr::field_path_stat_field_name; pub use relation::Relation; diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 00d29fbcf99..54c9666c283 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -5,18 +5,36 @@ use std::cell::RefCell; use std::iter; use itertools::Itertools; +use vortex_error::VortexResult; +use vortex_session::VortexSession; use vortex_utils::aliases::hash_map::HashMap; use super::relation::Relation; +use crate::aggregate_fn::fns::all_nan::AllNan; +use crate::aggregate_fn::fns::all_non_nan::AllNonNan; +use crate::aggregate_fn::fns::all_non_null::AllNonNull; +use crate::aggregate_fn::fns::all_null::AllNull; +use crate::aggregate_fn::fns::nan_count::NanCount; +use crate::dtype::DType; use crate::dtype::Field; use crate::dtype::FieldName; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; use crate::expr::Expression; use crate::expr::StatsCatalog; +use crate::expr::analysis::referenced_field_paths; +use crate::expr::eq; use crate::expr::get_item; +use crate::expr::lit; use crate::expr::root; use crate::expr::stats::Stat; +use crate::expr::traversal::NodeExt; +use crate::expr::traversal::Transformed; +use crate::scalar::Scalar; +use crate::scalar_fn::EmptyOptions; +use crate::scalar_fn::ScalarFnVTableExt; +use crate::scalar_fn::fns::stat::StatFn; +use crate::scalar_fn::internal::row_count::RowCount; pub type RequiredStats = Relation; @@ -113,6 +131,163 @@ pub fn checked_pruning_expr( Some((expr, relation)) } +/// Build a pruning expression using session-registered stats rewrite rules. +/// +/// The returned expression is lowered to the same stats-table field references as +/// [`checked_pruning_expr`]. If a rewrite asks for a stat that is not present in +/// `available_stats`, this returns `Ok(None)`. +pub fn checked_pruning_expr_with_session( + expr: &Expression, + scope: &DType, + available_stats: &FieldPathSet, + session: &VortexSession, +) -> VortexResult> { + let Some(predicate) = expr.falsify(scope, session)? else { + return Ok(None); + }; + + lower_stat_fns(predicate, scope, available_stats) +} + +fn lower_stat_fns( + predicate: Expression, + scope: &DType, + available_stats: &FieldPathSet, +) -> VortexResult> { + let mut required_stats = Relation::new(); + let mut missing_stat = false; + let lowered = predicate + .transform_down(|expr| { + if !expr.is::() { + return Ok(Transformed::no(expr)); + } + + if let Some(lowered) = + lower_stat_fn(&expr, scope, available_stats, &mut required_stats)? + { + return Ok(Transformed::yes(lowered)); + } + + missing_stat = true; + let dtype = expr.return_dtype(scope)?; + Ok(Transformed::yes(null_expr(dtype))) + })? + .into_inner(); + + if missing_stat { + return Ok(None); + } + + Ok(Some((lowered, required_stats))) +} + +fn lower_stat_fn( + expr: &Expression, + scope: &DType, + available_stats: &FieldPathSet, + required_stats: &mut RequiredStats, +) -> VortexResult> { + let options = expr.as_::(); + let aggregate_fn = options.aggregate_fn(); + let input = expr.child(0); + let input_dtype = input.return_dtype(scope)?; + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(Some(lit(false))); + } + return lower_stat_ref( + input, + Stat::NaNCount, + scope, + available_stats, + required_stats, + ) + .map(|stat| stat.map(|stat| eq(stat, row_count_expr()))); + } + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(Some(lit(true))); + } + return lower_stat_ref( + input, + Stat::NaNCount, + scope, + available_stats, + required_stats, + ) + .map(|stat| stat.map(|stat| eq(stat, lit(0u64)))); + } + + if aggregate_fn.is::() && !has_nans(&input_dtype) { + return Ok(Some(lit(0u64))); + } + + if aggregate_fn.is::() { + return lower_stat_ref( + input, + Stat::NullCount, + scope, + available_stats, + required_stats, + ) + .map(|stat| stat.map(|stat| eq(stat, row_count_expr()))); + } + + if aggregate_fn.is::() { + return lower_stat_ref( + input, + Stat::NullCount, + scope, + available_stats, + required_stats, + ) + .map(|stat| stat.map(|stat| eq(stat, lit(0u64)))); + } + + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; + + lower_stat_ref(input, stat, scope, available_stats, required_stats) +} + +fn lower_stat_ref( + input: &Expression, + stat: Stat, + scope: &DType, + available_stats: &FieldPathSet, + required_stats: &mut RequiredStats, +) -> VortexResult> { + let field_paths = referenced_field_paths(input, scope)?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + let stat_path = field_path.clone().push(stat.name()); + if !available_stats.contains(&stat_path) { + return Ok(None); + } + + required_stats.insert(field_path.clone(), stat); + Ok(Some(get_item( + field_path_stat_field_name(field_path, stat), + root(), + ))) +} + +fn row_count_expr() -> Expression { + RowCount.new_expr(EmptyOptions, []) +} + +fn null_expr(dtype: DType) -> Expression { + lit(Scalar::null(dtype.as_nullable())) +} + +fn has_nans(dtype: &DType) -> bool { + matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) +} + #[cfg(test)] mod tests { use rstest::fixture; diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index ceb085e0815..3d4cfeb6111 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -21,7 +21,7 @@ pub use stats_set::*; mod array; pub mod expr; pub mod flatbuffers; -pub(crate) mod rewrite; +pub mod rewrite; pub mod session; mod stats_set; diff --git a/vortex-array/src/stats/rewrite.rs b/vortex-array/src/stats/rewrite.rs index bf342a95cdd..c01829f7c90 100644 --- a/vortex-array/src/stats/rewrite.rs +++ b/vortex-array/src/stats/rewrite.rs @@ -21,7 +21,7 @@ mod builtins; pub(crate) use builtins::register_builtins; /// Shared reference to a stats rewrite rule. -pub(crate) type StatsRewriteRuleRef = Arc; +pub type StatsRewriteRuleRef = Arc; /// A plugin-provided rule that rewrites predicates into stats-backed proof expressions. /// @@ -29,8 +29,7 @@ pub(crate) type StatsRewriteRuleRef = Arc; /// current stats scope. A satisfier evaluates to `true` only when the original predicate is /// definitely true for the current stats scope. Returning `None` means the rule cannot prove /// anything for the expression. -#[allow(dead_code)] -pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { +pub trait StatsRewriteRule: Debug + Send + Sync + 'static { /// The scalar function ID this rule applies to. fn scalar_fn_id(&self) -> ScalarFnId; @@ -58,35 +57,35 @@ pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { } /// Context passed to stats rewrite rules. -pub(crate) struct StatsRewriteCtx<'a> { +pub struct StatsRewriteCtx<'a> { session: &'a VortexSession, scope: &'a DType, } impl<'a> StatsRewriteCtx<'a> { /// Create a rewrite context for `session`. - pub(crate) fn new(session: &'a VortexSession, scope: &'a DType) -> Self { + pub fn new(session: &'a VortexSession, scope: &'a DType) -> Self { Self { session, scope } } /// Returns the session that owns the rewrite registry. - pub(crate) fn session(&self) -> &'a VortexSession { + pub fn session(&self) -> &'a VortexSession { self.session } /// Return the dtype of `expr` within this rewrite scope. - pub(crate) fn return_dtype(&self, expr: &Expression) -> VortexResult { + pub fn return_dtype(&self, expr: &Expression) -> VortexResult { expr.return_dtype(self.scope) } /// Rewrite `expr` into a stats-backed falsifier. - pub(crate) fn falsify(&self, expr: &Expression) -> VortexResult> { + pub fn falsify(&self, expr: &Expression) -> VortexResult> { self.ensure_predicate(expr)?; rewrite(expr, self, StatsRewriteRule::falsify) } /// Rewrite `expr` into a stats-backed satisfier. - pub(crate) fn satisfy(&self, expr: &Expression) -> VortexResult> { + pub fn satisfy(&self, expr: &Expression) -> VortexResult> { self.ensure_predicate(expr)?; rewrite(expr, self, StatsRewriteRule::satisfy) } diff --git a/vortex-array/src/stats/session.rs b/vortex-array/src/stats/session.rs index 2d4325b2cd7..91eae4a4fa9 100644 --- a/vortex-array/src/stats/session.rs +++ b/vortex-array/src/stats/session.rs @@ -37,14 +37,12 @@ impl Default for StatsSession { impl StatsSession { /// Register a stats rewrite rule. - #[allow(dead_code)] - pub(crate) fn register_rewrite(&self, rule: R) { + pub fn register_rewrite(&self, rule: R) { self.register_rewrite_ref(Arc::new(rule)); } /// Register a shared stats rewrite rule. - #[allow(dead_code)] - pub(crate) fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) { + pub fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) { let mut rules = self.rewrite_rules.write(); let rule_id = rule.scalar_fn_id(); let mut updated_rules = rules @@ -75,7 +73,7 @@ impl SessionVar for StatsSession { } /// Extension trait for accessing stats session data. -pub(crate) trait StatsSessionExt: SessionExt { +pub trait StatsSessionExt: SessionExt { /// Returns the stats session state. fn stats(&self) -> Ref<'_, StatsSession> { self.get::() diff --git a/vortex-file/src/file.rs b/vortex-file/src/file.rs index ded986f6210..225d18b561a 100644 --- a/vortex-file/src/file.rs +++ b/vortex-file/src/file.rs @@ -22,7 +22,7 @@ use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; use vortex_array::dtype::FieldPathSet; use vortex_array::expr::Expression; -use vortex_array::expr::pruning::checked_pruning_expr; +use vortex_array::expr::pruning::checked_pruning_expr_with_session; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_error::VortexResult; use vortex_layout::LayoutReader; @@ -217,7 +217,9 @@ impl VortexFile { }), ); - let Some((predicate, required_stats)) = checked_pruning_expr(filter, &set) else { + let Some((predicate, required_stats)) = + checked_pruning_expr_with_session(filter, self.footer.dtype(), &set, &self.session)? + else { return Ok(false); }; diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index 0121c12b07d..22c92e817d6 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -10,22 +10,37 @@ use std::ops::Range; use std::sync::Arc; +use itertools::Itertools; use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::MaskFuture; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::fns::all_nan::AllNan; +use vortex_array::aggregate_fn::fns::all_non_nan::AllNonNan; +use vortex_array::aggregate_fn::fns::all_non_null::AllNonNull; +use vortex_array::aggregate_fn::fns::all_null::AllNull; +use vortex_array::aggregate_fn::fns::nan_count::NanCount; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::NullArray; use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; +use vortex_array::dtype::Nullability; use vortex_array::dtype::StructFields; use vortex_array::expr::Expression; use vortex_array::expr::StatsCatalog; +use vortex_array::expr::analysis::referenced_field_paths; +use vortex_array::expr::eq; use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; +use vortex_array::expr::traversal::NodeExt; +use vortex_array::expr::traversal::Transformed; use vortex_array::scalar::Scalar; +use vortex_array::scalar_fn::EmptyOptions; +use vortex_array::scalar_fn::ScalarFnVTableExt; use vortex_array::scalar_fn::fns::literal::Literal; +use vortex_array::scalar_fn::fns::stat::StatFn; +use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_error::VortexResult; use vortex_layout::ArrayFuture; @@ -83,10 +98,11 @@ impl FileStatsLayoutReader { /// Row-count placeholders are resolved against the full file row count, /// independent of the requested row range. fn evaluate_file_stats(&self, expr: &Expression) -> VortexResult { - let Some(pruning_expr) = expr.stat_falsification(self) else { + let Some(pruning_expr) = expr.falsify(self.child.dtype(), &self.session)? else { // If there is no pruning expression, we can't prune. return Ok(false); }; + let pruning_expr = self.lower_stats(pruning_expr)?; // Given how we implemented the StatsCatalog, we know the expression must be all literals // or row_count placeholders. We can therefore optimize with a null scope since there are @@ -115,11 +131,101 @@ impl FileStatsLayoutReader { Ok(result.as_bool().value() == Some(true)) } + fn lower_stats(&self, predicate: Expression) -> VortexResult { + predicate + .transform_down(|expr| { + if expr.is::() { + return self.lower_stat_fn(expr).map(Transformed::yes); + } + + Ok(Transformed::no(expr)) + }) + .map(Transformed::into_inner) + } + + fn lower_stat_fn(&self, expr: Expression) -> VortexResult { + let options = expr.as_::(); + let aggregate_fn = options.aggregate_fn(); + let input = expr.child(0); + let input_dtype = input.return_dtype(self.child.dtype())?; + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(lit(false)); + } + return Ok(self + .stat_ref(input, Stat::NaNCount)? + .map(|stat| eq(stat, row_count_expr())) + .unwrap_or_else(null_bool_expr)); + } + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(lit(true)); + } + return Ok(self + .stat_ref(input, Stat::NaNCount)? + .map(|stat| eq(stat, lit(0u64))) + .unwrap_or_else(null_bool_expr)); + } + + if aggregate_fn.is::() && !has_nans(&input_dtype) { + return Ok(lit(0u64)); + } + + if aggregate_fn.is::() { + return Ok(self + .stat_ref(input, Stat::NullCount)? + .map(|stat| eq(stat, row_count_expr())) + .unwrap_or_else(null_bool_expr)); + } + + if aggregate_fn.is::() { + return Ok(self + .stat_ref(input, Stat::NullCount)? + .map(|stat| eq(stat, lit(0u64))) + .unwrap_or_else(null_bool_expr)); + } + + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(null_expr(expr.return_dtype(self.child.dtype())?)); + }; + + let return_dtype = expr.return_dtype(self.child.dtype())?; + Ok(self + .stat_ref(input, stat)? + .unwrap_or_else(|| null_expr(return_dtype))) + } + + fn stat_ref(&self, input: &Expression, stat: Stat) -> VortexResult> { + let field_paths = referenced_field_paths(input, self.child.dtype())?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + Ok(self.stats_ref(field_path, stat)) + } + pub fn file_stats(&self) -> &FileStatistics { &self.file_stats } } +fn row_count_expr() -> Expression { + RowCount.new_expr(EmptyOptions, []) +} + +fn null_expr(dtype: DType) -> Expression { + lit(Scalar::null(dtype.as_nullable())) +} + +fn null_bool_expr() -> Expression { + null_expr(DType::Bool(Nullability::NonNullable)) +} + +fn has_nans(dtype: &DType) -> bool { + matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) +} + /// Implements [`StatsCatalog`] to provide file-level stats to expressions during pruning evaluation. impl StatsCatalog for FileStatsLayoutReader { fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option {