Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,4 @@ jobs:
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --quiet .
python examples/example_dcp.py
# TODO(#436): Re-enable once OpStrategy.__str__ handles None specs in PyTorch.
# torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py
torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py
24 changes: 2 additions & 22 deletions autoparallel/graph_passes/graph_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
tree_flatten,
)
from torch._inductor.codecache import sha256_hash
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import OpStrategy

logger: logging.Logger = logging.getLogger(__name__)
Expand All @@ -53,24 +52,7 @@ def _normalize_args(
return (sorted_keys, tuple(_extract_args(arg) for arg in all_args))


def _print_output_specs(op_strategy):
output = []
for s in op_strategy.strategies:
output_placements = []
output_specs = s.output_specs
if isinstance(output_specs, DTensorSpec):
output_specs = [output_specs]
for output_spec in output_specs:
if output_spec is None:
output_placements.append("(None)")
continue
plc_str = ",".join([str(p) for p in output_spec.placements])
output_placements.append(f"({plc_str})")
output.append(f"({','.join(output_placements)})")
return ", ".join(output)


def _prepare_op_strategy(op_strategy, output_only=False):
def _prepare_op_strategy(op_strategy):
# hasing op_strategy is expensive, so we hash the string representation
# instead, which is much cheaper and is a reasonable proxy for the
# clustering
Expand All @@ -80,8 +62,6 @@ def _prepare_op_strategy(op_strategy, output_only=False):
# view ops, which propagate the input shardings to the output.
# So we also add the strategy for a node as a hash key to avoid
# clustering nodes that look the same but have different strategies
if output_only:
return _print_output_specs(op_strategy)
return str(op_strategy)


Expand All @@ -93,7 +73,7 @@ def _hash_node(node, strategies, input_pickler):
_normalize_args(node),
_prepare_op_strategy(strategies[node]),
tuple(
_prepare_op_strategy(strategies[s], output_only=True)
_prepare_op_strategy(strategies[s])
for s in node.all_input_nodes
if s in strategies
),
Expand Down
240 changes: 184 additions & 56 deletions tests/test_graph_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor.placement_types import Replicate, Shard

from autoparallel._testing.models.dsv3 import (
DeepSeekV3Model,
DeepSeekV3ModelArgs,
MoEArgs,
)
from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs
from autoparallel.api import AutoParallel
from autoparallel.graph_passes.graph_clustering import get_identical_regions
Expand All @@ -27,17 +32,21 @@ def _get_layer_index(node):
return None


def _clustering_stats(graph, strats, n_layers):
"""Compute clustering statistics."""
clusters = get_identical_regions(graph, strats)
def _clustered_nodes(clusters):
"""Return every FX node that appears in any final clustered region."""
clustered = set()
for group in clusters:
for region in group:
clustered.update(region)
return clustered

clustered_nodes = set()

def _clustering_stats(graph, strats, clusters, n_layers):
"""Compute per-layer clustering coverage from precomputed clusters."""
clustered_nodes = _clustered_nodes(clusters)
regions_per_group = []
for group in clusters:
regions_per_group.append(len(group))
for region in group:
for node in region:
clustered_nodes.add(node)

per_layer_clustered = Counter()
per_layer_total = Counter()
Expand All @@ -62,6 +71,117 @@ def _clustering_stats(graph, strats, n_layers):
}


def _assert_layer_coverage(stats, n_layers, min_coverage, label):
"""Assert each repeated layer has enough nodes in clustered regions."""
layer_totals = [stats["per_layer_total"].get(i, 0) for i in range(n_layers)]
if len(set(layer_totals)) != 1:
raise AssertionError(
f"{label}: layers have different node counts: {layer_totals}"
)

total = layer_totals[0]
if total == 0:
raise AssertionError(f"{label}: no layer nodes found in clustering stats")
for layer_idx in range(n_layers):
clustered = stats["per_layer_clustered"].get(layer_idx, 0)
coverage = clustered / total
if coverage < min_coverage:
raise AssertionError(
f"{label}: layer {layer_idx} clustering coverage too low: "
f"{clustered}/{total} = {coverage:.1%}, "
f"expected >= {min_coverage:.1%}"
)


def _assert_cross_layer_cluster(
clusters, expected_layers, min_region_size, phase, label
):
"""Assert there is one large same-phase region for each expected layer."""
expected_layers = set(expected_layers)
for group in clusters:
if len(group) != len(expected_layers):
continue
region_layers = []
for region in group:
tags = {n.meta.get("partitioner_tag") for n in region}
tags.discard(None)
if tags != {phase}:
break
layers = {idx for n in region if (idx := _get_layer_index(n)) is not None}
if len(region) < min_region_size or len(layers) != 1:
break
region_layers.append(next(iter(layers)))
else:
if set(region_layers) == expected_layers:
return
raise AssertionError(
f"{label}: missing {phase} cross-layer cluster for layers "
f"{sorted(expected_layers)} with min region size {min_region_size}"
)


def _assert_no_forward_backward_mixing(clusters, label):
"""Assert clustered regions never contain both forward and backward nodes."""
for i, group in enumerate(clusters):
for j, region in enumerate(group):
tags = set(n.meta.get("partitioner_tag") for n in region)
tags.discard(None)
if len(tags) > 1:
raise AssertionError(
f"{label}: cluster group {i}, region {j} mixes phases: {tags}"
)


