diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 2ec070c84be7e..ce2ee1df8d7aa 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -291,9 +291,70 @@ static void reduction_out_mps(const Tensor& input_t, } } + // Apply wrapping (modular) arithmetic for small integer output types to + // match CPU behavior. MPSGraph's castTensor uses saturated casting (clamps + // to type range), but CPU preserves only the least significant bits + // (C-standard wrapping for unsigned types, two's complement for signed). + // We compute: wrapped = value - floor(value / range) * range + // which is equivalent to Python's floor-modulo (value % range). MPSGraphTensor* outputTensor = castOutputTensor; - if (getMPSDataType(output_t) != [castOutputTensor dataType]) { - outputTensor = castMPSTensor(mpsGraph, castOutputTensor, output_t.scalar_type()); + ScalarType outScalarType = output_t.scalar_type(); + if (getMPSDataType(output_t) != [castOutputTensor dataType] && + isIntegralType(outScalarType, /*includeBool=*/false)) { + double typeRange = 0; + double typeMin = 0; + if (outScalarType == kByte) { // uint8 + typeRange = 256.0; + typeMin = 0.0; + } else if (outScalarType == kChar) { // int8 + typeRange = 256.0; + typeMin = -128.0; + } else if (outScalarType == kShort) { // int16 + typeRange = 65536.0; + typeMin = -32768.0; + } + + if (typeRange > 0) { + MPSDataType floatType = [castOutputTensor dataType]; + MPSGraphTensor* rangeTensor = [mpsGraph constantWithScalar:typeRange dataType:floatType]; + + if (typeMin == 0) { + // Unsigned: floor(value / range) * range gives the multiple to subtract + MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:castOutputTensor + secondaryTensor:rangeTensor + name:nil]; + MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil]; + MPSGraphTensor* multTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor + secondaryTensor:rangeTensor + name:nil]; + castOutputTensor = [mpsGraph subtractionWithPrimaryTensor:castOutputTensor + secondaryTensor:multTensor + name:nil]; + } else { + // Signed: shift to unsigned range [0, range), modulo, shift back + MPSGraphTensor* shiftTensor = [mpsGraph constantWithScalar:-typeMin dataType:floatType]; + MPSGraphTensor* shifted = [mpsGraph additionWithPrimaryTensor:castOutputTensor + secondaryTensor:shiftTensor + name:nil]; + MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:shifted + secondaryTensor:rangeTensor + name:nil]; + MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil]; + MPSGraphTensor* multTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor + secondaryTensor:rangeTensor + name:nil]; + MPSGraphTensor* wrapped = [mpsGraph subtractionWithPrimaryTensor:shifted + secondaryTensor:multTensor + name:nil]; + castOutputTensor = [mpsGraph subtractionWithPrimaryTensor:wrapped + secondaryTensor:shiftTensor + name:nil]; + } + } + + outputTensor = castMPSTensor(mpsGraph, castOutputTensor, outScalarType); + } else if (getMPSDataType(output_t) != [castOutputTensor dataType]) { + outputTensor = castMPSTensor(mpsGraph, castOutputTensor, outScalarType); } newCachedGraph->inputTensor_ = inputTensor; diff --git a/test/test_mps.py b/test/test_mps.py index 54ae778cecbeb..140b71d806b57 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -5451,6 +5451,20 @@ def helper(n, c, h, w, dtype=torch.float32): self.assertEqual(x.numel(), 8) self.assertEqual(x.max().item(), 30.0) + # Regression test for https://github.com/pytorch/pytorch/issues/179415 + # MPS sum should use wrapping (not saturated) cast for integer overflow + def test_sum_integer_overflow_wrapping(self): + # uint8: [-1, 2, 1, 5, 3, 6] as uint8 = [255, 2, 1, 5, 3, 6], sum = 272, wrapped = 16 + example = [[-1, 2, 1], [5, 3, 6]] + cpu_x = torch.tensor(example, dtype=torch.uint8, device="cpu") + mps_x = torch.tensor(example, dtype=torch.uint8, device="mps") + self.assertEqual(mps_x.sum(dtype=torch.uint8), cpu_x.sum(dtype=torch.uint8)) + + # int8: [127, 1] should wrap to -128, not saturate at 127 + cpu_y = torch.tensor([127, 1], dtype=torch.int8, device="cpu") + mps_y = torch.tensor([127, 1], dtype=torch.int8, device="mps") + self.assertEqual(mps_y.sum(dtype=torch.int8), cpu_y.sum(dtype=torch.int8)) + # Test forward prod def test_prod(self): def helper(shape, dtype=torch.float32):