Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vortex-array/src/expr/pruning/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod relation;

pub use pruning_expr::RequiredStats;
pub use pruning_expr::checked_pruning_expr;
pub use pruning_expr::checked_pruning_expr_with_session;
pub use pruning_expr::field_path_stat_field_name;
pub use relation::Relation;

Expand Down
175 changes: 175 additions & 0 deletions vortex-array/src/expr/pruning/pruning_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,36 @@ use std::cell::RefCell;
use std::iter;

use itertools::Itertools;
use vortex_error::VortexResult;
use vortex_session::VortexSession;
use vortex_utils::aliases::hash_map::HashMap;

use super::relation::Relation;
use crate::aggregate_fn::fns::all_nan::AllNan;
use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
use crate::aggregate_fn::fns::all_non_null::AllNonNull;
use crate::aggregate_fn::fns::all_null::AllNull;
use crate::aggregate_fn::fns::nan_count::NanCount;
use crate::dtype::DType;
use crate::dtype::Field;
use crate::dtype::FieldName;
use crate::dtype::FieldPath;
use crate::dtype::FieldPathSet;
use crate::expr::Expression;
use crate::expr::StatsCatalog;
use crate::expr::analysis::referenced_field_paths;
use crate::expr::eq;
use crate::expr::get_item;
use crate::expr::lit;
use crate::expr::root;
use crate::expr::stats::Stat;
use crate::expr::traversal::NodeExt;
use crate::expr::traversal::Transformed;
use crate::scalar::Scalar;
use crate::scalar_fn::EmptyOptions;
use crate::scalar_fn::ScalarFnVTableExt;
use crate::scalar_fn::fns::stat::StatFn;
use crate::scalar_fn::internal::row_count::RowCount;

pub type RequiredStats = Relation<FieldPath, Stat>;

Expand Down Expand Up @@ -113,6 +131,163 @@ pub fn checked_pruning_expr(
Some((expr, relation))
}

/// Build a pruning expression using session-registered stats rewrite rules.
///
/// The returned expression is lowered to the same stats-table field references as
/// [`checked_pruning_expr`]. If a rewrite asks for a stat that is not present in
/// `available_stats`, this returns `Ok(None)`.
pub fn checked_pruning_expr_with_session(
expr: &Expression,
scope: &DType,
available_stats: &FieldPathSet,
session: &VortexSession,
) -> VortexResult<Option<(Expression, RequiredStats)>> {
let Some(predicate) = expr.falsify(scope, session)? else {
return Ok(None);
};

lower_stat_fns(predicate, scope, available_stats)
}

fn lower_stat_fns(
predicate: Expression,
scope: &DType,
available_stats: &FieldPathSet,
) -> VortexResult<Option<(Expression, RequiredStats)>> {
let mut required_stats = Relation::new();
let mut missing_stat = false;
let lowered = predicate
.transform_down(|expr| {
if !expr.is::<StatFn>() {
return Ok(Transformed::no(expr));
}

if let Some(lowered) =
lower_stat_fn(&expr, scope, available_stats, &mut required_stats)?
{
return Ok(Transformed::yes(lowered));
}

missing_stat = true;
let dtype = expr.return_dtype(scope)?;
Ok(Transformed::yes(null_expr(dtype)))
})?
.into_inner();

if missing_stat {
return Ok(None);
}

Ok(Some((lowered, required_stats)))
}

fn lower_stat_fn(
expr: &Expression,
scope: &DType,
available_stats: &FieldPathSet,
required_stats: &mut RequiredStats,
) -> VortexResult<Option<Expression>> {
let options = expr.as_::<StatFn>();
let aggregate_fn = options.aggregate_fn();
let input = expr.child(0);
let input_dtype = input.return_dtype(scope)?;

if aggregate_fn.is::<AllNan>() {
if !has_nans(&input_dtype) {
return Ok(Some(lit(false)));
}
return lower_stat_ref(
input,
Stat::NaNCount,
scope,
available_stats,
required_stats,
)
.map(|stat| stat.map(|stat| eq(stat, row_count_expr())));
}

if aggregate_fn.is::<AllNonNan>() {
if !has_nans(&input_dtype) {
return Ok(Some(lit(true)));
}
return lower_stat_ref(
input,
Stat::NaNCount,
scope,
available_stats,
required_stats,
)
.map(|stat| stat.map(|stat| eq(stat, lit(0u64))));
}

if aggregate_fn.is::<NanCount>() && !has_nans(&input_dtype) {
return Ok(Some(lit(0u64)));
}

if aggregate_fn.is::<AllNull>() {
return lower_stat_ref(
input,
Stat::NullCount,
scope,
available_stats,
required_stats,
)
.map(|stat| stat.map(|stat| eq(stat, row_count_expr())));
}

if aggregate_fn.is::<AllNonNull>() {
return lower_stat_ref(
input,
Stat::NullCount,
scope,
available_stats,
required_stats,
)
.map(|stat| stat.map(|stat| eq(stat, lit(0u64))));
}

let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else {
return Ok(None);
};

lower_stat_ref(input, stat, scope, available_stats, required_stats)
}

