Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,70 @@ static void reduction_out_mps(const Tensor& input_t,
}
}

// Apply wrapping (modular) arithmetic for small integer output types to
// match CPU behavior. MPSGraph's castTensor uses saturated casting (clamps
// to type range), but CPU preserves only the least significant bits
// (C-standard wrapping for unsigned types, two's complement for signed).
// We compute: wrapped = value - floor(value / range) * range
// which is equivalent to Python's floor-modulo (value % range).
MPSGraphTensor* outputTensor = castOutputTensor;
if (getMPSDataType(output_t) != [castOutputTensor dataType]) {
outputTensor = castMPSTensor(mpsGraph, castOutputTensor, output_t.scalar_type());
ScalarType outScalarType = output_t.scalar_type();
if (getMPSDataType(output_t) != [castOutputTensor dataType] &&
isIntegralType(outScalarType, /*includeBool=*/false)) {
double typeRange = 0;
double typeMin = 0;
if (outScalarType == kByte) { // uint8
typeRange = 256.0;
typeMin = 0.0;
} else if (outScalarType == kChar) { // int8
typeRange = 256.0;
typeMin = -128.0;
} else if (outScalarType == kShort) { // int16
typeRange = 65536.0;
typeMin = -32768.0;
}

if (typeRange > 0) {
MPSDataType floatType = [castOutputTensor dataType];
MPSGraphTensor* rangeTensor = [mpsGraph constantWithScalar:typeRange dataType:floatType];

if (typeMin == 0) {
// Unsigned: floor(value / range) * range gives the multiple to subtract
MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:castOutputTensor
secondaryTensor:rangeTensor
name:nil];
MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil];
MPSGraphTensor* multTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor
secondaryTensor:rangeTensor
name:nil];
castOutputTensor = [mpsGraph subtractionWithPrimaryTensor:castOutputTensor
secondaryTensor:multTensor
name:nil];
} else {
// Signed: shift to unsigned range [0, range), modulo, shift back
MPSGraphTensor* shiftTensor = [mpsGraph constantWithScalar:-typeMin dataType:floatType];
MPSGraphTensor* shifted = [mpsGraph additionWithPrimaryTensor:castOutputTensor
secondaryTensor:shiftTensor
name:nil];
MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:shifted
secondaryTensor:rangeTensor
name:nil];
MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil];
MPSGraphTensor* multTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor
secondaryTensor:rangeTensor
name:nil];
MPSGraphTensor* wrapped = [mpsGraph subtractionWithPrimaryTensor:shifted
secondaryTensor:multTensor
name:nil];
castOutputTensor = [mpsGraph subtractionWithPrimaryTensor:wrapped
secondaryTensor:shiftTensor
name:nil];
}
}

outputTensor = castMPSTensor(mpsGraph, castOutputTensor, outScalarType);
} else if (getMPSDataType(output_t) != [castOutputTensor dataType]) {
outputTensor = castMPSTensor(mpsGraph, castOutputTensor, outScalarType);
}

newCachedGraph->inputTensor_ = inputTensor;
Expand Down
14 changes: 14 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5451,6 +5451,20 @@ def helper(n, c, h, w, dtype=torch.float32):
self.assertEqual(x.numel(), 8)
self.assertEqual(x.max().item(), 30.0)

# Regression test for https://github.com/pytorch/pytorch/issues/179415
# MPS sum should use wrapping (not saturated) cast for integer overflow
def test_sum_integer_overflow_wrapping(self):
# uint8: [-1, 2, 1, 5, 3, 6] as uint8 = [255, 2, 1, 5, 3, 6], sum = 272, wrapped = 16
example = [[-1, 2, 1], [5, 3, 6]]
cpu_x = torch.tensor(example, dtype=torch.uint8, device="cpu")
mps_x = torch.tensor(example, dtype=torch.uint8, device="mps")
self.assertEqual(mps_x.sum(dtype=torch.uint8), cpu_x.sum(dtype=torch.uint8))

# int8: [127, 1] should wrap to -128, not saturate at 127
cpu_y = torch.tensor([127, 1], dtype=torch.int8, device="cpu")
mps_y = torch.tensor([127, 1], dtype=torch.int8, device="mps")
self.assertEqual(mps_y.sum(dtype=torch.int8), cpu_y.sum(dtype=torch.int8))

# Test forward prod
def test_prod(self):
def helper(shape, dtype=torch.float32):
Expand Down