def _run_clustering(autop, n_layers, input_sharding, output_sharding=None):
"""Build AutoParallel state and return clustering stats plus raw clusters."""
if output_sharding is None:
output_sharding = input_sharding
with autop:
autop.add_input_constraints([input_sharding])
autop.add_output_constraints([output_sharding])

graph = autop.sharding_optimizer.graph
strats = autop.sharding_optimizer.strats
clusters = get_identical_regions(graph, strats)
stats = _clustering_stats(graph, strats, clusters, n_layers)
return stats, clusters


def _assert_model_clustering(
stats,
clusters,
*,
label,
n_layers,
min_coverage,
forward_layers,
backward_layers,
min_region_size=100,
):
"""Assert coverage, phase separation, and large fwd/bwd layer clusters."""
_assert_layer_coverage(
stats,
n_layers,
min_coverage=min_coverage,
label=label,
)
_assert_cross_layer_cluster(
clusters,
forward_layers,
min_region_size=min_region_size,
phase="is_forward",
label=label,
)
_assert_cross_layer_cluster(
clusters,
backward_layers,
min_region_size=min_region_size,
phase="is_backward",
label=label,
)
_assert_no_forward_backward_mixing(clusters, label=label)


def _setup_llama_autop(device_mesh_2d, n_layers=4):
"""Set up AutoParallel with a small LLaMA model."""
vocab_size = 2048
Expand Down Expand Up @@ -93,6 +213,31 @@ def input_fn():
return autop, model_args


def _setup_ds3_local_map_autop(device_mesh_2d, n_layers=2):
global_batch_size = 2 * device_mesh_2d.shape[0] * device_mesh_2d.shape[1]
moe_args = MoEArgs(mesh=device_mesh_2d)
config = DeepSeekV3ModelArgs(
n_layers=n_layers,
n_dense_layers=0,
moe_args=moe_args,
)
with torch.device("meta"):
model = DeepSeekV3Model(config).bfloat16()
for module in model.modules():
if hasattr(module, "axis_name"):
module.axis_name = device_mesh_2d.mesh_dim_names[1]

def input_fn():
return torch.randint(
0,
config.vocab_size,
(global_batch_size, config.max_seq_len),
device="cuda",
)

return AutoParallel(model, input_fn, device_mesh_2d, dynamic=True)


def test_clustering_high_coverage(device_mesh_2d):
"""The vast majority of layer-specific nodes should be clustered.

Expand All @@ -102,56 +247,39 @@ def test_clustering_high_coverage(device_mesh_2d):
"""
n_layers = 4
autop, _ = _setup_llama_autop(device_mesh_2d, n_layers=n_layers)
with autop:
x_sharding = (Shard(0), Replicate())
out_sharding = (Shard(0), Shard(2))
autop.add_input_constraints([x_sharding])
autop.add_output_constraints([out_sharding])

stats = _clustering_stats(
autop.sharding_optimizer.graph,
autop.sharding_optimizer.strats,
n_layers,
)

# Every layer should have the same total node count
layer_totals = [stats["per_layer_total"].get(i, 0) for i in range(n_layers)]
assert (
len(set(layer_totals)) == 1
), f"Layers have different node counts: {layer_totals}"

# At least 50% of layer nodes should be clustered across all layers
total = layer_totals[0]
for layer_idx in range(n_layers):
clustered = stats["per_layer_clustered"].get(layer_idx, 0)
coverage = clustered / total
assert coverage >= 0.50, (
f"Layer {layer_idx} clustering coverage too low: "
f"{clustered}/{total} = {coverage:.1%}"
)


def test_clustering_no_forward_backward_mixing(device_mesh_2d):
"""Each cluster group's regions should contain only forward or only
backward nodes, never a mix. Expansion must not cross the phase boundary
by following saved-tensor edges from backward into forward."""
n_layers = 4
autop, _ = _setup_llama_autop(device_mesh_2d, n_layers=n_layers)
with autop:
x_sharding = (Shard(0), Replicate())
out_sharding = (Shard(0), Shard(2))
autop.add_input_constraints([x_sharding])
autop.add_output_constraints([out_sharding])

clusters = get_identical_regions(
autop.sharding_optimizer.graph, autop.sharding_optimizer.strats
)
stats, clusters = _run_clustering(
autop,
n_layers,
input_sharding=(Shard(0), Replicate()),
output_sharding=(Shard(0), Shard(2)),
)
_assert_model_clustering(
stats,
clusters,
label="LLaMA",
n_layers=n_layers,
min_coverage=0.50,
forward_layers=range(n_layers),
# Layer 0 has known backward boundary asymmetry.
backward_layers=range(1, n_layers),
)

for i, group in enumerate(clusters):
for j, region in enumerate(group):
tags = set(n.meta.get("partitioner_tag") for n in region)
tags.discard(None)
assert len(tags) <= 1, f"Cluster group {i}, region {j} mixes phases: {tags}"
n_layers = 2
autop = _setup_ds3_local_map_autop(device_mesh_2d, n_layers=n_layers)
stats, clusters = _run_clustering(
autop,
n_layers,
input_sharding=(Shard(0), Shard(0)),
)
_assert_model_clustering(
stats,
clusters,
label="DS3",
n_layers=n_layers,
min_coverage=0.75,
forward_layers=range(n_layers),
backward_layers=range(n_layers),
)


def test_getitem_siblings_are_clustered(device_mesh_2d):
Expand Down
Loading