From f58ff16aed40856a28c54cdce47bf0ab6766be9d Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 3 Jun 2026 16:06:05 -0700 Subject: [PATCH 1/4] init --- examples/models/gemma4_31b/model.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py index 657c79e0c4c..d051ff7e497 100644 --- a/examples/models/gemma4_31b/model.py +++ b/examples/models/gemma4_31b/model.py @@ -459,7 +459,19 @@ def _build_masks( seq_len = input_pos.shape[0] total_written = input_pos[0] + seq_len j = torch.arange(buf_size, dtype=torch.long, device=input_pos.device) - ring_pos = j + ((total_written - 1 - j) // buf_size) * buf_size + # NOTE(torch-2.12 AOTI/Inductor CUDA): int64 floor-division is + # mis-lowered for multi-token prefill (dynamic T>1). The negative + # numerators here (slot j > last-written position) truncate toward zero + # instead of flooring, and the fused index-math codegen scrambles the + # result further -- producing a wrong sliding-window mask, hence wrong + # attention and a wrong first prefill token. Decode (T=1) is unaffected + # because that lone, un-broadcast floor-div still lowers correctly (same + # "lone div OK / fused-broadcast div wrong" failure mode as the vision + # pooler bug). Values here are < 2**24, so doing the floor-division in + # float32 is bit-exact vs the int64 path and lowers correctly. Mirrors + # the Gemma4VisionPooler._avg_pool_by_positions fix in vision_tower.py. + wraps = torch.floor((total_written - 1 - j).float() / buf_size).long() + ring_pos = j + wraps * buf_size delta = q_pos - ring_pos.unsqueeze(0) sliding = (ring_pos >= 0) & (delta >= 0) & (delta < self.config.sliding_window) sliding_mask = sliding.unsqueeze(0).unsqueeze(0) # (1, 1, T_q, buf_size) From eb9978c871ca9745debfcf80d155c4c9e8a2b887 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 3 Jun 2026 17:43:12 -0700 Subject: [PATCH 2/4] replace floor_div with float div --- backends/cuda/cuda_backend.py | 5 +- .../cuda/passes/replace_int64_floordiv.py | 151 ++++++++++++ .../tests/test_replace_int64_floordiv.py | 216 ++++++++++++++++++ 3 files changed, 371 insertions(+), 1 deletion(-) create mode 100644 backends/cuda/passes/replace_int64_floordiv.py create mode 100644 backends/cuda/passes/tests/test_replace_int64_floordiv.py diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index d732a12a8fe..2914e36e7ff 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -19,6 +19,9 @@ from executorch.backends.cuda.passes.move_cond_predicate_to_cpu import ( MoveCondPredicateToCpuPass, ) +from executorch.backends.cuda.passes.replace_int64_floordiv import ( + ReplaceInt64FloorDivWithFloatPass, +) from executorch.backends.cuda.triton.replacement_pass import ( ReplaceEdgeOpWithTritonOpPass, ) @@ -257,7 +260,7 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any] f"Expected 'ON' or 'OFF'." ) triton_kernel_mode = mode - passes = [MoveCondPredicateToCpuPass()] + passes = [MoveCondPredicateToCpuPass(), ReplaceInt64FloorDivWithFloatPass()] if triton_kernel_mode == "ON": passes.append(ReplaceEdgeOpWithTritonOpPass()) return passes diff --git a/backends/cuda/passes/replace_int64_floordiv.py b/backends/cuda/passes/replace_int64_floordiv.py new file mode 100644 index 00000000000..7d8b14b611f --- /dev/null +++ b/backends/cuda/passes/replace_int64_floordiv.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Graph Transformation Pass for Integer Floor-Division Replacement. + +Rewrites integer (int64/int32) floor-division into a float64-domain floor to +work around a torch-2.12 AOTInductor/Inductor CUDA miscompile: + + floor_divide(a, b) -> floor(a.to(float64) / b.to(float64)).to(orig_int_dtype) +""" + +import logging + +import torch +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +logger = logging.getLogger(__name__) + +# Integer dtypes we rewrite. float64 (53-bit mantissa) is exact for +# |value| < 2**53, which covers these models' index ranges. +_INT_DTYPES = (torch.int64, torch.int32) + +# Edge ops that perform a floor-rounded integer division. +_FLOOR_DIVIDE_OP = exir_ops.edge.aten.floor_divide.default +_DIV_MODE_OPS = ( + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.div.Scalar_mode, +) + + +class ReplaceInt64FloorDivWithFloatPass(PassBase): + # Work around a torch-2.12 AOTInductor/Inductor CUDA miscompile of integer + # (int64) floor-division: fused/broadcast int64 floor_divide is mis-lowered + # (truncation instead of floor; cross-division term bleed under dynamic shapes). + # Rewriting into a float64-domain floor lowers correctly. Upstream issue: TODO(link). + """ + Pass to rewrite integer floor-division into a float64-domain floor. + + Matches ``floor_divide.default`` and the floor-mode ``div.Tensor_mode`` / + ``div.Scalar_mode`` overloads on integer operands, and replaces each with + ``floor(a.to(float64) / b.to(float64)).to(orig_int_dtype)`` built from edge + dialect ops. Float floor-division and non-integer nodes are left untouched. + """ + + def __init__(self): + super().__init__() + self._replacement_count = 0 + + def call(self, graph_module: GraphModule) -> PassResult: + self._replacement_count = 0 + modified = False + + for node in graph_module.graph.nodes: + if not self._should_replace_node(node): + continue + try: + self._replace_node(graph_module, node) + modified = True + self._replacement_count += 1 + except Exception as e: + logger.warning(f"Failed to rewrite floor-div node {node.name}: {e}") + # Continue with other nodes even if one fails. + + if modified: + graph_module.recompile() + + logger.info( + f"Rewrote {self._replacement_count} integer floor-division nodes " + f"into float64-domain floor" + ) + + return PassResult(graph_module, modified) + + @staticmethod + def _node_dtype(node: Node): + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor): + return val.dtype + return None + + @staticmethod + def _rounding_mode(node: Node): + if "rounding_mode" in node.kwargs: + return node.kwargs["rounding_mode"] + # Trailing positional arg: div(self, other, rounding_mode) + if len(node.args) > 2: + return node.args[2] + return None + + def _should_replace_node(self, node: Node) -> bool: + if node.op != "call_function": + return False + + if node.target == _FLOOR_DIVIDE_OP: + pass + elif node.target in _DIV_MODE_OPS: + if self._rounding_mode(node) != "floor": + return False + else: + return False + + # Only rewrite when the result is an integer tensor. Guard meta access: + # a node may lack meta["val"]; skip conservatively if so. + out_dtype = self._node_dtype(node) + if out_dtype not in _INT_DTYPES: + return False + + return True + + def _replace_node(self, graph_module: GraphModule, node: Node) -> None: + orig_dtype = self._node_dtype(node) + a = node.args[0] + b = node.args[1] + + graph = graph_module.graph + with graph.inserting_before(node): + a_f = graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(a,), + kwargs={"dtype": torch.float64}, + ) + if isinstance(b, Node): + b_f = graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(b,), + kwargs={"dtype": torch.float64}, + ) + q = graph.call_function(exir_ops.edge.aten.div.Tensor, args=(a_f, b_f)) + else: + # Python-scalar divisor: stays bit-exact, no cast needed for b. + q = graph.call_function( + exir_ops.edge.aten.div.Scalar, args=(a_f, float(b)) + ) + fl = graph.call_function(exir_ops.edge.aten.floor.default, args=(q,)) + new_node = graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(fl,), + kwargs={"dtype": orig_dtype}, + ) + + new_node.meta = node.meta.copy() + + node.replace_all_uses_with(new_node) + graph.erase_node(node) diff --git a/backends/cuda/passes/tests/test_replace_int64_floordiv.py b/backends/cuda/passes/tests/test_replace_int64_floordiv.py new file mode 100644 index 00000000000..9632611890b --- /dev/null +++ b/backends/cuda/passes/tests/test_replace_int64_floordiv.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from backends.cuda.passes.replace_int64_floordiv import ( + ReplaceInt64FloorDivWithFloatPass, +) +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import export + + +_INT_DIV_OPS = ( + exir_ops.edge.aten.floor_divide.default, + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.div.Scalar_mode, +) + + +def _count_int_floordiv(graph_module) -> int: + """Count integer floor-division nodes remaining in the graph.""" + n = 0 + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in _INT_DIV_OPS: + continue + if node.target in ( + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.div.Scalar_mode, + ): + rmode = node.kwargs.get("rounding_mode", None) + if rmode != "floor": + continue + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor) and val.dtype in ( + torch.int64, + torch.int32, + ): + n += 1 + return n + + +class TestReplaceInt64FloorDivWithFloatPass(unittest.TestCase): + """Test the ReplaceInt64FloorDivWithFloatPass transformation pass.""" + + def _edge_gm(self, module, inputs): + ep = to_edge(export(module, inputs, strict=True)) + return ep, ep.exported_program().graph_module + + def test_tensor_tensor_floordiv_rewritten(self): + """int64 a // b (tensor/tensor), including negative numerators.""" + + class M(torch.nn.Module): + def forward(self, a, b): + return a // b + + a = torch.tensor([-5, 7, -8, 9, -1, 0], dtype=torch.long) + b = torch.tensor([2, 3, 4, 5, 3, 7], dtype=torch.long) + ep, gm = self._edge_gm(M().eval(), (a, b)) + + self.assertGreater(_count_int_floordiv(gm), 0) + ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertEqual(_count_int_floordiv(gm), 0) + + out = ep.exported_program().module()(a, b) + self.assertEqual(out.dtype, torch.int64) + self.assertTrue(torch.equal(out, a // b)) + + def test_scalar_divisor_floordiv_rewritten(self): + """int64 a // 3 (scalar divisor lifted to a 0-d tensor constant).""" + + class M(torch.nn.Module): + def forward(self, a): + return a // 3 + + a = torch.tensor([-5, 7, -8, 9, -1, 0], dtype=torch.long) + ep, gm = self._edge_gm(M().eval(), (a,)) + + self.assertGreater(_count_int_floordiv(gm), 0) + ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertEqual(_count_int_floordiv(gm), 0) + + out = ep.exported_program().module()(a) + self.assertTrue(torch.equal(out, a // 3)) + + def test_div_rounding_mode_floor_rewritten(self): + """torch.div(..., rounding_mode='floor') on int64 is rewritten.""" + + class M(torch.nn.Module): + def forward(self, a, b): + return torch.div(a, b, rounding_mode="floor") + + a = torch.tensor([-5, 7, -8, 9], dtype=torch.long) + b = torch.tensor([2, 3, 4, 5], dtype=torch.long) + ep, gm = self._edge_gm(M().eval(), (a, b)) + + self.assertGreater(_count_int_floordiv(gm), 0) + ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertEqual(_count_int_floordiv(gm), 0) + + out = ep.exported_program().module()(a, b) + self.assertTrue(torch.equal(out, torch.div(a, b, rounding_mode="floor"))) + + def test_int32_floordiv_rewritten(self): + """int32 floor-division is also rewritten and stays int32.""" + + class M(torch.nn.Module): + def forward(self, a, b): + return a // b + + a = torch.tensor([-5, 7, -8, 9], dtype=torch.int32) + b = torch.tensor([2, 3, 4, 5], dtype=torch.int32) + ep, gm = self._edge_gm(M().eval(), (a, b)) + + self.assertGreater(_count_int_floordiv(gm), 0) + ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertEqual(_count_int_floordiv(gm), 0) + + out = ep.exported_program().module()(a, b) + self.assertEqual(out.dtype, torch.int32) + self.assertTrue(torch.equal(out, a // b)) + + def test_float_division_untouched(self): + """Real float division must not be rewritten.""" + + class M(torch.nn.Module): + def forward(self, a, b): + return a / b + + a = torch.tensor([1.0, 2.0, 3.0]) + b = torch.tensor([2.0, 3.0, 4.0]) + ep, gm = self._edge_gm(M().eval(), (a, b)) + + before = [n.target for n in gm.graph.nodes if n.op == "call_function"] + result = ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertFalse(result.modified) + after = [n.target for n in gm.graph.nodes if n.op == "call_function"] + self.assertEqual(before, after) + + def test_trunc_rounding_mode_untouched(self): + """div with rounding_mode='trunc' must not be rewritten.""" + + class M(torch.nn.Module): + def forward(self, a, b): + return torch.div(a, b, rounding_mode="trunc") + + a = torch.tensor([-5, 7, -8, 9], dtype=torch.long) + b = torch.tensor([2, 3, 4, 5], dtype=torch.long) + ep, gm = self._edge_gm(M().eval(), (a, b)) + + result = ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertFalse(result.modified) + + def test_floor_divide_default_branch(self): + """Exercise the floor_divide.default match/rewrite branch. + + This pin lowers ``//`` to ``div.Tensor_mode``; floor_divide.default does + not appear naturally, so we synthesize it by retargeting a node. + """ + + class M(torch.nn.Module): + def forward(self, a, b): + return a // b + + a = torch.tensor([-5, 7, -8, 9], dtype=torch.long) + b = torch.tensor([2, 3, 4, 5], dtype=torch.long) + ep, gm = self._edge_gm(M().eval(), (a, b)) + + # Retarget the div.Tensor_mode node to floor_divide.default. + for node in list(gm.graph.nodes): + if node.target == exir_ops.edge.aten.div.Tensor_mode: + with gm.graph.inserting_before(node): + new = gm.graph.call_function( + exir_ops.edge.aten.floor_divide.default, args=node.args + ) + new.meta = node.meta.copy() + node.replace_all_uses_with(new) + gm.graph.erase_node(node) + gm.recompile() + + self.assertGreater(_count_int_floordiv(gm), 0) + ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertEqual(_count_int_floordiv(gm), 0) + + out = ep.exported_program().module()(a, b) + self.assertTrue(torch.equal(out, a // b)) + + def test_ring_buffer_mask_analog(self): + """gemma4_31b sliding-window analog: negative numerators + scalar divisor.""" + + class M(torch.nn.Module): + def forward(self, input_pos): + buf_size = 8 + seq_len = input_pos.shape[0] + total_written = input_pos[0] + seq_len + j = torch.arange(buf_size, dtype=torch.long) + wraps = (total_written - 1 - j) // buf_size + return j + wraps * buf_size + + input_pos = torch.arange(3, dtype=torch.long) + ep, gm = self._edge_gm(M().eval(), (input_pos,)) + + ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertEqual(_count_int_floordiv(gm), 0) + + out = ep.exported_program().module()(input_pos) + ref = M()(input_pos) + self.assertTrue(torch.equal(out, ref)) + + +if __name__ == "__main__": + unittest.main() From 84a543289cad9ee27d0999175372774aae73280c Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 3 Jun 2026 17:57:15 -0700 Subject: [PATCH 3/4] revert model definition chages --- examples/models/gemma4_31b/model.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py index d051ff7e497..657c79e0c4c 100644 --- a/examples/models/gemma4_31b/model.py +++ b/examples/models/gemma4_31b/model.py @@ -459,19 +459,7 @@ def _build_masks( seq_len = input_pos.shape[0] total_written = input_pos[0] + seq_len j = torch.arange(buf_size, dtype=torch.long, device=input_pos.device) - # NOTE(torch-2.12 AOTI/Inductor CUDA): int64 floor-division is - # mis-lowered for multi-token prefill (dynamic T>1). The negative - # numerators here (slot j > last-written position) truncate toward zero - # instead of flooring, and the fused index-math codegen scrambles the - # result further -- producing a wrong sliding-window mask, hence wrong - # attention and a wrong first prefill token. Decode (T=1) is unaffected - # because that lone, un-broadcast floor-div still lowers correctly (same - # "lone div OK / fused-broadcast div wrong" failure mode as the vision - # pooler bug). Values here are < 2**24, so doing the floor-division in - # float32 is bit-exact vs the int64 path and lowers correctly. Mirrors - # the Gemma4VisionPooler._avg_pool_by_positions fix in vision_tower.py. - wraps = torch.floor((total_written - 1 - j).float() / buf_size).long() - ring_pos = j + wraps * buf_size + ring_pos = j + ((total_written - 1 - j) // buf_size) * buf_size delta = q_pos - ring_pos.unsqueeze(0) sliding = (ring_pos >= 0) & (delta >= 0) & (delta < self.config.sliding_window) sliding_mask = sliding.unsqueeze(0).unsqueeze(0) # (1, 1, T_q, buf_size) From fc846cff07f21b89f35952f677f6a5f4ce847797 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 3 Jun 2026 22:35:54 -0700 Subject: [PATCH 4/4] update doc and increate test time --- .github/workflows/cuda.yml | 2 +- backends/cuda/passes/replace_int64_floordiv.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index f19b937994f..2f8c1c12643 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -267,7 +267,7 @@ jobs: name: "whisper-large-v3-turbo" quant: "non-quantized" with: - timeout: 90 + timeout: 150 secrets-env: EXECUTORCH_HF_TOKEN runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} gpu-arch-type: cuda diff --git a/backends/cuda/passes/replace_int64_floordiv.py b/backends/cuda/passes/replace_int64_floordiv.py index 7d8b14b611f..85cd201416e 100644 --- a/backends/cuda/passes/replace_int64_floordiv.py +++ b/backends/cuda/passes/replace_int64_floordiv.py @@ -17,14 +17,14 @@ import torch from executorch.exir.dialects._ops import ops as exir_ops - from torch.fx import GraphModule, Node from torch.fx.passes.infra.pass_base import PassBase, PassResult logger = logging.getLogger(__name__) -# Integer dtypes we rewrite. float64 (53-bit mantissa) is exact for -# |value| < 2**53, which covers these models' index ranges. +# NOTE: Integer dtypes we rewrite. float64 (53-bit mantissa) is for +# |value| < 2**53, which covers models' index ranges but not enough +# for extreme large numbers. _INT_DTYPES = (torch.int64, torch.int32) # Edge ops that perform a floor-rounded integer division. @@ -39,7 +39,8 @@ class ReplaceInt64FloorDivWithFloatPass(PassBase): # Work around a torch-2.12 AOTInductor/Inductor CUDA miscompile of integer # (int64) floor-division: fused/broadcast int64 floor_divide is mis-lowered # (truncation instead of floor; cross-division term bleed under dynamic shapes). - # Rewriting into a float64-domain floor lowers correctly. Upstream issue: TODO(link). + # TODO(gasoonjia): remove this pass once the upstream issue solved. + # Upstream issue: https://github.com/pytorch/pytorch/issues/186164 """ Pass to rewrite integer floor-division into a float64-domain floor.