From 71235d08bc956f768a7529cb0e77d29b5a17a700 Mon Sep 17 00:00:00 2001 From: Wouter Devriendt Date: Sun, 12 Apr 2026 14:02:28 -0700 Subject: [PATCH] [Learning] Fix #179415: Use wrapping cast instead of saturated cast for MPS sum reduction MPS backend's sum reduction was using MPSGraph's castTensor which performs saturated casting (clamps to type range), causing incorrect results for integer overflow. For example, summing uint8 values [255, 2, 1, 5, 3, 6] returned 255 (saturated) instead of 16 (wrapping). This adds modular arithmetic (floor-modulo) before the final cast for small integer types (uint8, int8, int16), matching CPU's wrapping behavior: - Unsigned: value - floor(value / range) * range - Signed: shift to unsigned, modulo, shift back Fixes pytorch/pytorch#179415 --- .../ATen/native/mps/operations/ReduceOps.mm | 65 ++++++++++++++++++- test/test_mps.py | 14 ++++ 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 2ec070c84be7e..ce2ee1df8d7aa 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -291,9 +291,70 @@ static void reduction_out_mps(const Tensor& input_t, } } + // Apply wrapping (modular) arithmetic for small integer output types to + // match CPU behavior. MPSGraph's castTensor uses saturated casting (clamps + // to type range), but CPU preserves only the least significant bits + // (C-standard wrapping for unsigned types, two's complement for signed). + // We compute: wrapped = value - floor(value / range) * range + // which is equivalent to Python's floor-modulo (value % range). MPSGraphTensor* outputTensor = castOutputTensor; - if (getMPSDataType(output_t) != [castOutputTensor dataType]) { - outputTensor = castMPSTensor(mpsGraph, castOutputTensor, output_t.scalar_type()); + ScalarType outScalarType = output_t.scalar_type(); + if (getMPSDataType(output_t) != [castOutputTensor dataType] && + isIntegralType(outScalarType, /*includeBool=*/false)) { + double typeRange = 0; + double typeMin = 0; + if (outScalarType == kByte) { // uint8 + typeRange = 256.0; + typeMin = 0.0; + } else if (outScalarType == kChar) { // int8 + typeRange = 256.0; + typeMin = -128.0; + } else if (outScalarType == kShort) { // int16 + typeRange = 65536.0; + typeMin = -32768.0; + } + + if (typeRange > 0) { + MPSDataType floatType = [castOutputTensor dataType]; + MPSGraphTensor* rangeTensor = [mpsGraph constantWithScalar:typeRange dataType:floatType]; + + if (typeMin == 0) { + // Unsigned: floor(value / range) * range gives the multiple to subtract + MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:castOutputTensor + secondaryTensor:rangeTensor + name:nil]; + MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil]; + MPSGraphTensor* multTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor + secondaryTensor:rangeTensor + name:nil]; + castOutputTensor = [mpsGraph subtractionWithPrimaryTensor:castOutputTensor + secondaryTensor:multTensor + name:nil]; + } else { + // Signed: shift to unsigned range [0, range), modulo, shift back + MPSGraphTensor* shiftTensor = [mpsGraph constantWithScalar:-typeMin dataType:floatType]; + MPSGraphTensor* shifted = [mpsGraph additionWithPrimaryTensor:castOutputTensor + secondaryTensor:shiftTensor + name:nil]; + MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:shifted + secondaryTensor:rangeTensor + name:nil]; + MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil]; + MPSGraphTensor* multTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor + secondaryTensor:rangeTensor + name:nil]; + MPSGraphTensor* wrapped = [mpsGraph subtractionWithPrimaryTensor:shifted + secondaryTensor:multTensor + name:nil]; + castOutputTensor = [mpsGraph subtractionWithPrimaryTensor:wrapped + secondaryTensor:shiftTensor + name:nil]; + } + } + + outputTensor = castMPSTensor(mpsGraph, castOutputTensor, outScalarType); + } else if (getMPSDataType(output_t) != [castOutputTensor dataType]) { + outputTensor = castMPSTensor(mpsGraph, castOutputTensor, outScalarType); } newCachedGraph->inputTensor_ = inputTensor; diff --git a/test/test_mps.py b/test/test_mps.py index 54ae778cecbeb..140b71d806b57 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -5451,6 +5451,20 @@ def helper(n, c, h, w, dtype=torch.float32): self.assertEqual(x.numel(), 8) self.assertEqual(x.max().item(), 30.0) + # Regression test for https://github.com/pytorch/pytorch/issues/179415 + # MPS sum should use wrapping (not saturated) cast for integer overflow + def test_sum_integer_overflow_wrapping(self): + # uint8: [-1, 2, 1, 5, 3, 6] as uint8 = [255, 2, 1, 5, 3, 6], sum = 272, wrapped = 16 + example = [[-1, 2, 1], [5, 3, 6]] + cpu_x = torch.tensor(example, dtype=torch.uint8, device="cpu") + mps_x = torch.tensor(example, dtype=torch.uint8, device="mps") + self.assertEqual(mps_x.sum(dtype=torch.uint8), cpu_x.sum(dtype=torch.uint8)) + + # int8: [127, 1] should wrap to -128, not saturate at 127 + cpu_y = torch.tensor([127, 1], dtype=torch.int8, device="cpu") + mps_y = torch.tensor([127, 1], dtype=torch.int8, device="mps") + self.assertEqual(mps_y.sum(dtype=torch.int8), cpu_y.sum(dtype=torch.int8)) + # Test forward prod def test_prod(self): def helper(shape, dtype=torch.float32):