Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
exclude: 'build|stubs'
exclude: 'build|stubs|autoparallel/tools/overlap_simulator'

default_language_version:
python: python3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import torch
from torch._inductor import ir, scheduler
from torch._inductor.comms import get_op_idx
from torch._inductor.dependencies import StarDep, WeakDep
from torch._inductor.utils import is_collective, is_wait
from torch._inductor.virtualized import V
Expand All @@ -26,6 +25,7 @@
bucket_all_gathers,
bucket_reduce_scatters,
check_ir_node_bucketable,
get_op_idx,
)


Expand Down
15 changes: 15 additions & 0 deletions autoparallel/graph_passes/autobucketing_inductor/bucket_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ def get_data_size(size):
return reduce(lambda x, y: x * y, size)


def get_op_idx(snode: "scheduler.BaseSchedulerNode") -> int:
if isinstance(
snode,
(
scheduler.FusedSchedulerNode,
scheduler.GroupedSchedulerNode,
),
):
raise TypeError(f"Expected an unfused scheduler node, got {type(snode)}")
op_name = snode.get_name()
if not op_name.startswith("op"):
raise KeyError(f"Expected op name to start with 'op', got {op_name}")
return int(op_name[2:])


def _find_recursive_deps_of_snode(
snode: "scheduler.BaseSchedulerNode",
collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"],
Expand Down
45 changes: 45 additions & 0 deletions tests/test_fsdp_all_gather_tagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
without running the full AutoParallel pipeline.
"""

import pytest
import torch
import torch.fx
from torch.utils.checkpoint import CheckpointPolicy
Expand All @@ -19,6 +20,7 @@
force_save_fsdp_all_gather,
mark_fsdp_all_gather_recomputation,
)
from autoparallel.graph_passes.autobucketing_inductor import bucket_utils

# ---------------------------------------------------------------------------
# Helpers for building minimal FSDP-like graphs
Expand Down Expand Up @@ -78,6 +80,22 @@ def _build_simple_fsdp_graph() -> torch.fx.Graph:
return graph


class _SchedulerNode:
def __init__(self, name):
self.name = name

def get_name(self):
return self.name


class _FusedSchedulerNode(_SchedulerNode):
pass


class _GroupedSchedulerNode(_SchedulerNode):
pass


# ---------------------------------------------------------------------------
# Tests for force_recompute_fsdp_all_gather
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -224,3 +242,30 @@ def test_no_tags_without_fsdp_pattern():
for node in graph.nodes:
assert "recompute" not in node.meta
assert "ac_graph_id" not in node.meta


# ---------------------------------------------------------------------------
# Tests for autobucketing scheduler helpers
# ---------------------------------------------------------------------------


def test_get_op_idx():
assert bucket_utils.get_op_idx(_SchedulerNode("op142")) == 142


def test_get_op_idx_rejects_non_op_name():
with pytest.raises(KeyError, match="Expected op name"):
bucket_utils.get_op_idx(_SchedulerNode("buf142"))


def test_get_op_idx_rejects_fused_and_grouped_snodes(monkeypatch):
monkeypatch.setattr(
bucket_utils.scheduler, "FusedSchedulerNode", _FusedSchedulerNode
)
monkeypatch.setattr(
bucket_utils.scheduler, "GroupedSchedulerNode", _GroupedSchedulerNode
)

for node_cls in (_FusedSchedulerNode, _GroupedSchedulerNode):
with pytest.raises(TypeError, match="Expected an unfused scheduler node"):
bucket_utils.get_op_idx(node_cls("op142"))
22 changes: 11 additions & 11 deletions tests/test_nccl_cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,8 +653,8 @@ def test_cost_increases_with_size(self, func):
times = [nccl_collective_time(func, s, topo, config) for s in sizes]
for i in range(1, len(times)):
assert times[i] >= times[i - 1], (
f"Cost decreased from {sizes[i-1]} to {sizes[i]}: "
f"{times[i-1]} > {times[i]}"
f"Cost decreased from {sizes[i - 1]} to {sizes[i]}: "
f"{times[i - 1]} > {times[i]}"
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -810,8 +810,8 @@ def test_monotonicity(self):
times = [nccl_all_to_all_cost(s, topo, config) for s in sizes]
for i in range(1, len(times)):
assert times[i] >= times[i - 1], (
f"AllToAll cost decreased from {sizes[i-1]} to {sizes[i]}: "
f"{times[i-1]} > {times[i]}"
f"AllToAll cost decreased from {sizes[i - 1]} to {sizes[i]}: "
f"{times[i - 1]} > {times[i]}"
)

def test_more_expensive_than_allgather(self):
Expand Down Expand Up @@ -1210,8 +1210,8 @@ def test_monotonicity(self):
times = [nccl_allreduce_cost(s, topo, config) for s in sizes]
for i in range(1, len(times)):
assert times[i] >= times[i - 1], (
f"Cost decreased from {sizes[i-1]} to {sizes[i]}: "
f"{times[i-1]} > {times[i]}"
f"Cost decreased from {sizes[i - 1]} to {sizes[i]}: "
f"{times[i - 1]} > {times[i]}"
)

def test_ramp_tables_shape(self):
Expand Down Expand Up @@ -1246,8 +1246,8 @@ def test_monotonicity_16_nodes(self):
times = [nccl_allreduce_cost(s, topo, config) for s in sizes]
for i in range(1, len(times)):
assert times[i] >= times[i - 1], (
f"Cost decreased from {sizes[i-1]} to {sizes[i]}: "
f"{times[i-1]} > {times[i]}"
f"Cost decreased from {sizes[i - 1]} to {sizes[i]}: "
f"{times[i - 1]} > {times[i]}"
)

def test_blackwell_bw_scales(self):
Expand Down Expand Up @@ -1308,8 +1308,8 @@ def test_monotonicity(self):
times = [nccl_allgather_cost(s, topo, config) for s in sizes]
for i in range(1, len(times)):
assert times[i] >= times[i - 1], (
f"Cost decreased from {sizes[i-1]} to {sizes[i]}: "
f"{times[i-1]} > {times[i]}"
f"Cost decreased from {sizes[i - 1]} to {sizes[i]}: "
f"{times[i - 1]} > {times[i]}"
)

def test_ramp_table_shape(self):
Expand Down Expand Up @@ -1340,7 +1340,7 @@ def test_table_entries_do_not_jump_more_than_2x(self, n_nodes):
for i in range(1, len(table)):
assert table[i] <= 2.0 * table[i - 1] + 1e-12, (
f"n_nodes={n_nodes}: table[{i}]={table[i]:.6f} > "
f"2 * table[{i-1}]={2*table[i-1]:.6f}"
f"2 * table[{i - 1}]={2 * table[i - 1]:.6f}"
)

@pytest.mark.parametrize("n_nodes", [8, 16, 32])
Expand Down
Loading