Skip to content

[MPS] Fix sum reduction saturated cast for integer overflow#37

Draft
wdvr wants to merge 1 commit into
mainfrom
learning/issue-179415
Draft

[MPS] Fix sum reduction saturated cast for integer overflow#37
wdvr wants to merge 1 commit into
mainfrom
learning/issue-179415

Conversation

@wdvr

@wdvr wdvr commented Apr 12, 2026

Copy link
Copy Markdown
Owner

Fixes pytorch#179415

Problem

MPS backend's sum reduction uses [MPSGraph castTensor:toType:] to convert the float32 reduction result back to the original integer dtype. This API performs saturated casting (clamps to type range), but CPU uses wrapping (preserves only least significant bits).

Example:

torch.tensor([[-1, 2, 1], [5, 3, 6]], dtype=torch.uint8, device="mps").sum(dtype=torch.uint8)
# MPS (before fix): 255 (saturated)
# CPU (expected):   16  (wrapping: 272 % 256)

Fix

Before the final castTensor call in reduction_out_mps(), apply floor-modulo wrapping for small integer output types (uint8, int8, int16):

  • Unsigned: value - floor(value / range) * range (equivalent to value % range)
  • Signed: shift to unsigned range, modulo, shift back

This ensures the float32 value is already within the valid type range before the saturated cast, making both casts equivalent.

Test plan

  • Added regression test test_sum_integer_overflow_wrapping in test/test_mps.py
  • Tests uint8 overflow: [255, 2, 1, 5, 3, 6].sum() == 16
  • Tests int8 overflow: [127, 1].sum() == -128

… cast for MPS sum reduction

MPS backend's sum reduction was using MPSGraph's castTensor which performs
saturated casting (clamps to type range), causing incorrect results for
integer overflow. For example, summing uint8 values [255, 2, 1, 5, 3, 6]
returned 255 (saturated) instead of 16 (wrapping).

This adds modular arithmetic (floor-modulo) before the final cast for
small integer types (uint8, int8, int16), matching CPU's wrapping behavior:
- Unsigned: value - floor(value / range) * range
- Signed: shift to unsigned, modulo, shift back

Fixes pytorch#179415
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MPS] sum uses saturated cast

1 participant