diff --git a/vortex-array/src/arrays/decimal/compute/cast.rs b/vortex-array/src/arrays/decimal/compute/cast.rs index 432313d3cb6..7ed4d287fbb 100644 --- a/vortex-array/src/arrays/decimal/compute/cast.rs +++ b/vortex-array/src/arrays/decimal/compute/cast.rs @@ -77,8 +77,8 @@ impl CastKernel for Decimal { ); }; - // Scale changes are not yet supported - if from_decimal_dtype.scale() != to_decimal_dtype.scale() { + // Narrowing the scale (dropping fractional digits) is not supported. + if from_decimal_dtype.scale() > to_decimal_dtype.scale() { vortex_bail!( "Casting decimal with scale {} to scale {} not yet implemented", from_decimal_dtype.scale(), @@ -86,8 +86,12 @@ impl CastKernel for Decimal { ); } - // Downcasting precision is not yet supported - if to_decimal_dtype.precision() < from_decimal_dtype.precision() { + // The target must retain at least the source's integer digits. + let from_integer_digits = + i16::from(from_decimal_dtype.precision()) - i16::from(from_decimal_dtype.scale()); + let to_integer_digits = + i16::from(to_decimal_dtype.precision()) - i16::from(to_decimal_dtype.scale()); + if to_integer_digits < from_integer_digits { vortex_bail!( "Downcasting decimal from precision {} to {} not yet implemented", from_decimal_dtype.precision(), @@ -105,6 +109,12 @@ impl CastKernel for Decimal { .validity()? .cast_nullability(*to_nullability, array.len(), ctx)?; + // Widening the scale multiplies unscaled values by a power of ten. + if from_decimal_dtype.scale() < to_decimal_dtype.scale() { + let rescaled = rescale_decimal_values(array, *to_decimal_dtype, new_validity)?; + return Ok(Some(rescaled.into_array())); + } + // If the target needs a wider physical type, upcast the values let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype); let array = if target_values_type > array.values_type() { @@ -128,6 +138,73 @@ impl CastKernel for Decimal { } } +/// Rescale a DecimalArray to a wider scale (e.g. `(16,2)` → `(31,4)`), +/// multiplying unscaled values by the corresponding power of ten. The +/// result is stored at the width the target precision requires. +fn rescale_decimal_values( + array: ArrayView<'_, Decimal>, + to: crate::dtype::DecimalDType, + validity: crate::validity::Validity, +) -> VortexResult { + let from = array.decimal_dtype(); + let scale_up = u32::try_from(to.scale() - from.scale()) + .map_err(|_| vortex_error::vortex_err!("rescale requires a widening scale"))?; + let factor = 10i128 + .checked_pow(scale_up) + .ok_or_else(|| vortex_error::vortex_err!("rescale factor overflows i128"))?; + + // Gather unscaled values as i128 (i256 sources are unsupported). + let values: Vec = match array.values_type() { + DecimalType::I8 => array + .buffer::() + .iter() + .map(|&v| i128::from(v)) + .collect(), + DecimalType::I16 => array + .buffer::() + .iter() + .map(|&v| i128::from(v)) + .collect(), + DecimalType::I32 => array + .buffer::() + .iter() + .map(|&v| i128::from(v)) + .collect(), + DecimalType::I64 => array + .buffer::() + .iter() + .map(|&v| i128::from(v)) + .collect(), + DecimalType::I128 => array.buffer::().iter().copied().collect(), + DecimalType::I256 => vortex_bail!("rescaling i256 decimals is not supported"), + }; + + let rescaled = values + .into_iter() + .map(|v| { + v.checked_mul(factor) + .ok_or_else(|| vortex_error::vortex_err!("decimal rescale overflows i128")) + }) + .collect::>>()?; + + match DecimalType::smallest_decimal_value_type(&to) { + DecimalType::I256 => vortex_bail!("rescaling into i256 decimals is not supported"), + DecimalType::I128 => Ok(DecimalArray::new(Buffer::from_iter(rescaled), to, validity)), + // Narrow storage targets: the values fit by the precision check. + DecimalType::I64 | DecimalType::I32 | DecimalType::I16 | DecimalType::I8 => { + let narrowed = rescaled + .into_iter() + .map(|v| { + i64::try_from(v).map_err(|_| { + vortex_error::vortex_err!("rescaled decimal exceeds target width") + }) + }) + .collect::>>()?; + Ok(DecimalArray::new(Buffer::from_iter(narrowed), to, validity)) + } + } +} + /// Upcast a DecimalArray to a wider physical representation (e.g., i32 -> i64) while keeping /// the same precision and scale. /// @@ -262,19 +339,35 @@ mod tests { } #[test] - fn cast_different_scale_fails() { + fn cast_widening_scale_rescales() { + let array = DecimalArray::new( + buffer![100i32, -250], + DecimalDType::new(10, 2), + Validity::NonNullable, + ); + + // 1.00 and -2.50 at scale 2 become 1.000 and -2.500 at scale 3. + let wider = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable); + #[expect(deprecated)] + let casted = array.into_array().cast(wider.clone()).unwrap().to_decimal(); + assert_eq!(casted.dtype(), &wider); + assert_eq!(casted.buffer::().as_ref(), &[1000i64, -2500]); + } + + #[test] + fn cast_narrowing_scale_fails() { let array = DecimalArray::new( buffer![100i32], DecimalDType::new(10, 2), Validity::NonNullable, ); - // Try to cast to different scale - not supported - let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable); + // Dropping fractional digits is not supported. + let narrower = DType::Decimal(DecimalDType::new(15, 1), Nullability::NonNullable); #[expect(deprecated)] let result = array .into_array() - .cast(different_dtype) + .cast(narrower) .and_then(|a| a.to_canonical().map(|c| c.into_array())); assert!(result.is_err()); @@ -282,7 +375,7 @@ mod tests { result .unwrap_err() .to_string() - .contains("Casting decimal with scale 2 to scale 3 not yet implemented") + .contains("Casting decimal with scale 2 to scale 1 not yet implemented") ); } diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index 1c860cb75b5..1b0a62a4ac3 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -17,6 +17,7 @@ use vortex_session::registry::CachedId; use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; +use crate::dtype::DecimalDType; use crate::expr::StatsCatalog; use crate::expr::and; use crate::expr::and_collect; @@ -46,6 +47,47 @@ pub(crate) use numeric::*; use crate::scalar::NumericOperator; +/// Output decimal type of an arithmetic `operator` over two operands that +/// have already been coerced to the same decimal type. +/// +/// Mirrors the Hive-style rules `arrow-arith` applies at execution time +/// (see `arrow_arith::numeric::decimal_op`), including precision saturation +/// at the physical width's maximum: vortex lowers precisions `<= 38` to +/// Arrow `Decimal128` and wider decimals to `Decimal256`. +fn decimal_arithmetic_dtype( + operator: Operator, + operand: DecimalDType, +) -> VortexResult { + let p = u16::from(operand.precision()); + let s = i16::from(operand.scale()); + let (max_precision, max_scale): (u16, i16) = if p <= 38 { (38, 38) } else { (76, 76) }; + let (precision, scale) = match operator { + // scale = max(s, s); precision = max(p - s, p - s) + scale + 1 + Operator::Add | Operator::Sub => ((p + 1).min(max_precision), s), + // scale = s + s; precision = p + p + 1 + Operator::Mul => { + let scale = s + s; + if scale > max_scale { + vortex_bail!( + "output scale of {operand} {operator} {operand} exceeds the maximum scale \ + {max_scale}" + ); + } + ((p + p + 1).min(max_precision), scale) + } + // scale = min(s + 4, max); precision = p - s + s + scale + Operator::Div => { + let scale = (s + 4).min(max_scale); + (((p + scale.unsigned_abs()).min(max_precision)), scale) + } + _ => vortex_bail!("operator {operator} is not arithmetic"), + }; + Ok(DecimalDType::new( + u8::try_from(precision).unwrap_or(u8::MAX), + i8::try_from(scale).unwrap_or(i8::MAX), + )) +} + #[derive(Clone)] pub struct Binary; @@ -122,6 +164,15 @@ impl ScalarFnVTable for Binary { if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) { return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability())); } + if let (DType::Decimal(l, _), DType::Decimal(r, _)) = (lhs, rhs) + && l == r + { + let result = decimal_arithmetic_dtype(*operator, *l)?; + return Ok(DType::Decimal( + result, + lhs.nullability() | rhs.nullability(), + )); + } vortex_bail!( "incompatible types for arithmetic operation: {} {}", lhs, @@ -332,6 +383,47 @@ mod tests { use crate::expr::or_collect; use crate::expr::test_harness; use crate::scalar::Scalar; + + /// The decimal arithmetic dtypes derived at plan time must match what + /// arrow produces at execution time (see `decimal_arithmetic_dtype`). + #[test] + fn decimal_arithmetic_dtype_matches_execution() -> VortexResult<()> { + use vortex_buffer::buffer; + + use crate::Canonical; + use crate::IntoArray; + use crate::arrays::DecimalArray; + use crate::dtype::DecimalDType; + use crate::scalar::DecimalValue; + use crate::scalar_fn::ScalarFnVTableExt; + use crate::validity::Validity; + + let dec = DecimalDType::new(15, 2); + let values = + DecimalArray::new(buffer![100i128, 250, 1099], dec, Validity::NonNullable).into_array(); + let rhs = lit(Scalar::decimal( + DecimalValue::I128(50), + dec, + Nullability::NonNullable, + )); + for op in [Operator::Add, Operator::Sub, Operator::Mul, Operator::Div] { + let expr = Binary.try_new_expr(op, [crate::expr::root(), rhs.clone()])?; + let derived = expr.return_dtype(values.dtype())?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let executed = values + .clone() + .apply(&expr)? + .execute::(&mut ctx)? + .into_array(); + assert_eq!( + executed.dtype(), + &derived, + "derived dtype diverges from execution for {op}" + ); + } + Ok(()) + } + #[test] fn and_collect_balanced() { let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];