From 2c37139dc805334a461d7b4d0dff26824af73f8c Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 10 Jun 2026 07:13:16 -0700 Subject: [PATCH] vortex-array: decimal arithmetic dtypes and widening rescale cast Binary::return_dtype derives Hive-style result types for decimal + - * / (mirroring arrow-arith's execution-time rules, with precision saturation at the physical width), instead of erroring at plan time. The decimal cast kernel gains widening rescale so expression coercion can align decimal operand types. Needed by vortex-engine for TPC-H Q1/Q6 decimal expressions. Co-Authored-By: Claude Fable 5 Signed-off-by: Nicholas Gates --- .../src/arrays/decimal/compute/cast.rs | 111 ++++++++++++++++-- vortex-array/src/scalar_fn/fns/binary/mod.rs | 92 +++++++++++++++ 2 files changed, 194 insertions(+), 9 deletions(-) diff --git a/vortex-array/src/arrays/decimal/compute/cast.rs b/vortex-array/src/arrays/decimal/compute/cast.rs index 432313d3cb6..7ed4d287fbb 100644 --- a/vortex-array/src/arrays/decimal/compute/cast.rs +++ b/vortex-array/src/arrays/decimal/compute/cast.rs @@ -77,8 +77,8 @@ impl CastKernel for Decimal { ); }; - // Scale changes are not yet supported - if from_decimal_dtype.scale() != to_decimal_dtype.scale() { + // Narrowing the scale (dropping fractional digits) is not supported. + if from_decimal_dtype.scale() > to_decimal_dtype.scale() { vortex_bail!( "Casting decimal with scale {} to scale {} not yet implemented", from_decimal_dtype.scale(), @@ -86,8 +86,12 @@ impl CastKernel for Decimal { ); } - // Downcasting precision is not yet supported - if to_decimal_dtype.precision() < from_decimal_dtype.precision() { + // The target must retain at least the source's integer digits. + let from_integer_digits = + i16::from(from_decimal_dtype.precision()) - i16::from(from_decimal_dtype.scale()); + let to_integer_digits = + i16::from(to_decimal_dtype.precision()) - i16::from(to_decimal_dtype.scale()); + if to_integer_digits < from_integer_digits { vortex_bail!( "Downcasting decimal from precision {} to {} not yet implemented", from_decimal_dtype.precision(), @@ -105,6 +109,12 @@ impl CastKernel for Decimal { .validity()? .cast_nullability(*to_nullability, array.len(), ctx)?; + // Widening the scale multiplies unscaled values by a power of ten. + if from_decimal_dtype.scale() < to_decimal_dtype.scale() { + let rescaled = rescale_decimal_values(array, *to_decimal_dtype, new_validity)?; + return Ok(Some(rescaled.into_array())); + } + // If the target needs a wider physical type, upcast the values let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype); let array = if target_values_type > array.values_type() { @@ -128,6 +138,73 @@ impl CastKernel for Decimal { } } +/// Rescale a DecimalArray to a wider scale (e.g. `(16,2)` → `(31,4)`), +/// multiplying unscaled values by the corresponding power of ten. The +/// result is stored at the width the target precision requires. +fn rescale_decimal_values( + array: ArrayView<'_, Decimal>, + to: crate::dtype::DecimalDType, + validity: crate::validity::Validity, +) -> VortexResult { + let from = array.decimal_dtype(); + let scale_up = u32::try_from(to.scale() - from.scale()) + .map_err(|_| vortex_error::vortex_err!("rescale requires a widening scale"))?; + let factor = 10i128 + .checked_pow(scale_up) + .ok_or_else(|| vortex_error::vortex_err!("rescale factor overflows i128"))?; + + // Gather unscaled values as i128 (i256 sources are unsupported). + let values: Vec = match array.values_type() { + DecimalType::I8 => array + .buffer::() + .iter() + .map(|&v| i128::from(v)) + .collect(), + DecimalType::I16 => array + .buffer::() + .iter() + .map(|&v| i128::from(v)) + .collect(), + DecimalType::I32 => array + .buffer::() + .iter() + .map(|&v| i128::from(v)) + .collect(), + DecimalType::I64 => array + .buffer::() + .iter() + .map(|&v| i128::from(v)) + .collect(), + DecimalType::I128 => array.buffer::().iter().copied().collect(), + DecimalType::I256 => vortex_bail!("rescaling i256 decimals is not supported"), + }; + + let rescaled = values + .into_iter() + .map(|v| { + v.checked_mul(factor) + .ok_or_else(|| vortex_error::vortex_err!("decimal rescale overflows i128")) + }) + .collect::>>()?; + + match DecimalType::smallest_decimal_value_type(&to) { + DecimalType::I256 => vortex_bail!("rescaling into i256 decimals is not supported"), + DecimalType::I128 => Ok(DecimalArray::new(Buffer::from_iter(rescaled), to, validity)), + // Narrow storage targets: the values fit by the precision check. + DecimalType::I64 | DecimalType::I32 | DecimalType::I16 | DecimalType::I8 => { + let narrowed = rescaled + .into_iter() + .map(|v| { + i64::try_from(v).map_err(|_| { + vortex_error::vortex_err!("rescaled decimal exceeds target width") + }) + }) + .collect::>>()?; + Ok(DecimalArray::new(Buffer::from_iter(narrowed), to, validity)) + } + } +} + /// Upcast a DecimalArray to a wider physical representation (e.g., i32 -> i64) while keeping /// the same precision and scale. /// @@ -262,19 +339,35 @@ mod tests { } #[test] - fn cast_different_scale_fails() { + fn cast_widening_scale_rescales() { + let array = DecimalArray::new( + buffer![100i32, -250], + DecimalDType::new(10, 2), + Validity::NonNullable, + ); + + // 1.00 and -2.50 at scale 2 become 1.000 and -2.500 at scale 3. + let wider = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable); + #[expect(deprecated)] + let casted = array.into_array().cast(wider.clone()).unwrap().to_decimal(); + assert_eq!(casted.dtype(), &wider); + assert_eq!(casted.buffer::().as_ref(), &[1000i64, -2500]); + } + + #[test] + fn cast_narrowing_scale_fails() { let array = DecimalArray::new( buffer![100i32], DecimalDType::new(10, 2), Validity::NonNullable, ); - // Try to cast to different scale - not supported - let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable); + // Dropping fractional digits is not supported. + let narrower = DType::Decimal(DecimalDType::new(15, 1), Nullability::NonNullable); #[expect(deprecated)] let result = array .into_array() - .cast(different_dtype) + .cast(narrower) .and_then(|a| a.to_canonical().map(|c| c.into_array())); assert!(result.is_err()); @@ -282,7 +375,7 @@ mod tests { result .unwrap_err() .to_string() - .contains("Casting decimal with scale 2 to scale 3 not yet implemented") + .contains("Casting decimal with scale 2 to scale 1 not yet implemented") ); } diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index 1c860cb75b5..1b0a62a4ac3 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -17,6 +17,7 @@ use vortex_session::registry::CachedId; use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; +use crate::dtype::DecimalDType; use crate::expr::StatsCatalog; use crate::expr::and; use crate::expr::and_collect; @@ -46,6 +47,47 @@ pub(crate) use numeric::*; use crate::scalar::NumericOperator; +/// Output decimal type of an arithmetic `operator` over two operands that +/// have already been coerced to the same decimal type. +/// +/// Mirrors the Hive-style rules `arrow-arith` applies at execution time +/// (see `arrow_arith::numeric::decimal_op`), including precision saturation +/// at the physical width's maximum: vortex lowers precisions `<= 38` to +/// Arrow `Decimal128` and wider decimals to `Decimal256`. +fn decimal_arithmetic_dtype( + operator: Operator, + operand: DecimalDType, +) -> VortexResult { + let p = u16::from(operand.precision()); + let s = i16::from(operand.scale()); + let (max_precision, max_scale): (u16, i16) = if p <= 38 { (38, 38) } else { (76, 76) }; + let (precision, scale) = match operator { + // scale = max(s, s); precision = max(p - s, p - s) + scale + 1 + Operator::Add | Operator::Sub => ((p + 1).min(max_precision), s), + // scale = s + s; precision = p + p + 1 + Operator::Mul => { + let scale = s + s; + if scale > max_scale { + vortex_bail!( + "output scale of {operand} {operator} {operand} exceeds the maximum scale \ + {max_scale}" + ); + } + ((p + p + 1).min(max_precision), scale) + } + // scale = min(s + 4, max); precision = p - s + s + scale + Operator::Div => { + let scale = (s + 4).min(max_scale); + (((p + scale.unsigned_abs()).min(max_precision)), scale) + } + _ => vortex_bail!("operator {operator} is not arithmetic"), + }; + Ok(DecimalDType::new( + u8::try_from(precision).unwrap_or(u8::MAX), + i8::try_from(scale).unwrap_or(i8::MAX), + )) +} + #[derive(Clone)] pub struct Binary; @@ -122,6 +164,15 @@ impl ScalarFnVTable for Binary { if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) { return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability())); } + if let (DType::Decimal(l, _), DType::Decimal(r, _)) = (lhs, rhs) + && l == r + { + let result = decimal_arithmetic_dtype(*operator, *l)?; + return Ok(DType::Decimal( + result, + lhs.nullability() | rhs.nullability(), + )); + } vortex_bail!( "incompatible types for arithmetic operation: {} {}", lhs, @@ -332,6 +383,47 @@ mod tests { use crate::expr::or_collect; use crate::expr::test_harness; use crate::scalar::Scalar; + + /// The decimal arithmetic dtypes derived at plan time must match what + /// arrow produces at execution time (see `decimal_arithmetic_dtype`). + #[test] + fn decimal_arithmetic_dtype_matches_execution() -> VortexResult<()> { + use vortex_buffer::buffer; + + use crate::Canonical; + use crate::IntoArray; + use crate::arrays::DecimalArray; + use crate::dtype::DecimalDType; + use crate::scalar::DecimalValue; + use crate::scalar_fn::ScalarFnVTableExt; + use crate::validity::Validity; + + let dec = DecimalDType::new(15, 2); + let values = + DecimalArray::new(buffer![100i128, 250, 1099], dec, Validity::NonNullable).into_array(); + let rhs = lit(Scalar::decimal( + DecimalValue::I128(50), + dec, + Nullability::NonNullable, + )); + for op in [Operator::Add, Operator::Sub, Operator::Mul, Operator::Div] { + let expr = Binary.try_new_expr(op, [crate::expr::root(), rhs.clone()])?; + let derived = expr.return_dtype(values.dtype())?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let executed = values + .clone() + .apply(&expr)? + .execute::(&mut ctx)? + .into_array(); + assert_eq!( + executed.dtype(), + &derived, + "derived dtype diverges from execution for {op}" + ); + } + Ok(()) + } + #[test] fn and_collect_balanced() { let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];