Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 80 additions & 47 deletions vortex-array/src/stats/rewrite/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ impl StatsRewriteRule for BinaryStatsRewrite {

Ok(match operator {
Operator::Eq => {
let left = min(lhs).zip(max(rhs)).map(|(a, b)| gt(a, b));
let right = min(rhs).zip(max(lhs)).map(|(a, b)| gt(a, b));
let left = min(lhs, ctx).zip(max(rhs, ctx)).map(|(a, b)| gt(a, b));
let right = min(rhs, ctx).zip(max(lhs, ctx)).map(|(a, b)| gt(a, b));
or_collect(left.into_iter().chain(right))
.map(|value_predicate| with_nan_predicate(ctx, lhs, rhs, value_predicate))
.transpose()?
}
Operator::NotEq => min(lhs)
.zip(max(rhs))
.zip(max(lhs).zip(min(rhs)))
Operator::NotEq => min(lhs, ctx)
.zip(max(rhs, ctx))
.zip(max(lhs, ctx).zip(min(rhs, ctx)))
.map(|((min_lhs, max_rhs), (max_lhs, min_rhs))| {
with_nan_predicate(
ctx,
Expand All @@ -102,20 +102,20 @@ impl StatsRewriteRule for BinaryStatsRewrite {
)
})
.transpose()?,
Operator::Gt => max(lhs)
.zip(min(rhs))
Operator::Gt => max(lhs, ctx)
.zip(min(rhs, ctx))
.map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, lt_eq(a, b)))
.transpose()?,
Operator::Gte => max(lhs)
.zip(min(rhs))
Operator::Gte => max(lhs, ctx)
.zip(min(rhs, ctx))
.map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, lt(a, b)))
.transpose()?,
Operator::Lt => min(lhs)
.zip(max(rhs))
Operator::Lt => min(lhs, ctx)
.zip(max(rhs, ctx))
.map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, gt_eq(a, b)))
.transpose()?,
Operator::Lte => min(lhs)
.zip(max(rhs))
Operator::Lte => min(lhs, ctx)
.zip(max(rhs, ctx))
.map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, gt(a, b)))
.transpose()?,
Operator::And => {
Expand Down Expand Up @@ -167,17 +167,17 @@ impl StatsRewriteRule for IsNullLegacyStatsRewrite {
fn falsify(
&self,
expr: &Expression,
_ctx: &StatsRewriteCtx<'_>,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
Ok(null_count(expr.child(0)).map(|null_count| eq(null_count, lit(0u64))))
Ok(null_count(expr.child(0), ctx).map(|null_count| eq(null_count, lit(0u64))))
}

fn satisfy(
&self,
expr: &Expression,
_ctx: &StatsRewriteCtx<'_>,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
Ok(null_count(expr.child(0))
Ok(null_count(expr.child(0), ctx)
.map(|null_count| eq(null_count, RowCount.new_expr(EmptyOptions, []))))
}
}
Expand Down Expand Up @@ -227,18 +227,18 @@ impl StatsRewriteRule for IsNotNullLegacyStatsRewrite {
fn falsify(
&self,
expr: &Expression,
_ctx: &StatsRewriteCtx<'_>,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
Ok(null_count(expr.child(0))
Ok(null_count(expr.child(0), ctx)
.map(|null_count| eq(null_count, RowCount.new_expr(EmptyOptions, []))))
}

fn satisfy(
&self,
expr: &Expression,
_ctx: &StatsRewriteCtx<'_>,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
Ok(null_count(expr.child(0)).map(|null_count| eq(null_count, lit(0u64))))
Ok(null_count(expr.child(0), ctx).map(|null_count| eq(null_count, lit(0u64))))
}
}

Expand Down Expand Up @@ -287,7 +287,7 @@ impl StatsRewriteRule for LikeStatsRewrite {
fn falsify(
&self,
expr: &Expression,
_ctx: &StatsRewriteCtx<'_>,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
let like_options = expr.as_::<Like>();
if like_options.negated || like_options.case_insensitive {
Expand All @@ -304,8 +304,8 @@ impl StatsRewriteRule for LikeStatsRewrite {
let source = expr.child(0);
Ok(match LikeVariant::from_str(pattern) {
Some(LikeVariant::Exact(text)) => {
min(source)
.zip(max(source))
min(source, ctx)
.zip(max(source, ctx))
.map(|(source_min, source_max)| {
or(
gt(source_min, lit(text.as_ref())),
Expand All @@ -317,8 +317,8 @@ impl StatsRewriteRule for LikeStatsRewrite {
let Some(successor) = prefix.to_string().increment().ok() else {
return Ok(None);
};
min(source)
.zip(max(source))
min(source, ctx)
.zip(max(source, ctx))
.map(|(source_min, source_max)| {
or(
gt_eq(source_min, lit(successor)),
Expand Down Expand Up @@ -361,10 +361,10 @@ impl StatsRewriteRule for ListContainsStatsRewrite {
return Ok(Some(lit(true)));
}

let Some(value_max) = max(needle) else {
let Some(value_max) = max(needle, ctx) else {
return Ok(None);
};
let Some(value_min) = min(needle) else {
let Some(value_min) = min(needle, ctx) else {
return Ok(None);
};

Expand Down Expand Up @@ -398,10 +398,10 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite {

let Some((operator, lhs_stat)) = (match dynamic.operator {
CompareOperator::Eq | CompareOperator::NotEq => None,
CompareOperator::Gt => max(lhs).map(|lhs_stat| (CompareOperator::Lte, lhs_stat)),
CompareOperator::Gte => max(lhs).map(|lhs_stat| (CompareOperator::Lt, lhs_stat)),
CompareOperator::Lt => min(lhs).map(|lhs_stat| (CompareOperator::Gte, lhs_stat)),
CompareOperator::Lte => min(lhs).map(|lhs_stat| (CompareOperator::Gt, lhs_stat)),
CompareOperator::Gt => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lte, lhs_stat)),
CompareOperator::Gte => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lt, lhs_stat)),
CompareOperator::Lt => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gte, lhs_stat)),
CompareOperator::Lte => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gt, lhs_stat)),
}) else {
return Ok(None);
};
Expand All @@ -418,16 +418,16 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite {
}
}

fn min(expr: &Expression) -> Option<Expression> {
stat_expr(expr, Stat::Min)
fn min(expr: &Expression, ctx: &StatsRewriteCtx<'_>) -> Option<Expression> {
stat_expr(expr, Stat::Min, ctx)
}

fn max(expr: &Expression) -> Option<Expression> {
stat_expr(expr, Stat::Max)
fn max(expr: &Expression, ctx: &StatsRewriteCtx<'_>) -> Option<Expression> {
stat_expr(expr, Stat::Max, ctx)
}

fn null_count(expr: &Expression) -> Option<Expression> {
stat_expr(expr, Stat::NullCount)
fn null_count(expr: &Expression, ctx: &StatsRewriteCtx<'_>) -> Option<Expression> {
stat_expr(expr, Stat::NullCount, ctx)
}

fn all_null(expr: &Expression) -> Expression {
Expand Down Expand Up @@ -474,7 +474,7 @@ fn has_nans(dtype: &DType) -> bool {
matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float())
}

fn stat_expr(expr: &Expression, stat: Stat) -> Option<Expression> {
fn stat_expr(expr: &Expression, stat: Stat, ctx: &StatsRewriteCtx<'_>) -> Option<Expression> {
if let Some(literal) = literal_stat(expr, stat) {
return Some(literal);
}
Expand All @@ -487,11 +487,18 @@ fn stat_expr(expr: &Expression, stat: Stat) -> Option<Expression> {
}

if let Some(dtype) = expr.as_opt::<Cast>() {
return cast_stat(expr.child(0), dtype, stat);
}

stat.aggregate_fn()
.map(|aggregate_fn| stat_fn(expr.clone(), aggregate_fn))
return cast_stat(expr.child(0), dtype, stat, ctx);
}

let aggregate_fn = stat.aggregate_fn()?;
// The aggregate may not support the expression's dtype, e.g. min/max over structs,
// even when the predicate itself is well-typed. Such stats cannot be lowered later,
// so do not reference them in the rewrite.
let input_dtype = ctx.return_dtype(expr).ok()?;
aggregate_fn
.return_dtype(&input_dtype)
.is_some()
.then(|| stat_fn(expr.clone(), aggregate_fn))
}

fn with_nan_predicate(
Expand Down Expand Up @@ -545,10 +552,15 @@ fn literal_stat(expr: &Expression, stat: Stat) -> Option<Expression> {
}
}

fn cast_stat(expr: &Expression, dtype: &DType, stat: Stat) -> Option<Expression> {
fn cast_stat(
expr: &Expression,
dtype: &DType,
stat: Stat,
ctx: &StatsRewriteCtx<'_>,
) -> Option<Expression> {
match stat {
Stat::Min | Stat::Max => stat_expr(expr, stat).map(|stat| cast(stat, dtype.clone())),
Stat::NaNCount | Stat::Sum | Stat::UncompressedSizeInBytes => stat_expr(expr, stat),
Stat::Min | Stat::Max => stat_expr(expr, stat, ctx).map(|stat| cast(stat, dtype.clone())),
Stat::NaNCount | Stat::Sum | Stat::UncompressedSizeInBytes => stat_expr(expr, stat, ctx),
Stat::NullCount | Stat::IsConstant | Stat::IsSorted | Stat::IsStrictSorted => None,
}
}
Expand Down Expand Up @@ -626,11 +638,19 @@ mod tests {
("f", DType::Primitive(PType::F32, Nullability::NonNullable)),
("s", DType::Utf8(Nullability::NonNullable)),
("t", DType::Utf8(Nullability::NonNullable)),
("n", nested_struct_dtype()),
]),
Nullability::NonNullable,
)
}

fn nested_struct_dtype() -> DType {
DType::Struct(
StructFields::from_iter([("x", DType::Primitive(PType::F32, Nullability::Nullable))]),
Nullability::NonNullable,
)
}

fn falsify(expr: &Expression) -> VortexResult<Option<Expression>> {
expr.falsify(&test_scope(), &SESSION)
}
Expand Down Expand Up @@ -848,6 +868,19 @@ mod tests {
Ok(())
}

#[test]
fn skips_falsifier_when_min_max_unsupported_for_dtype() -> VortexResult<()> {
// Struct comparisons are valid predicates, but min/max aggregates do not
// support struct inputs, so no stats-backed falsifier should be produced.
let struct_scalar = Scalar::struct_(
nested_struct_dtype(),
vec![Scalar::primitive(1.0f32, Nullability::Nullable)],
);
assert_eq!(falsify(&lt_eq(col("n"), lit(struct_scalar.clone())))?, None);
assert_eq!(falsify(&eq(col("n"), lit(struct_scalar)))?, None);
Ok(())
}

#[test]
fn forwards_min_max_through_safe_cast() -> VortexResult<()> {
let dtype = DType::Primitive(PType::I64, Nullability::NonNullable);
Expand Down
Loading