diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index f19b937994f..2f8c1c12643 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -267,7 +267,7 @@ jobs: name: "whisper-large-v3-turbo" quant: "non-quantized" with: - timeout: 90 + timeout: 150 secrets-env: EXECUTORCH_HF_TOKEN runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} gpu-arch-type: cuda diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index d732a12a8fe..2914e36e7ff 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -19,6 +19,9 @@ from executorch.backends.cuda.passes.move_cond_predicate_to_cpu import ( MoveCondPredicateToCpuPass, ) +from executorch.backends.cuda.passes.replace_int64_floordiv import ( + ReplaceInt64FloorDivWithFloatPass, +) from executorch.backends.cuda.triton.replacement_pass import ( ReplaceEdgeOpWithTritonOpPass, ) @@ -257,7 +260,7 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any] f"Expected 'ON' or 'OFF'." ) triton_kernel_mode = mode - passes = [MoveCondPredicateToCpuPass()] + passes = [MoveCondPredicateToCpuPass(), ReplaceInt64FloorDivWithFloatPass()] if triton_kernel_mode == "ON": passes.append(ReplaceEdgeOpWithTritonOpPass()) return passes diff --git a/backends/cuda/passes/replace_int64_floordiv.py b/backends/cuda/passes/replace_int64_floordiv.py new file mode 100644 index 00000000000..85cd201416e --- /dev/null +++ b/backends/cuda/passes/replace_int64_floordiv.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Graph Transformation Pass for Integer Floor-Division Replacement. + +Rewrites integer (int64/int32) floor-division into a float64-domain floor to +work around a torch-2.12 AOTInductor/Inductor CUDA miscompile: + + floor_divide(a, b) -> floor(a.to(float64) / b.to(float64)).to(orig_int_dtype) +""" + +import logging + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +logger = logging.getLogger(__name__) + +# NOTE: Integer dtypes we rewrite. float64 (53-bit mantissa) is for +# |value| < 2**53, which covers models' index ranges but not enough +# for extreme large numbers. +_INT_DTYPES = (torch.int64, torch.int32) + +# Edge ops that perform a floor-rounded integer division. +_FLOOR_DIVIDE_OP = exir_ops.edge.aten.floor_divide.default +_DIV_MODE_OPS = ( + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.div.Scalar_mode, +) + + +class ReplaceInt64FloorDivWithFloatPass(PassBase): + # Work around a torch-2.12 AOTInductor/Inductor CUDA miscompile of integer + # (int64) floor-division: fused/broadcast int64 floor_divide is mis-lowered + # (truncation instead of floor; cross-division term bleed under dynamic shapes). + # TODO(gasoonjia): remove this pass once the upstream issue solved. + # Upstream issue: https://github.com/pytorch/pytorch/issues/186164 + """ + Pass to rewrite integer floor-division into a float64-domain floor. + + Matches ``floor_divide.default`` and the floor-mode ``div.Tensor_mode`` / + ``div.Scalar_mode`` overloads on integer operands, and replaces each with + ``floor(a.to(float64) / b.to(float64)).to(orig_int_dtype)`` built from edge + dialect ops. Float floor-division and non-integer nodes are left untouched. + """ + + def __init__(self): + super().__init__() + self._replacement_count = 0 + + def call(self, graph_module: GraphModule) -> PassResult: + self._replacement_count = 0 + modified = False + + for node in graph_module.graph.nodes: + if not self._should_replace_node(node): + continue + try: + self._replace_node(graph_module, node) + modified = True + self._replacement_count += 1 + except Exception as e: + logger.warning(f"Failed to rewrite floor-div node {node.name}: {e}") + # Continue with other nodes even if one fails. + + if modified: + graph_module.recompile() + + logger.info( + f"Rewrote {self._replacement_count} integer floor-division nodes " + f"into float64-domain floor" + ) + + return PassResult(graph_module, modified) + + @staticmethod + def _node_dtype(node: Node): + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor): + return val.dtype + return None + + @staticmethod + def _rounding_mode(node: Node): + if "rounding_mode" in node.kwargs: + return node.kwargs["rounding_mode"] + # Trailing positional arg: div(self, other, rounding_mode) + if len(node.args) > 2: + return node.args[2] + return None + + def _should_replace_node(self, node: Node) -> bool: + if node.op != "call_function": + return False + + if node.target == _FLOOR_DIVIDE_OP: + pass + elif node.target in _DIV_MODE_OPS: + if self._rounding_mode(node) != "floor": + return False + else: + return False + + # Only rewrite when the result is an integer tensor. Guard meta access: + # a node may lack meta["val"]; skip conservatively if so. + out_dtype = self._node_dtype(node) + if out_dtype not in _INT_DTYPES: + return False + + return True + + def _replace_node(self, graph_module: GraphModule, node: Node) -> None: + orig_dtype = self._node_dtype(node) + a = node.args[0] + b = node.args[1] + + graph = graph_module.graph + with graph.inserting_before(node): + a_f = graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(a,), + kwargs={"dtype": torch.float64}, + ) + if isinstance(b, Node): + b_f = graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(b,), + kwargs={"dtype": torch.float64}, + ) + q = graph.call_function(exir_ops.edge.aten.div.Tensor, args=(a_f, b_f)) + else: + # Python-scalar divisor: stays bit-exact, no cast needed for b. + q = graph.call_function( + exir_ops.edge.aten.div.Scalar, args=(a_f, float(b)) + ) + fl = graph.call_function(exir_ops.edge.aten.floor.default, args=(q,)) + new_node = graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(fl,), + kwargs={"dtype": orig_dtype}, + ) + + new_node.meta = node.meta.copy() + + node.replace_all_uses_with(new_node) + graph.erase_node(node) diff --git a/backends/cuda/passes/tests/test_replace_int64_floordiv.py b/backends/cuda/passes/tests/test_replace_int64_floordiv.py new file mode 100644 index 00000000000..9632611890b --- /dev/null +++ b/backends/cuda/passes/tests/test_replace_int64_floordiv.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from backends.cuda.passes.replace_int64_floordiv import ( + ReplaceInt64FloorDivWithFloatPass, +) +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import export + + +_INT_DIV_OPS = ( + exir_ops.edge.aten.floor_divide.default, + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.div.Scalar_mode, +) + + +def _count_int_floordiv(graph_module) -> int: + """Count integer floor-division nodes remaining in the graph.""" + n = 0 + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in _INT_DIV_OPS: + continue + if node.target in ( + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.div.Scalar_mode, + ): + rmode = node.kwargs.get("rounding_mode", None) + if rmode != "floor": + continue + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor) and val.dtype in ( + torch.int64, + torch.int32, + ): + n += 1 + return n + + +class TestReplaceInt64FloorDivWithFloatPass(unittest.TestCase): + """Test the ReplaceInt64FloorDivWithFloatPass transformation pass.""" + + def _edge_gm(self, module, inputs): + ep = to_edge(export(module, inputs, strict=True)) + return ep, ep.exported_program().graph_module + + def test_tensor_tensor_floordiv_rewritten(self): + """int64 a // b (tensor/tensor), including negative numerators.""" + + class M(torch.nn.Module): + def forward(self, a, b): + return a // b + + a = torch.tensor([-5, 7, -8, 9, -1, 0], dtype=torch.long) + b = torch.tensor([2, 3, 4, 5, 3, 7], dtype=torch.long) + ep, gm = self._edge_gm(M().eval(), (a, b)) + + self.assertGreater(_count_int_floordiv(gm), 0) + ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertEqual(_count_int_floordiv(gm), 0) + + out = ep.exported_program().module()(a, b) + self.assertEqual(out.dtype, torch.int64) + self.assertTrue(torch.equal(out, a // b)) + + def test_scalar_divisor_floordiv_rewritten(self): + """int64 a // 3 (scalar divisor lifted to a 0-d tensor constant).""" + + class M(torch.nn.Module): + def forward(self, a): + return a // 3 + + a = torch.tensor([-5, 7, -8, 9, -1, 0], dtype=torch.long) + ep, gm = self._edge_gm(M().eval(), (a,)) + + self.assertGreater(_count_int_floordiv(gm), 0) + ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertEqual(_count_int_floordiv(gm), 0) + + out = ep.exported_program().module()(a) + self.assertTrue(torch.equal(out, a // 3)) + + def test_div_rounding_mode_floor_rewritten(self): + """torch.div(..., rounding_mode='floor') on int64 is rewritten.""" + + class M(torch.nn.Module): + def forward(self, a, b): + return torch.div(a, b, rounding_mode="floor") + + a = torch.tensor([-5, 7, -8, 9], dtype=torch.long) + b = torch.tensor([2, 3, 4, 5], dtype=torch.long) + ep, gm = self._edge_gm(M().eval(), (a, b)) + + self.assertGreater(_count_int_floordiv(gm), 0) + ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertEqual(_count_int_floordiv(gm), 0) + + out = ep.exported_program().module()(a, b) + self.assertTrue(torch.equal(out, torch.div(a, b, rounding_mode="floor"))) + + def test_int32_floordiv_rewritten(self): + """int32 floor-division is also rewritten and stays int32.""" + + class M(torch.nn.Module): + def forward(self, a, b): + return a // b + + a = torch.tensor([-5, 7, -8, 9], dtype=torch.int32) + b = torch.tensor([2, 3, 4, 5], dtype=torch.int32) + ep, gm = self._edge_gm(M().eval(), (a, b)) + + self.assertGreater(_count_int_floordiv(gm), 0) + ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertEqual(_count_int_floordiv(gm), 0) + + out = ep.exported_program().module()(a, b) + self.assertEqual(out.dtype, torch.int32) + self.assertTrue(torch.equal(out, a // b)) + + def test_float_division_untouched(self): + """Real float division must not be rewritten.""" + + class M(torch.nn.Module): + def forward(self, a, b): + return a / b + + a = torch.tensor([1.0, 2.0, 3.0]) + b = torch.tensor([2.0, 3.0, 4.0]) + ep, gm = self._edge_gm(M().eval(), (a, b)) + + before = [n.target for n in gm.graph.nodes if n.op == "call_function"] + result = ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertFalse(result.modified) + after = [n.target for n in gm.graph.nodes if n.op == "call_function"] + self.assertEqual(before, after) + + def test_trunc_rounding_mode_untouched(self): + """div with rounding_mode='trunc' must not be rewritten.""" + + class M(torch.nn.Module): + def forward(self, a, b): + return torch.div(a, b, rounding_mode="trunc") + + a = torch.tensor([-5, 7, -8, 9], dtype=torch.long) + b = torch.tensor([2, 3, 4, 5], dtype=torch.long) + ep, gm = self._edge_gm(M().eval(), (a, b)) + + result = ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertFalse(result.modified) + + def test_floor_divide_default_branch(self): + """Exercise the floor_divide.default match/rewrite branch. + + This pin lowers ``//`` to ``div.Tensor_mode``; floor_divide.default does + not appear naturally, so we synthesize it by retargeting a node. + """ + + class M(torch.nn.Module): + def forward(self, a, b): + return a // b + + a = torch.tensor([-5, 7, -8, 9], dtype=torch.long) + b = torch.tensor([2, 3, 4, 5], dtype=torch.long) + ep, gm = self._edge_gm(M().eval(), (a, b)) + + # Retarget the div.Tensor_mode node to floor_divide.default. + for node in list(gm.graph.nodes): + if node.target == exir_ops.edge.aten.div.Tensor_mode: + with gm.graph.inserting_before(node): + new = gm.graph.call_function( + exir_ops.edge.aten.floor_divide.default, args=node.args + ) + new.meta = node.meta.copy() + node.replace_all_uses_with(new) + gm.graph.erase_node(node) + gm.recompile() + + self.assertGreater(_count_int_floordiv(gm), 0) + ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertEqual(_count_int_floordiv(gm), 0) + + out = ep.exported_program().module()(a, b) + self.assertTrue(torch.equal(out, a // b)) + + def test_ring_buffer_mask_analog(self): + """gemma4_31b sliding-window analog: negative numerators + scalar divisor.""" + + class M(torch.nn.Module): + def forward(self, input_pos): + buf_size = 8 + seq_len = input_pos.shape[0] + total_written = input_pos[0] + seq_len + j = torch.arange(buf_size, dtype=torch.long) + wraps = (total_written - 1 - j) // buf_size + return j + wraps * buf_size + + input_pos = torch.arange(3, dtype=torch.long) + ep, gm = self._edge_gm(M().eval(), (input_pos,)) + + ReplaceInt64FloorDivWithFloatPass()(gm) + self.assertEqual(_count_int_floordiv(gm), 0) + + out = ep.exported_program().module()(input_pos) + ref = M()(input_pos) + self.assertTrue(torch.equal(out, ref)) + + +if __name__ == "__main__": + unittest.main()