diff --git a/vortex-array/src/stats/rewrite/builtins.rs b/vortex-array/src/stats/rewrite/builtins.rs index 3476c053ecf..db0ee7ffc31 100644 --- a/vortex-array/src/stats/rewrite/builtins.rs +++ b/vortex-array/src/stats/rewrite/builtins.rs @@ -130,6 +130,61 @@ impl StatsRewriteRule for BinaryStatsRewrite { Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, }) } + + fn satisfy( + &self, + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, + ) -> VortexResult> { + let operator = expr.as_::(); + let lhs = expr.child(0); + let rhs = expr.child(1); + + // Min/max stats may be truncated to outward bounds (stored min ≤ + // true min, stored max ≥ true max), which keeps every comparison + // below conservative: a bound that proves the predicate proves it + // for the true extrema too. + let value_predicate = match operator { + // Both value ranges pinch to the same single point. + Operator::Eq => min(lhs).zip(max(lhs)).zip(min(rhs).zip(max(rhs))).map( + |((min_lhs, max_lhs), (min_rhs, max_rhs))| { + and(lt_eq(max_lhs, min_rhs), gt_eq(min_lhs, max_rhs)) + }, + ), + // The value ranges are disjoint. + Operator::NotEq => max(lhs).zip(min(rhs)).zip(min(lhs).zip(max(rhs))).map( + |((max_lhs, min_rhs), (min_lhs, max_rhs))| { + or(lt(max_lhs, min_rhs), gt(min_lhs, max_rhs)) + }, + ), + Operator::Gt => min(lhs).zip(max(rhs)).map(|(a, b)| gt(a, b)), + Operator::Gte => min(lhs).zip(max(rhs)).map(|(a, b)| gt_eq(a, b)), + Operator::Lt => max(lhs).zip(min(rhs)).map(|(a, b)| lt(a, b)), + Operator::Lte => max(lhs).zip(min(rhs)).map(|(a, b)| lt_eq(a, b)), + Operator::And => { + return Ok(match (ctx.satisfy(lhs)?, ctx.satisfy(rhs)?) { + (Some(lhs), Some(rhs)) => Some(and(lhs, rhs)), + _ => None, + }); + } + Operator::Or => { + let lhs_satisfier = ctx.satisfy(lhs)?; + let rhs_satisfier = ctx.satisfy(rhs)?; + return Ok(or_collect(lhs_satisfier.into_iter().chain(rhs_satisfier))); + } + Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => return Ok(None), + }; + value_predicate + .map(|value_predicate| { + // Satisfaction must prove more than the values: a NaN + // operand makes every comparison false, and a null operand + // makes it null — neither is `true`, so the rewrite is only + // sound over rows proven non-NaN and non-null. + let guarded = with_nan_predicate(ctx, lhs, rhs, value_predicate)?; + with_all_non_null_predicate(ctx, [lhs, rhs], guarded) + }) + .transpose() + } } #[derive(Debug)] @@ -154,6 +209,21 @@ impl StatsRewriteRule for BetweenStatsRewrite { let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]); ctx.falsify(&and(lhs, rhs)) } + + fn satisfy( + &self, + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, + ) -> VortexResult> { + let options = expr.as_::(); + let arr = expr.child(0).clone(); + let lower = expr.child(1).clone(); + let upper = expr.child(2).clone(); + + let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]); + let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]); + ctx.satisfy(&and(lhs, rhs)) + } } #[derive(Debug)] @@ -503,6 +573,27 @@ fn with_nan_predicate( with_all_non_nan_predicate(ctx, [lhs, rhs], value_predicate) } +// Satisfaction rewrites prove a predicate true for *every* row, but min/max +// stats describe non-null values only: a null operand row evaluates the +// comparison to null, not true. Guard each nullable operand with an +// all-non-null check; non-nullable operands need none. +fn with_all_non_null_predicate<'a>( + ctx: &StatsRewriteCtx<'_>, + exprs: impl IntoIterator, + value_predicate: Expression, +) -> VortexResult { + let mut null_checks = Vec::new(); + for expr in exprs { + if ctx.return_dtype(expr)?.is_nullable() { + null_checks.push(all_non_null(expr)); + } + } + Ok(match and_collect(null_checks) { + Some(null_check) => and(null_check, value_predicate), + None => value_predicate, + }) +} + fn with_all_non_nan_predicate<'a>( ctx: &StatsRewriteCtx<'_>, exprs: impl IntoIterator, @@ -626,6 +717,7 @@ mod tests { ("f", DType::Primitive(PType::F32, Nullability::NonNullable)), ("s", DType::Utf8(Nullability::NonNullable)), ("t", DType::Utf8(Nullability::NonNullable)), + ("n", DType::Primitive(PType::I32, Nullability::Nullable)), ]), Nullability::NonNullable, ) @@ -671,6 +763,96 @@ mod tests { Ok(()) } + #[test] + fn rewrites_comparison_satisfier() -> VortexResult<()> { + // Non-nullable integer: the value condition alone proves all-true. + let expr = lt(col("a"), lit(10)); + assert_eq!( + satisfy(&expr)?, + Some(lt(stat(col("a"), Stat::Max), lit(10))) + ); + + let expr = gt_eq(col("a"), lit(10)); + assert_eq!( + satisfy(&expr)?, + Some(gt_eq(stat(col("a"), Stat::Min), lit(10))) + ); + + // Column-to-column comparison uses both sides' stats. + let expr = lt(col("a"), col("b")); + assert_eq!( + satisfy(&expr)?, + Some(lt(stat(col("a"), Stat::Max), stat(col("b"), Stat::Min))) + ); + + // Floats must also prove no NaNs: a NaN row never satisfies a + // comparison. + let expr = gt(col("f"), lit(1.0f32)); + assert_eq!( + satisfy(&expr)?, + Some(and( + nan_free(col("f")), + gt(stat(col("f"), Stat::Min), lit(1.0f32)) + )) + ); + + // Nullable operands must also prove no nulls: a null row evaluates + // the comparison to null, not true. + let expr = lt(col("n"), lit(10)); + assert_eq!( + satisfy(&expr)?, + Some(and( + all_non_null(&col("n")), + lt(stat(col("n"), Stat::Max), lit(10)) + )) + ); + Ok(()) + } + + #[test] + fn rewrites_boolean_satisfiers() -> VortexResult<()> { + // Conjunctions require both satisfiers; disjunctions accept either. + let expr = and(gt(col("a"), lit(10)), lt(col("a"), lit(50))); + assert_eq!( + satisfy(&expr)?, + Some(and( + gt(stat(col("a"), Stat::Min), lit(10)), + lt(stat(col("a"), Stat::Max), lit(50)), + )) + ); + + let expr = or(gt(col("a"), lit(10)), lt(col("a"), lit(0))); + assert_eq!( + satisfy(&expr)?, + Some(or( + gt(stat(col("a"), Stat::Min), lit(10)), + lt(stat(col("a"), Stat::Max), lit(0)), + )) + ); + Ok(()) + } + + #[test] + fn rewrites_between_satisfier() -> VortexResult<()> { + let expr = between( + col("a"), + lit(10), + lit(50), + BetweenOptions { + lower_strict: StrictComparison::NonStrict, + upper_strict: StrictComparison::NonStrict, + }, + ); + assert_eq!( + satisfy(&expr)?, + Some(and( + lt_eq(lit(10), stat(col("a"), Stat::Min)), + lt_eq(stat(col("a"), Stat::Max), lit(50)), + )) + ); + Ok(()) + } + #[test] fn rewrites_boolean_falsifiers() -> VortexResult<()> { let expr = and(gt(col("a"), lit(10)), lt(col("a"), lit(50)));