diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e23bf1b4..ad1b2ea2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: 'build|stubs' +exclude: 'build|stubs|autoparallel/tools/overlap_simulator' default_language_version: python: python3 diff --git a/autoparallel/graph_passes/autobucketing_inductor/bucket_func.py b/autoparallel/graph_passes/autobucketing_inductor/bucket_func.py index d6f47764..6e45c729 100644 --- a/autoparallel/graph_passes/autobucketing_inductor/bucket_func.py +++ b/autoparallel/graph_passes/autobucketing_inductor/bucket_func.py @@ -9,7 +9,6 @@ import torch from torch._inductor import ir, scheduler -from torch._inductor.comms import get_op_idx from torch._inductor.dependencies import StarDep, WeakDep from torch._inductor.utils import is_collective, is_wait from torch._inductor.virtualized import V @@ -26,6 +25,7 @@ bucket_all_gathers, bucket_reduce_scatters, check_ir_node_bucketable, + get_op_idx, ) diff --git a/autoparallel/graph_passes/autobucketing_inductor/bucket_utils.py b/autoparallel/graph_passes/autobucketing_inductor/bucket_utils.py index e11af615..1064cf19 100644 --- a/autoparallel/graph_passes/autobucketing_inductor/bucket_utils.py +++ b/autoparallel/graph_passes/autobucketing_inductor/bucket_utils.py @@ -27,6 +27,21 @@ def get_data_size(size): return reduce(lambda x, y: x * y, size) +def get_op_idx(snode: "scheduler.BaseSchedulerNode") -> int: + if isinstance( + snode, + ( + scheduler.FusedSchedulerNode, + scheduler.GroupedSchedulerNode, + ), + ): + raise TypeError(f"Expected an unfused scheduler node, got {type(snode)}") + op_name = snode.get_name() + if not op_name.startswith("op"): + raise KeyError(f"Expected op name to start with 'op', got {op_name}") + return int(op_name[2:]) + + def _find_recursive_deps_of_snode( snode: "scheduler.BaseSchedulerNode", collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"], diff --git a/tests/test_fsdp_all_gather_tagging.py b/tests/test_fsdp_all_gather_tagging.py index 8ed0b6d2..ab971c76 100644 --- a/tests/test_fsdp_all_gather_tagging.py +++ b/tests/test_fsdp_all_gather_tagging.py @@ -9,6 +9,7 @@ without running the full AutoParallel pipeline. """ +import pytest import torch import torch.fx from torch.utils.checkpoint import CheckpointPolicy @@ -19,6 +20,7 @@ force_save_fsdp_all_gather, mark_fsdp_all_gather_recomputation, ) +from autoparallel.graph_passes.autobucketing_inductor import bucket_utils # --------------------------------------------------------------------------- # Helpers for building minimal FSDP-like graphs @@ -78,6 +80,22 @@ def _build_simple_fsdp_graph() -> torch.fx.Graph: return graph +class _SchedulerNode: + def __init__(self, name): + self.name = name + + def get_name(self): + return self.name + + +class _FusedSchedulerNode(_SchedulerNode): + pass + + +class _GroupedSchedulerNode(_SchedulerNode): + pass + + # --------------------------------------------------------------------------- # Tests for force_recompute_fsdp_all_gather # --------------------------------------------------------------------------- @@ -224,3 +242,30 @@ def test_no_tags_without_fsdp_pattern(): for node in graph.nodes: assert "recompute" not in node.meta assert "ac_graph_id" not in node.meta + + +# --------------------------------------------------------------------------- +# Tests for autobucketing scheduler helpers +# --------------------------------------------------------------------------- + + +def test_get_op_idx(): + assert bucket_utils.get_op_idx(_SchedulerNode("op142")) == 142 + + +def test_get_op_idx_rejects_non_op_name(): + with pytest.raises(KeyError, match="Expected op name"): + bucket_utils.get_op_idx(_SchedulerNode("buf142")) + + +def test_get_op_idx_rejects_fused_and_grouped_snodes(monkeypatch): + monkeypatch.setattr( + bucket_utils.scheduler, "FusedSchedulerNode", _FusedSchedulerNode + ) + monkeypatch.setattr( + bucket_utils.scheduler, "GroupedSchedulerNode", _GroupedSchedulerNode + ) + + for node_cls in (_FusedSchedulerNode, _GroupedSchedulerNode): + with pytest.raises(TypeError, match="Expected an unfused scheduler node"): + bucket_utils.get_op_idx(node_cls("op142")) diff --git a/tests/test_nccl_cost_model.py b/tests/test_nccl_cost_model.py index 9399c7f1..54ac6d54 100644 --- a/tests/test_nccl_cost_model.py +++ b/tests/test_nccl_cost_model.py @@ -653,8 +653,8 @@ def test_cost_increases_with_size(self, func): times = [nccl_collective_time(func, s, topo, config) for s in sizes] for i in range(1, len(times)): assert times[i] >= times[i - 1], ( - f"Cost decreased from {sizes[i-1]} to {sizes[i]}: " - f"{times[i-1]} > {times[i]}" + f"Cost decreased from {sizes[i - 1]} to {sizes[i]}: " + f"{times[i - 1]} > {times[i]}" ) @pytest.mark.parametrize( @@ -810,8 +810,8 @@ def test_monotonicity(self): times = [nccl_all_to_all_cost(s, topo, config) for s in sizes] for i in range(1, len(times)): assert times[i] >= times[i - 1], ( - f"AllToAll cost decreased from {sizes[i-1]} to {sizes[i]}: " - f"{times[i-1]} > {times[i]}" + f"AllToAll cost decreased from {sizes[i - 1]} to {sizes[i]}: " + f"{times[i - 1]} > {times[i]}" ) def test_more_expensive_than_allgather(self): @@ -1210,8 +1210,8 @@ def test_monotonicity(self): times = [nccl_allreduce_cost(s, topo, config) for s in sizes] for i in range(1, len(times)): assert times[i] >= times[i - 1], ( - f"Cost decreased from {sizes[i-1]} to {sizes[i]}: " - f"{times[i-1]} > {times[i]}" + f"Cost decreased from {sizes[i - 1]} to {sizes[i]}: " + f"{times[i - 1]} > {times[i]}" ) def test_ramp_tables_shape(self): @@ -1246,8 +1246,8 @@ def test_monotonicity_16_nodes(self): times = [nccl_allreduce_cost(s, topo, config) for s in sizes] for i in range(1, len(times)): assert times[i] >= times[i - 1], ( - f"Cost decreased from {sizes[i-1]} to {sizes[i]}: " - f"{times[i-1]} > {times[i]}" + f"Cost decreased from {sizes[i - 1]} to {sizes[i]}: " + f"{times[i - 1]} > {times[i]}" ) def test_blackwell_bw_scales(self): @@ -1308,8 +1308,8 @@ def test_monotonicity(self): times = [nccl_allgather_cost(s, topo, config) for s in sizes] for i in range(1, len(times)): assert times[i] >= times[i - 1], ( - f"Cost decreased from {sizes[i-1]} to {sizes[i]}: " - f"{times[i-1]} > {times[i]}" + f"Cost decreased from {sizes[i - 1]} to {sizes[i]}: " + f"{times[i - 1]} > {times[i]}" ) def test_ramp_table_shape(self): @@ -1340,7 +1340,7 @@ def test_table_entries_do_not_jump_more_than_2x(self, n_nodes): for i in range(1, len(table)): assert table[i] <= 2.0 * table[i - 1] + 1e-12, ( f"n_nodes={n_nodes}: table[{i}]={table[i]:.6f} > " - f"2 * table[{i-1}]={2*table[i-1]:.6f}" + f"2 * table[{i - 1}]={2 * table[i - 1]:.6f}" ) @pytest.mark.parametrize("n_nodes", [8, 16, 32])