diff --git a/vortex-array/src/stats/rewrite/builtins.rs b/vortex-array/src/stats/rewrite/builtins.rs index 3476c053ecf..2b7316c7d98 100644 --- a/vortex-array/src/stats/rewrite/builtins.rs +++ b/vortex-array/src/stats/rewrite/builtins.rs @@ -84,15 +84,15 @@ impl StatsRewriteRule for BinaryStatsRewrite { Ok(match operator { Operator::Eq => { - let left = min(lhs).zip(max(rhs)).map(|(a, b)| gt(a, b)); - let right = min(rhs).zip(max(lhs)).map(|(a, b)| gt(a, b)); + let left = min(lhs, ctx).zip(max(rhs, ctx)).map(|(a, b)| gt(a, b)); + let right = min(rhs, ctx).zip(max(lhs, ctx)).map(|(a, b)| gt(a, b)); or_collect(left.into_iter().chain(right)) .map(|value_predicate| with_nan_predicate(ctx, lhs, rhs, value_predicate)) .transpose()? } - Operator::NotEq => min(lhs) - .zip(max(rhs)) - .zip(max(lhs).zip(min(rhs))) + Operator::NotEq => min(lhs, ctx) + .zip(max(rhs, ctx)) + .zip(max(lhs, ctx).zip(min(rhs, ctx))) .map(|((min_lhs, max_rhs), (max_lhs, min_rhs))| { with_nan_predicate( ctx, @@ -102,20 +102,20 @@ impl StatsRewriteRule for BinaryStatsRewrite { ) }) .transpose()?, - Operator::Gt => max(lhs) - .zip(min(rhs)) + Operator::Gt => max(lhs, ctx) + .zip(min(rhs, ctx)) .map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, lt_eq(a, b))) .transpose()?, - Operator::Gte => max(lhs) - .zip(min(rhs)) + Operator::Gte => max(lhs, ctx) + .zip(min(rhs, ctx)) .map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, lt(a, b))) .transpose()?, - Operator::Lt => min(lhs) - .zip(max(rhs)) + Operator::Lt => min(lhs, ctx) + .zip(max(rhs, ctx)) .map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, gt_eq(a, b))) .transpose()?, - Operator::Lte => min(lhs) - .zip(max(rhs)) + Operator::Lte => min(lhs, ctx) + .zip(max(rhs, ctx)) .map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, gt(a, b))) .transpose()?, Operator::And => { @@ -167,17 +167,17 @@ impl StatsRewriteRule for IsNullLegacyStatsRewrite { fn falsify( &self, expr: &Expression, - _ctx: &StatsRewriteCtx<'_>, + ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - Ok(null_count(expr.child(0)).map(|null_count| eq(null_count, lit(0u64)))) + Ok(null_count(expr.child(0), ctx).map(|null_count| eq(null_count, lit(0u64)))) } fn satisfy( &self, expr: &Expression, - _ctx: &StatsRewriteCtx<'_>, + ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - Ok(null_count(expr.child(0)) + Ok(null_count(expr.child(0), ctx) .map(|null_count| eq(null_count, RowCount.new_expr(EmptyOptions, [])))) } } @@ -227,18 +227,18 @@ impl StatsRewriteRule for IsNotNullLegacyStatsRewrite { fn falsify( &self, expr: &Expression, - _ctx: &StatsRewriteCtx<'_>, + ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - Ok(null_count(expr.child(0)) + Ok(null_count(expr.child(0), ctx) .map(|null_count| eq(null_count, RowCount.new_expr(EmptyOptions, [])))) } fn satisfy( &self, expr: &Expression, - _ctx: &StatsRewriteCtx<'_>, + ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - Ok(null_count(expr.child(0)).map(|null_count| eq(null_count, lit(0u64)))) + Ok(null_count(expr.child(0), ctx).map(|null_count| eq(null_count, lit(0u64)))) } } @@ -287,7 +287,7 @@ impl StatsRewriteRule for LikeStatsRewrite { fn falsify( &self, expr: &Expression, - _ctx: &StatsRewriteCtx<'_>, + ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { let like_options = expr.as_::(); if like_options.negated || like_options.case_insensitive { @@ -304,8 +304,8 @@ impl StatsRewriteRule for LikeStatsRewrite { let source = expr.child(0); Ok(match LikeVariant::from_str(pattern) { Some(LikeVariant::Exact(text)) => { - min(source) - .zip(max(source)) + min(source, ctx) + .zip(max(source, ctx)) .map(|(source_min, source_max)| { or( gt(source_min, lit(text.as_ref())), @@ -317,8 +317,8 @@ impl StatsRewriteRule for LikeStatsRewrite { let Some(successor) = prefix.to_string().increment().ok() else { return Ok(None); }; - min(source) - .zip(max(source)) + min(source, ctx) + .zip(max(source, ctx)) .map(|(source_min, source_max)| { or( gt_eq(source_min, lit(successor)), @@ -361,10 +361,10 @@ impl StatsRewriteRule for ListContainsStatsRewrite { return Ok(Some(lit(true))); } - let Some(value_max) = max(needle) else { + let Some(value_max) = max(needle, ctx) else { return Ok(None); }; - let Some(value_min) = min(needle) else { + let Some(value_min) = min(needle, ctx) else { return Ok(None); }; @@ -398,10 +398,10 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite { let Some((operator, lhs_stat)) = (match dynamic.operator { CompareOperator::Eq | CompareOperator::NotEq => None, - CompareOperator::Gt => max(lhs).map(|lhs_stat| (CompareOperator::Lte, lhs_stat)), - CompareOperator::Gte => max(lhs).map(|lhs_stat| (CompareOperator::Lt, lhs_stat)), - CompareOperator::Lt => min(lhs).map(|lhs_stat| (CompareOperator::Gte, lhs_stat)), - CompareOperator::Lte => min(lhs).map(|lhs_stat| (CompareOperator::Gt, lhs_stat)), + CompareOperator::Gt => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lte, lhs_stat)), + CompareOperator::Gte => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lt, lhs_stat)), + CompareOperator::Lt => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gte, lhs_stat)), + CompareOperator::Lte => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gt, lhs_stat)), }) else { return Ok(None); }; @@ -418,16 +418,16 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite { } } -fn min(expr: &Expression) -> Option { - stat_expr(expr, Stat::Min) +fn min(expr: &Expression, ctx: &StatsRewriteCtx<'_>) -> Option { + stat_expr(expr, Stat::Min, ctx) } -fn max(expr: &Expression) -> Option { - stat_expr(expr, Stat::Max) +fn max(expr: &Expression, ctx: &StatsRewriteCtx<'_>) -> Option { + stat_expr(expr, Stat::Max, ctx) } -fn null_count(expr: &Expression) -> Option { - stat_expr(expr, Stat::NullCount) +fn null_count(expr: &Expression, ctx: &StatsRewriteCtx<'_>) -> Option { + stat_expr(expr, Stat::NullCount, ctx) } fn all_null(expr: &Expression) -> Expression { @@ -474,7 +474,7 @@ fn has_nans(dtype: &DType) -> bool { matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) } -fn stat_expr(expr: &Expression, stat: Stat) -> Option { +fn stat_expr(expr: &Expression, stat: Stat, ctx: &StatsRewriteCtx<'_>) -> Option { if let Some(literal) = literal_stat(expr, stat) { return Some(literal); } @@ -487,11 +487,18 @@ fn stat_expr(expr: &Expression, stat: Stat) -> Option { } if let Some(dtype) = expr.as_opt::() { - return cast_stat(expr.child(0), dtype, stat); - } - - stat.aggregate_fn() - .map(|aggregate_fn| stat_fn(expr.clone(), aggregate_fn)) + return cast_stat(expr.child(0), dtype, stat, ctx); + } + + let aggregate_fn = stat.aggregate_fn()?; + // The aggregate may not support the expression's dtype, e.g. min/max over structs, + // even when the predicate itself is well-typed. Such stats cannot be lowered later, + // so do not reference them in the rewrite. + let input_dtype = ctx.return_dtype(expr).ok()?; + aggregate_fn + .return_dtype(&input_dtype) + .is_some() + .then(|| stat_fn(expr.clone(), aggregate_fn)) } fn with_nan_predicate( @@ -545,10 +552,15 @@ fn literal_stat(expr: &Expression, stat: Stat) -> Option { } } -fn cast_stat(expr: &Expression, dtype: &DType, stat: Stat) -> Option { +fn cast_stat( + expr: &Expression, + dtype: &DType, + stat: Stat, + ctx: &StatsRewriteCtx<'_>, +) -> Option { match stat { - Stat::Min | Stat::Max => stat_expr(expr, stat).map(|stat| cast(stat, dtype.clone())), - Stat::NaNCount | Stat::Sum | Stat::UncompressedSizeInBytes => stat_expr(expr, stat), + Stat::Min | Stat::Max => stat_expr(expr, stat, ctx).map(|stat| cast(stat, dtype.clone())), + Stat::NaNCount | Stat::Sum | Stat::UncompressedSizeInBytes => stat_expr(expr, stat, ctx), Stat::NullCount | Stat::IsConstant | Stat::IsSorted | Stat::IsStrictSorted => None, } } @@ -626,11 +638,19 @@ mod tests { ("f", DType::Primitive(PType::F32, Nullability::NonNullable)), ("s", DType::Utf8(Nullability::NonNullable)), ("t", DType::Utf8(Nullability::NonNullable)), + ("n", nested_struct_dtype()), ]), Nullability::NonNullable, ) } + fn nested_struct_dtype() -> DType { + DType::Struct( + StructFields::from_iter([("x", DType::Primitive(PType::F32, Nullability::Nullable))]), + Nullability::NonNullable, + ) + } + fn falsify(expr: &Expression) -> VortexResult> { expr.falsify(&test_scope(), &SESSION) } @@ -848,6 +868,19 @@ mod tests { Ok(()) } + #[test] + fn skips_falsifier_when_min_max_unsupported_for_dtype() -> VortexResult<()> { + // Struct comparisons are valid predicates, but min/max aggregates do not + // support struct inputs, so no stats-backed falsifier should be produced. + let struct_scalar = Scalar::struct_( + nested_struct_dtype(), + vec![Scalar::primitive(1.0f32, Nullability::Nullable)], + ); + assert_eq!(falsify(<_eq(col("n"), lit(struct_scalar.clone())))?, None); + assert_eq!(falsify(&eq(col("n"), lit(struct_scalar)))?, None); + Ok(()) + } + #[test] fn forwards_min_max_through_safe_cast() -> VortexResult<()> { let dtype = DType::Primitive(PType::I64, Nullability::NonNullable);