Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions vortex-array/src/stats/rewrite/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,61 @@ impl StatsRewriteRule for BinaryStatsRewrite {
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
})
}

fn satisfy(
&self,
expr: &Expression,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
let operator = expr.as_::<Binary>();
let lhs = expr.child(0);
let rhs = expr.child(1);

// Min/max stats may be truncated to outward bounds (stored min ≤
// true min, stored max ≥ true max), which keeps every comparison
// below conservative: a bound that proves the predicate proves it
// for the true extrema too.
let value_predicate = match operator {
// Both value ranges pinch to the same single point.
Operator::Eq => min(lhs).zip(max(lhs)).zip(min(rhs).zip(max(rhs))).map(
|((min_lhs, max_lhs), (min_rhs, max_rhs))| {
and(lt_eq(max_lhs, min_rhs), gt_eq(min_lhs, max_rhs))
},
),
// The value ranges are disjoint.
Operator::NotEq => max(lhs).zip(min(rhs)).zip(min(lhs).zip(max(rhs))).map(
|((max_lhs, min_rhs), (min_lhs, max_rhs))| {
or(lt(max_lhs, min_rhs), gt(min_lhs, max_rhs))
},
),
Operator::Gt => min(lhs).zip(max(rhs)).map(|(a, b)| gt(a, b)),
Operator::Gte => min(lhs).zip(max(rhs)).map(|(a, b)| gt_eq(a, b)),
Operator::Lt => max(lhs).zip(min(rhs)).map(|(a, b)| lt(a, b)),
Operator::Lte => max(lhs).zip(min(rhs)).map(|(a, b)| lt_eq(a, b)),
Operator::And => {
return Ok(match (ctx.satisfy(lhs)?, ctx.satisfy(rhs)?) {
(Some(lhs), Some(rhs)) => Some(and(lhs, rhs)),
_ => None,
});
}
Operator::Or => {
let lhs_satisfier = ctx.satisfy(lhs)?;
let rhs_satisfier = ctx.satisfy(rhs)?;
return Ok(or_collect(lhs_satisfier.into_iter().chain(rhs_satisfier)));
}
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => return Ok(None),
};
value_predicate
.map(|value_predicate| {
// Satisfaction must prove more than the values: a NaN
// operand makes every comparison false, and a null operand
// makes it null — neither is `true`, so the rewrite is only
// sound over rows proven non-NaN and non-null.
let guarded = with_nan_predicate(ctx, lhs, rhs, value_predicate)?;
with_all_non_null_predicate(ctx, [lhs, rhs], guarded)
})
.transpose()
}
}

#[derive(Debug)]
Expand All @@ -154,6 +209,21 @@ impl StatsRewriteRule for BetweenStatsRewrite {
let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
ctx.falsify(&and(lhs, rhs))
}

fn satisfy(
&self,
expr: &Expression,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
let options = expr.as_::<Between>();
let arr = expr.child(0).clone();
let lower = expr.child(1).clone();
let upper = expr.child(2).clone();

let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
ctx.satisfy(&and(lhs, rhs))
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -503,6 +573,27 @@ fn with_nan_predicate(
with_all_non_nan_predicate(ctx, [lhs, rhs], value_predicate)
}

// Satisfaction rewrites prove a predicate true for *every* row, but min/max
// stats describe non-null values only: a null operand row evaluates the
// comparison to null, not true. Guard each nullable operand with an
// all-non-null check; non-nullable operands need none.
fn with_all_non_null_predicate<'a>(
ctx: &StatsRewriteCtx<'_>,
exprs: impl IntoIterator<Item = &'a Expression>,
value_predicate: Expression,
) -> VortexResult<Expression> {
let mut null_checks = Vec::new();
for expr in exprs {
if ctx.return_dtype(expr)?.is_nullable() {
null_checks.push(all_non_null(expr));
}
}
Ok(match and_collect(null_checks) {
Some(null_check) => and(null_check, value_predicate),
None => value_predicate,
})
}

fn with_all_non_nan_predicate<'a>(
ctx: &StatsRewriteCtx<'_>,
exprs: impl IntoIterator<Item = &'a Expression>,
Expand Down Expand Up @@ -626,6 +717,7 @@ mod tests {
("f", DType::Primitive(PType::F32, Nullability::NonNullable)),
("s", DType::Utf8(Nullability::NonNullable)),
("t", DType::Utf8(Nullability::NonNullable)),
("n", DType::Primitive(PType::I32, Nullability::Nullable)),
]),
Nullability::NonNullable,
)
Expand Down Expand Up @@ -671,6 +763,96 @@ mod tests {
Ok(())
}

#[test]
fn rewrites_comparison_satisfier() -> VortexResult<()> {
// Non-nullable integer: the value condition alone proves all-true.
let expr = lt(col("a"), lit(10));
assert_eq!(
satisfy(&expr)?,
Some(lt(stat(col("a"), Stat::Max), lit(10)))
);

let expr = gt_eq(col("a"), lit(10));
assert_eq!(
satisfy(&expr)?,
Some(gt_eq(stat(col("a"), Stat::Min), lit(10)))
);

// Column-to-column comparison uses both sides' stats.
let expr = lt(col("a"), col("b"));
assert_eq!(
satisfy(&expr)?,
Some(lt(stat(col("a"), Stat::Max), stat(col("b"), Stat::Min)))
);

// Floats must also prove no NaNs: a NaN row never satisfies a
// comparison.
let expr = gt(col("f"), lit(1.0f32));
assert_eq!(
satisfy(&expr)?,
Some(and(
nan_free(col("f")),
gt(stat(col("f"), Stat::Min), lit(1.0f32))
))
);

// Nullable operands must also prove no nulls: a null row evaluates
// the comparison to null, not true.
let expr = lt(col("n"), lit(10));
assert_eq!(
satisfy(&expr)?,
Some(and(
all_non_null(&col("n")),
lt(stat(col("n"), Stat::Max), lit(10))
))
);
Ok(())
}

#[test]
fn rewrites_boolean_satisfiers() -> VortexResult<()> {
// Conjunctions require both satisfiers; disjunctions accept either.
let expr = and(gt(col("a"), lit(10)), lt(col("a"), lit(50)));
assert_eq!(
satisfy(&expr)?,
Some(and(
gt(stat(col("a"), Stat::Min), lit(10)),
lt(stat(col("a"), Stat::Max), lit(50)),
))
);

let expr = or(gt(col("a"), lit(10)), lt(col("a"), lit(0)));
assert_eq!(
satisfy(&expr)?,
Some(or(
gt(stat(col("a"), Stat::Min), lit(10)),
lt(stat(col("a"), Stat::Max), lit(0)),
))
);
Ok(())
}

#[test]
fn rewrites_between_satisfier() -> VortexResult<()> {
let expr = between(
col("a"),
lit(10),
lit(50),
BetweenOptions {
lower_strict: StrictComparison::NonStrict,
upper_strict: StrictComparison::NonStrict,
},
);
assert_eq!(
satisfy(&expr)?,
Some(and(
lt_eq(lit(10), stat(col("a"), Stat::Min)),
lt_eq(stat(col("a"), Stat::Max), lit(50)),
))
);
Ok(())
}

#[test]
fn rewrites_boolean_falsifiers() -> VortexResult<()> {
let expr = and(gt(col("a"), lit(10)), lt(col("a"), lit(50)));
Expand Down
Loading