diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index b031c793..9ee4d5c9 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -79,5 +79,4 @@ jobs: pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 pip install --quiet . python examples/example_dcp.py - # TODO(#436): Re-enable once OpStrategy.__str__ handles None specs in PyTorch. - # torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py + torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py diff --git a/autoparallel/graph_passes/graph_clustering.py b/autoparallel/graph_passes/graph_clustering.py index 02995c9c..4503e611 100644 --- a/autoparallel/graph_passes/graph_clustering.py +++ b/autoparallel/graph_passes/graph_clustering.py @@ -26,7 +26,6 @@ tree_flatten, ) from torch._inductor.codecache import sha256_hash -from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import OpStrategy logger: logging.Logger = logging.getLogger(__name__) @@ -53,24 +52,7 @@ def _normalize_args( return (sorted_keys, tuple(_extract_args(arg) for arg in all_args)) -def _print_output_specs(op_strategy): - output = [] - for s in op_strategy.strategies: - output_placements = [] - output_specs = s.output_specs - if isinstance(output_specs, DTensorSpec): - output_specs = [output_specs] - for output_spec in output_specs: - if output_spec is None: - output_placements.append("(None)") - continue - plc_str = ",".join([str(p) for p in output_spec.placements]) - output_placements.append(f"({plc_str})") - output.append(f"({','.join(output_placements)})") - return ", ".join(output) - - -def _prepare_op_strategy(op_strategy, output_only=False): +def _prepare_op_strategy(op_strategy): # hasing op_strategy is expensive, so we hash the string representation # instead, which is much cheaper and is a reasonable proxy for the # clustering @@ -80,8 +62,6 @@ def _prepare_op_strategy(op_strategy, output_only=False): # view ops, which propagate the input shardings to the output. # So we also add the strategy for a node as a hash key to avoid # clustering nodes that look the same but have different strategies - if output_only: - return _print_output_specs(op_strategy) return str(op_strategy) @@ -93,7 +73,7 @@ def _hash_node(node, strategies, input_pickler): _normalize_args(node), _prepare_op_strategy(strategies[node]), tuple( - _prepare_op_strategy(strategies[s], output_only=True) + _prepare_op_strategy(strategies[s]) for s in node.all_input_nodes if s in strategies ), diff --git a/tests/test_graph_clustering.py b/tests/test_graph_clustering.py index d290fd91..b4522c59 100644 --- a/tests/test_graph_clustering.py +++ b/tests/test_graph_clustering.py @@ -11,6 +11,11 @@ from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard +from autoparallel._testing.models.dsv3 import ( + DeepSeekV3Model, + DeepSeekV3ModelArgs, + MoEArgs, +) from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs from autoparallel.api import AutoParallel from autoparallel.graph_passes.graph_clustering import get_identical_regions @@ -27,17 +32,21 @@ def _get_layer_index(node): return None -def _clustering_stats(graph, strats, n_layers): - """Compute clustering statistics.""" - clusters = get_identical_regions(graph, strats) +def _clustered_nodes(clusters): + """Return every FX node that appears in any final clustered region.""" + clustered = set() + for group in clusters: + for region in group: + clustered.update(region) + return clustered - clustered_nodes = set() + +def _clustering_stats(graph, strats, clusters, n_layers): + """Compute per-layer clustering coverage from precomputed clusters.""" + clustered_nodes = _clustered_nodes(clusters) regions_per_group = [] for group in clusters: regions_per_group.append(len(group)) - for region in group: - for node in region: - clustered_nodes.add(node) per_layer_clustered = Counter() per_layer_total = Counter() @@ -62,6 +71,117 @@ def _clustering_stats(graph, strats, n_layers): } +def _assert_layer_coverage(stats, n_layers, min_coverage, label): + """Assert each repeated layer has enough nodes in clustered regions.""" + layer_totals = [stats["per_layer_total"].get(i, 0) for i in range(n_layers)] + if len(set(layer_totals)) != 1: + raise AssertionError( + f"{label}: layers have different node counts: {layer_totals}" + ) + + total = layer_totals[0] + if total == 0: + raise AssertionError(f"{label}: no layer nodes found in clustering stats") + for layer_idx in range(n_layers): + clustered = stats["per_layer_clustered"].get(layer_idx, 0) + coverage = clustered / total + if coverage < min_coverage: + raise AssertionError( + f"{label}: layer {layer_idx} clustering coverage too low: " + f"{clustered}/{total} = {coverage:.1%}, " + f"expected >= {min_coverage:.1%}" + ) + + +def _assert_cross_layer_cluster( + clusters, expected_layers, min_region_size, phase, label +): + """Assert there is one large same-phase region for each expected layer.""" + expected_layers = set(expected_layers) + for group in clusters: + if len(group) != len(expected_layers): + continue + region_layers = [] + for region in group: + tags = {n.meta.get("partitioner_tag") for n in region} + tags.discard(None) + if tags != {phase}: + break + layers = {idx for n in region if (idx := _get_layer_index(n)) is not None} + if len(region) < min_region_size or len(layers) != 1: + break + region_layers.append(next(iter(layers))) + else: + if set(region_layers) == expected_layers: + return + raise AssertionError( + f"{label}: missing {phase} cross-layer cluster for layers " + f"{sorted(expected_layers)} with min region size {min_region_size}" + ) + + +def _assert_no_forward_backward_mixing(clusters, label): + """Assert clustered regions never contain both forward and backward nodes.""" + for i, group in enumerate(clusters): + for j, region in enumerate(group): + tags = set(n.meta.get("partitioner_tag") for n in region) + tags.discard(None) + if len(tags) > 1: + raise AssertionError( + f"{label}: cluster group {i}, region {j} mixes phases: {tags}" + ) + + +def _run_clustering(autop, n_layers, input_sharding, output_sharding=None): + """Build AutoParallel state and return clustering stats plus raw clusters.""" + if output_sharding is None: + output_sharding = input_sharding + with autop: + autop.add_input_constraints([input_sharding]) + autop.add_output_constraints([output_sharding]) + + graph = autop.sharding_optimizer.graph + strats = autop.sharding_optimizer.strats + clusters = get_identical_regions(graph, strats) + stats = _clustering_stats(graph, strats, clusters, n_layers) + return stats, clusters + + +def _assert_model_clustering( + stats, + clusters, + *, + label, + n_layers, + min_coverage, + forward_layers, + backward_layers, + min_region_size=100, +): + """Assert coverage, phase separation, and large fwd/bwd layer clusters.""" + _assert_layer_coverage( + stats, + n_layers, + min_coverage=min_coverage, + label=label, + ) + _assert_cross_layer_cluster( + clusters, + forward_layers, + min_region_size=min_region_size, + phase="is_forward", + label=label, + ) + _assert_cross_layer_cluster( + clusters, + backward_layers, + min_region_size=min_region_size, + phase="is_backward", + label=label, + ) + _assert_no_forward_backward_mixing(clusters, label=label) + + def _setup_llama_autop(device_mesh_2d, n_layers=4): """Set up AutoParallel with a small LLaMA model.""" vocab_size = 2048 @@ -93,6 +213,31 @@ def input_fn(): return autop, model_args +def _setup_ds3_local_map_autop(device_mesh_2d, n_layers=2): + global_batch_size = 2 * device_mesh_2d.shape[0] * device_mesh_2d.shape[1] + moe_args = MoEArgs(mesh=device_mesh_2d) + config = DeepSeekV3ModelArgs( + n_layers=n_layers, + n_dense_layers=0, + moe_args=moe_args, + ) + with torch.device("meta"): + model = DeepSeekV3Model(config).bfloat16() + for module in model.modules(): + if hasattr(module, "axis_name"): + module.axis_name = device_mesh_2d.mesh_dim_names[1] + + def input_fn(): + return torch.randint( + 0, + config.vocab_size, + (global_batch_size, config.max_seq_len), + device="cuda", + ) + + return AutoParallel(model, input_fn, device_mesh_2d, dynamic=True) + + def test_clustering_high_coverage(device_mesh_2d): """The vast majority of layer-specific nodes should be clustered. @@ -102,56 +247,39 @@ def test_clustering_high_coverage(device_mesh_2d): """ n_layers = 4 autop, _ = _setup_llama_autop(device_mesh_2d, n_layers=n_layers) - with autop: - x_sharding = (Shard(0), Replicate()) - out_sharding = (Shard(0), Shard(2)) - autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([out_sharding]) - - stats = _clustering_stats( - autop.sharding_optimizer.graph, - autop.sharding_optimizer.strats, - n_layers, - ) - - # Every layer should have the same total node count - layer_totals = [stats["per_layer_total"].get(i, 0) for i in range(n_layers)] - assert ( - len(set(layer_totals)) == 1 - ), f"Layers have different node counts: {layer_totals}" - - # At least 50% of layer nodes should be clustered across all layers - total = layer_totals[0] - for layer_idx in range(n_layers): - clustered = stats["per_layer_clustered"].get(layer_idx, 0) - coverage = clustered / total - assert coverage >= 0.50, ( - f"Layer {layer_idx} clustering coverage too low: " - f"{clustered}/{total} = {coverage:.1%}" - ) - - -def test_clustering_no_forward_backward_mixing(device_mesh_2d): - """Each cluster group's regions should contain only forward or only - backward nodes, never a mix. Expansion must not cross the phase boundary - by following saved-tensor edges from backward into forward.""" - n_layers = 4 - autop, _ = _setup_llama_autop(device_mesh_2d, n_layers=n_layers) - with autop: - x_sharding = (Shard(0), Replicate()) - out_sharding = (Shard(0), Shard(2)) - autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([out_sharding]) - - clusters = get_identical_regions( - autop.sharding_optimizer.graph, autop.sharding_optimizer.strats - ) + stats, clusters = _run_clustering( + autop, + n_layers, + input_sharding=(Shard(0), Replicate()), + output_sharding=(Shard(0), Shard(2)), + ) + _assert_model_clustering( + stats, + clusters, + label="LLaMA", + n_layers=n_layers, + min_coverage=0.50, + forward_layers=range(n_layers), + # Layer 0 has known backward boundary asymmetry. + backward_layers=range(1, n_layers), + ) - for i, group in enumerate(clusters): - for j, region in enumerate(group): - tags = set(n.meta.get("partitioner_tag") for n in region) - tags.discard(None) - assert len(tags) <= 1, f"Cluster group {i}, region {j} mixes phases: {tags}" + n_layers = 2 + autop = _setup_ds3_local_map_autop(device_mesh_2d, n_layers=n_layers) + stats, clusters = _run_clustering( + autop, + n_layers, + input_sharding=(Shard(0), Shard(0)), + ) + _assert_model_clustering( + stats, + clusters, + label="DS3", + n_layers=n_layers, + min_coverage=0.75, + forward_layers=range(n_layers), + backward_layers=range(n_layers), + ) def test_getitem_siblings_are_clustered(device_mesh_2d):