fn lower_stat_ref(
input: &Expression,
stat: Stat,
scope: &DType,
available_stats: &FieldPathSet,
required_stats: &mut RequiredStats,
) -> VortexResult<Option<Expression>> {
let field_paths = referenced_field_paths(input, scope)?;
let Some(field_path) = field_paths.iter().exactly_one().ok() else {
return Ok(None);
};
let stat_path = field_path.clone().push(stat.name());
if !available_stats.contains(&stat_path) {
return Ok(None);
}

required_stats.insert(field_path.clone(), stat);
Ok(Some(get_item(
field_path_stat_field_name(field_path, stat),
root(),
)))
}

fn row_count_expr() -> Expression {
RowCount.new_expr(EmptyOptions, [])
}

fn null_expr(dtype: DType) -> Expression {
lit(Scalar::null(dtype.as_nullable()))
}

fn has_nans(dtype: &DType) -> bool {
matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float())
}

#[cfg(test)]
mod tests {
use rstest::fixture;
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub use stats_set::*;
mod array;
pub mod expr;
pub mod flatbuffers;
pub(crate) mod rewrite;
pub mod rewrite;
pub mod session;
mod stats_set;

Expand Down
17 changes: 8 additions & 9 deletions vortex-array/src/stats/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@ mod builtins;
pub(crate) use builtins::register_builtins;

/// Shared reference to a stats rewrite rule.
pub(crate) type StatsRewriteRuleRef = Arc<dyn StatsRewriteRule>;
pub type StatsRewriteRuleRef = Arc<dyn StatsRewriteRule>;

/// A plugin-provided rule that rewrites predicates into stats-backed proof expressions.
///
/// A falsifier evaluates to `true` only when the original predicate is definitely false for the
/// current stats scope. A satisfier evaluates to `true` only when the original predicate is
/// definitely true for the current stats scope. Returning `None` means the rule cannot prove
/// anything for the expression.
#[allow(dead_code)]
pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static {
pub trait StatsRewriteRule: Debug + Send + Sync + 'static {
/// The scalar function ID this rule applies to.
fn scalar_fn_id(&self) -> ScalarFnId;

Expand Down Expand Up @@ -58,35 +57,35 @@ pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static {
}

/// Context passed to stats rewrite rules.
pub(crate) struct StatsRewriteCtx<'a> {
pub struct StatsRewriteCtx<'a> {
session: &'a VortexSession,
scope: &'a DType,
}

impl<'a> StatsRewriteCtx<'a> {
/// Create a rewrite context for `session`.
pub(crate) fn new(session: &'a VortexSession, scope: &'a DType) -> Self {
pub fn new(session: &'a VortexSession, scope: &'a DType) -> Self {
Self { session, scope }
}

/// Returns the session that owns the rewrite registry.
pub(crate) fn session(&self) -> &'a VortexSession {
pub fn session(&self) -> &'a VortexSession {
self.session
}

/// Return the dtype of `expr` within this rewrite scope.
pub(crate) fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
pub fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
expr.return_dtype(self.scope)
}

/// Rewrite `expr` into a stats-backed falsifier.
pub(crate) fn falsify(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
pub fn falsify(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
self.ensure_predicate(expr)?;
rewrite(expr, self, StatsRewriteRule::falsify)
}

/// Rewrite `expr` into a stats-backed satisfier.
pub(crate) fn satisfy(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
pub fn satisfy(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
self.ensure_predicate(expr)?;
rewrite(expr, self, StatsRewriteRule::satisfy)
}
Expand Down
8 changes: 3 additions & 5 deletions vortex-array/src/stats/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ impl Default for StatsSession {

impl StatsSession {
/// Register a stats rewrite rule.
#[allow(dead_code)]
pub(crate) fn register_rewrite<R: StatsRewriteRule>(&self, rule: R) {
pub fn register_rewrite<R: StatsRewriteRule>(&self, rule: R) {
self.register_rewrite_ref(Arc::new(rule));
}

/// Register a shared stats rewrite rule.
#[allow(dead_code)]
pub(crate) fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) {
pub fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) {
let mut rules = self.rewrite_rules.write();
let rule_id = rule.scalar_fn_id();
let mut updated_rules = rules
Expand Down Expand Up @@ -75,7 +73,7 @@ impl SessionVar for StatsSession {
}

/// Extension trait for accessing stats session data.
pub(crate) trait StatsSessionExt: SessionExt {
pub trait StatsSessionExt: SessionExt {
/// Returns the stats session state.
fn stats(&self) -> Ref<'_, StatsSession> {
self.get::<StatsSession>()
Expand Down
6 changes: 4 additions & 2 deletions vortex-file/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use vortex_array::dtype::FieldMask;
use vortex_array::dtype::FieldPath;
use vortex_array::dtype::FieldPathSet;
use vortex_array::expr::Expression;
use vortex_array::expr::pruning::checked_pruning_expr;
use vortex_array::expr::pruning::checked_pruning_expr_with_session;
use vortex_array::scalar_fn::internal::row_count::substitute_row_count;
use vortex_error::VortexResult;
use vortex_layout::LayoutReader;
Expand Down Expand Up @@ -217,7 +217,9 @@ impl VortexFile {
}),
);

let Some((predicate, required_stats)) = checked_pruning_expr(filter, &set) else {
let Some((predicate, required_stats)) =
checked_pruning_expr_with_session(filter, self.footer.dtype(), &set, &self.session)?
else {
return Ok(false);
};

Expand Down
Loading
Loading