Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 102 additions & 9 deletions vortex-array/src/arrays/decimal/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,21 @@ impl CastKernel for Decimal {
);
};

// Scale changes are not yet supported
if from_decimal_dtype.scale() != to_decimal_dtype.scale() {
// Narrowing the scale (dropping fractional digits) is not supported.
if from_decimal_dtype.scale() > to_decimal_dtype.scale() {
vortex_bail!(
"Casting decimal with scale {} to scale {} not yet implemented",
from_decimal_dtype.scale(),
to_decimal_dtype.scale()
);
}

// Downcasting precision is not yet supported
if to_decimal_dtype.precision() < from_decimal_dtype.precision() {
// The target must retain at least the source's integer digits.
let from_integer_digits =
i16::from(from_decimal_dtype.precision()) - i16::from(from_decimal_dtype.scale());
let to_integer_digits =
i16::from(to_decimal_dtype.precision()) - i16::from(to_decimal_dtype.scale());
if to_integer_digits < from_integer_digits {
vortex_bail!(
"Downcasting decimal from precision {} to {} not yet implemented",
from_decimal_dtype.precision(),
Expand All @@ -105,6 +109,12 @@ impl CastKernel for Decimal {
.validity()?
.cast_nullability(*to_nullability, array.len(), ctx)?;

// Widening the scale multiplies unscaled values by a power of ten.
if from_decimal_dtype.scale() < to_decimal_dtype.scale() {
let rescaled = rescale_decimal_values(array, *to_decimal_dtype, new_validity)?;
return Ok(Some(rescaled.into_array()));
}

// If the target needs a wider physical type, upcast the values
let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype);
let array = if target_values_type > array.values_type() {
Expand All @@ -128,6 +138,73 @@ impl CastKernel for Decimal {
}
}

/// Rescale a DecimalArray to a wider scale (e.g. `(16,2)` → `(31,4)`),
/// multiplying unscaled values by the corresponding power of ten. The
/// result is stored at the width the target precision requires.
fn rescale_decimal_values(
array: ArrayView<'_, Decimal>,
to: crate::dtype::DecimalDType,
validity: crate::validity::Validity,
) -> VortexResult<DecimalArray> {
let from = array.decimal_dtype();
let scale_up = u32::try_from(to.scale() - from.scale())
.map_err(|_| vortex_error::vortex_err!("rescale requires a widening scale"))?;
let factor = 10i128
.checked_pow(scale_up)
.ok_or_else(|| vortex_error::vortex_err!("rescale factor overflows i128"))?;

// Gather unscaled values as i128 (i256 sources are unsupported).
let values: Vec<i128> = match array.values_type() {
DecimalType::I8 => array
.buffer::<i8>()
.iter()
.map(|&v| i128::from(v))
.collect(),
DecimalType::I16 => array
.buffer::<i16>()
.iter()
.map(|&v| i128::from(v))
.collect(),
DecimalType::I32 => array
.buffer::<i32>()
.iter()
.map(|&v| i128::from(v))
.collect(),
DecimalType::I64 => array
.buffer::<i64>()
.iter()
.map(|&v| i128::from(v))
.collect(),
DecimalType::I128 => array.buffer::<i128>().iter().copied().collect(),
DecimalType::I256 => vortex_bail!("rescaling i256 decimals is not supported"),
};

let rescaled = values
.into_iter()
.map(|v| {
v.checked_mul(factor)
.ok_or_else(|| vortex_error::vortex_err!("decimal rescale overflows i128"))
})
.collect::<VortexResult<Vec<i128>>>()?;

match DecimalType::smallest_decimal_value_type(&to) {
DecimalType::I256 => vortex_bail!("rescaling into i256 decimals is not supported"),
DecimalType::I128 => Ok(DecimalArray::new(Buffer::from_iter(rescaled), to, validity)),
// Narrow storage targets: the values fit by the precision check.
DecimalType::I64 | DecimalType::I32 | DecimalType::I16 | DecimalType::I8 => {
let narrowed = rescaled
.into_iter()
.map(|v| {
i64::try_from(v).map_err(|_| {
vortex_error::vortex_err!("rescaled decimal exceeds target width")
})
})
.collect::<VortexResult<Vec<i64>>>()?;
Ok(DecimalArray::new(Buffer::from_iter(narrowed), to, validity))
}
}
}

/// Upcast a DecimalArray to a wider physical representation (e.g., i32 -> i64) while keeping
/// the same precision and scale.
///
Expand Down Expand Up @@ -262,27 +339,43 @@ mod tests {
}

#[test]
fn cast_different_scale_fails() {
fn cast_widening_scale_rescales() {
let array = DecimalArray::new(
buffer![100i32, -250],
DecimalDType::new(10, 2),
Validity::NonNullable,
);

// 1.00 and -2.50 at scale 2 become 1.000 and -2.500 at scale 3.
let wider = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
#[expect(deprecated)]
let casted = array.into_array().cast(wider.clone()).unwrap().to_decimal();
assert_eq!(casted.dtype(), &wider);
assert_eq!(casted.buffer::<i64>().as_ref(), &[1000i64, -2500]);
}

#[test]
fn cast_narrowing_scale_fails() {
let array = DecimalArray::new(
buffer![100i32],
DecimalDType::new(10, 2),
Validity::NonNullable,
);

// Try to cast to different scale - not supported
let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
// Dropping fractional digits is not supported.
let narrower = DType::Decimal(DecimalDType::new(15, 1), Nullability::NonNullable);
#[expect(deprecated)]
let result = array
.into_array()
.cast(different_dtype)
.cast(narrower)
.and_then(|a| a.to_canonical().map(|c| c.into_array()));

assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Casting decimal with scale 2 to scale 3 not yet implemented")
.contains("Casting decimal with scale 2 to scale 1 not yet implemented")
);
}

Expand Down
92 changes: 92 additions & 0 deletions vortex-array/src/scalar_fn/fns/binary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use vortex_session::registry::CachedId;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::dtype::DType;
use crate::dtype::DecimalDType;
use crate::expr::StatsCatalog;
use crate::expr::and;
use crate::expr::and_collect;
Expand Down Expand Up @@ -46,6 +47,47 @@ pub(crate) use numeric::*;

use crate::scalar::NumericOperator;

/// Output decimal type of an arithmetic `operator` over two operands that
/// have already been coerced to the same decimal type.
///
/// Mirrors the Hive-style rules `arrow-arith` applies at execution time
/// (see `arrow_arith::numeric::decimal_op`), including precision saturation
/// at the physical width's maximum: vortex lowers precisions `<= 38` to
/// Arrow `Decimal128` and wider decimals to `Decimal256`.
fn decimal_arithmetic_dtype(
operator: Operator,
operand: DecimalDType,
) -> VortexResult<DecimalDType> {
let p = u16::from(operand.precision());
let s = i16::from(operand.scale());
let (max_precision, max_scale): (u16, i16) = if p <= 38 { (38, 38) } else { (76, 76) };
let (precision, scale) = match operator {
// scale = max(s, s); precision = max(p - s, p - s) + scale + 1
Operator::Add | Operator::Sub => ((p + 1).min(max_precision), s),
// scale = s + s; precision = p + p + 1
Operator::Mul => {
let scale = s + s;
if scale > max_scale {
vortex_bail!(
"output scale of {operand} {operator} {operand} exceeds the maximum scale \
{max_scale}"
);
}
((p + p + 1).min(max_precision), scale)
}
// scale = min(s + 4, max); precision = p - s + s + scale
Operator::Div => {
let scale = (s + 4).min(max_scale);
(((p + scale.unsigned_abs()).min(max_precision)), scale)
}
_ => vortex_bail!("operator {operator} is not arithmetic"),
};
Ok(DecimalDType::new(
u8::try_from(precision).unwrap_or(u8::MAX),
i8::try_from(scale).unwrap_or(i8::MAX),
))
}

#[derive(Clone)]
pub struct Binary;

Expand Down Expand Up @@ -122,6 +164,15 @@ impl ScalarFnVTable for Binary {
if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
}
if let (DType::Decimal(l, _), DType::Decimal(r, _)) = (lhs, rhs)
&& l == r
{
let result = decimal_arithmetic_dtype(*operator, *l)?;
return Ok(DType::Decimal(
result,
lhs.nullability() | rhs.nullability(),
));
}
vortex_bail!(
"incompatible types for arithmetic operation: {} {}",
lhs,
Expand Down Expand Up @@ -332,6 +383,47 @@ mod tests {
use crate::expr::or_collect;
use crate::expr::test_harness;
use crate::scalar::Scalar;

/// The decimal arithmetic dtypes derived at plan time must match what
/// arrow produces at execution time (see `decimal_arithmetic_dtype`).
#[test]
fn decimal_arithmetic_dtype_matches_execution() -> VortexResult<()> {
use vortex_buffer::buffer;

use crate::Canonical;
use crate::IntoArray;
use crate::arrays::DecimalArray;
use crate::dtype::DecimalDType;
use crate::scalar::DecimalValue;
use crate::scalar_fn::ScalarFnVTableExt;
use crate::validity::Validity;

let dec = DecimalDType::new(15, 2);
let values =
DecimalArray::new(buffer![100i128, 250, 1099], dec, Validity::NonNullable).into_array();
let rhs = lit(Scalar::decimal(
DecimalValue::I128(50),
dec,
Nullability::NonNullable,
));
for op in [Operator::Add, Operator::Sub, Operator::Mul, Operator::Div] {
let expr = Binary.try_new_expr(op, [crate::expr::root(), rhs.clone()])?;
let derived = expr.return_dtype(values.dtype())?;
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let executed = values
.clone()
.apply(&expr)?
.execute::<Canonical>(&mut ctx)?
.into_array();
assert_eq!(
executed.dtype(),
&derived,
"derived dtype diverges from execution for {op}"
);
}
Ok(())
}

#[test]
fn and_collect_balanced() {
let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
Expand Down
Loading