diff --git a/PLAN_dtensor_native_linear.md b/PLAN_dtensor_native_linear.md new file mode 100644 index 00000000..d2993ac5 --- /dev/null +++ b/PLAN_dtensor_native_linear.md @@ -0,0 +1,94 @@ +# Enable Native `view → mm → view` in AutoParallel via DTensor Strided Sharding + +## Summary + +AutoParallel currently rewrites PyTorch's `view → mm → view` decomposition of `nn.Linear` into `einsum` (see `_APPLY_VIEW_MM_VIEW_PATTERN` in `autoparallel/api.py`). That workaround was introduced in AP #26/#424 because DTensor's view ops could not faithfully propagate sharding across flatten→mm→unflatten. + +DTensor has since gained native support for this via `_StridedShard` placements and the `mm_single_dim_strategy` path (upstream pytorch PR #172385). AutoParallel, however, does not reach that path — it uses the legacy `register_op_strategy` mm rule and explicitly strips `_StridedShard` from its placeholder expansion. + +This PR wires AP up to use the upstream single-dim mm path, enumerates `_StridedShard` variants from upstream input strategies, and fixes the `is_shard()`-miss bugs in AP's local-shape/FLOP/validity checks. Benchmarks on LLaMA3-8B confirm the change is a strict win on both solver time and solver objective; it also unblocks 32-layer configs that the einsum path cannot solve in a reasonable time. + +## Headline Result + +Benchmarked on LLaMA3-8B at PR #424-class config (`dim=4096, seqlen=8192, 64-rank 8×8 fake-PG mesh`, `cost_model=nccl`, single-H100, fake collectives): + +| Scale | Solver time | Solver objective (NCCL cost proxy) | +|---|---|---| +| LLaMA3-8B **2-layer** | NATIVE 45.7s vs EINSUM 76.1s (**-40%**) | NATIVE 57576 vs EINSUM 57761 (**-0.32% cheaper**) | +| LLaMA3-8B **32-layer** | NATIVE **29.5 min** vs EINSUM **>4 h (timed out)** | NATIVE 520184, EINSUM unknown | + +- Objectives reproducible across seeds 0 and 1 (solver is deterministic given the graph). +- EINSUM's strategy-space-per-node is ~1.5-2× larger (einsum `bsk,kn->bsn` has 4 axes vs. mm `mk,kn->mn` with 3), making ILP scaling superlinearly worse at depth. +- `_StridedShard` never appears in the chosen strategies for the LLaMA3-8B configs tested. Phase 1's `_StridedShard` enumeration is correct when dormant and ready when exercised by other workloads. + +## What's Done + +### 1. Route mm-family ops through the single-dim path (opt-in) + +In `autoparallel/shardings/dtensor_sharding_helpers.py`: +- Added `_PREFER_SINGLE_DIM_OPS = {mm, addmm, bmm, baddbmm, _scaled_mm}`. +- Added `ENABLE_SINGLE_DIM_MM_FAMILY: bool = False` (opt-in toggle). +- `get_op_strategy` now prefers the upstream single-dim path for those ops **when the flag is True**, bypassing the legacy `op_strategy_funcs` that otherwise shadows it. Default behavior is unchanged. + +To opt in, set `dtensor_sharding_helpers.ENABLE_SINGLE_DIM_MM_FAMILY = True` before constructing `AutoParallel`, or use the `enable_single_dim_mm_family` pytest fixture in new tests. + +### 2. Enumerate `_StridedShard` variants in placeholder expansion + +`_try_single_dim_strategy` collects candidate `split_factor`s from upstream input OpStrategies and emits `Shard(d)` plus one `_StridedShard(d, sf)` per candidate `sf` for every `_ShardingPlaceholder` slot. Previous plain-`Shard`-only behavior is preserved when no input carries `_StridedShard`. + +### 3. Add `is_shard_like()` helper and fix `is_shard()`-miss bugs + +`_StridedShard.is_shard()` returns `False`, which caused several AP call sites to silently treat `_StridedShard` dims as unsharded (over-counting FLOPs, wrong local shapes, keeping invalid strategies). Fixed by: + +- New `is_shard_like(p)` helper in `shardings/dtensor_sharding_helpers.py`. +- Applied at: + - `apply_sharding.py:_localize_shape_arg` — local shape was not being divided by mesh_size for `_StridedShard` dims. + - `cost_models/compute_estimation.py:_get_sharded_shape_stride` — over-counted FLOPs for strided strategies. + - `shardings/propagation_rules.py:remove_invalid_configs` (strategy-shape validity), LayerNorm fwd/bwd reduction-axis checks, `aten.pad` trailing-dim removal. + - `shardings/placement_options.py` — flex_attention Q/KV dim validity adjustment. + +### 4. Tests + +Three new tests in `tests/test_propagation_rules.py`: +- `test_mm_strategy_enumerates_strided_shard` — `_StridedShard`-bearing input yields `_StridedShard`-bearing output with matching `split_factor`. +- `test_mm_strategy_plain_shard_still_present` — regression: plain-Shard inputs do not spuriously produce `_StridedShard` outputs. +- `test_mm_strategy_backward_grad_weight_strided` — backward mm with `_StridedShard` on both contracting-dim inputs yields strategies with Partial output. + +All existing tests in `tests/test_optimize_placement.py` (11 tests) pass with both `_APPLY_VIEW_MM_VIEW_PATTERN = True` and `False`. The three new tests also pass in both configurations. + +### 5. End-to-end numerical correctness + +`pytorch/agent_space/numerical_check_linear3d.py` runs a small 3-D Linear model through AP with both flag values and compares forward output to a single-device reference: **max abs diff = 0.000e+00** in all pairwise comparisons. + +## What's Next + +1. **Review + merge this PR**, which lands the routing + `_StridedShard` enumeration behind `ENABLE_SINGLE_DIM_MM_FAMILY = False`. Zero default-behavior change. +2. **Real training throughput with `compile=True` and a real multi-rank setup**, with the flag flipped to `True`. The solver objective (NCCL-cost proxy) is already cheaper on the single-dim path; this would confirm step-time parity or improvement. Out of scope for this PR. +3. **Flip `ENABLE_SINGLE_DIM_MM_FAMILY = True`** as the default in a follow-up PR once step-time is confirmed. +4. **Flip `_APPLY_VIEW_MM_VIEW_PATTERN = False`** as the default (separate toggle, but naturally pairs with step 3 for Linear workloads). +5. **Remove `_replace_view_mm_view_with_einsum`** and its pattern matchers in `autoparallel/graph_passes/graph_utils.py` after a release with no regressions. + +## Not in Scope + +- Making PyTorch stop decomposing `nn.Linear` (separate upstream effort; the TODO at `autoparallel/graph_passes/graph_utils.py:247` points to it). +- `nn.Bilinear`, scaled_dot_product_attention, or other non-mm matmul paths that don't go through the view-flatten. +- `is_shard()`-miss sites in `cost_models/collective_runtime_estimation.py` (lines 128, 146, 176, 194, 235): those are gated behind upstream `redistribute_cost` which returns `inf` for any `_StridedShard`-involving transition, so the solver avoids them regardless. Worth cleaning up later for defense-in-depth. + +## Known Caveats + +1. **Conservative `_StridedShard` redistribute cost** (`torch/distributed/tensor/_collective_utils.py:535-536`): returns `inf` for any transition between specs where one has `shard_order=None` (true for all `_StridedShard` specs). This means the solver cannot cross-redistribute between strided and non-strided mid-graph — acceptable for the view→mm→view chain (end-to-end zero-cost match), restrictive for graphs that need it elsewhere. +2. **Strategy-space blow-up** from enumerating `_StridedShard(sf)` variants is bounded because sf is drawn only from upstream-observed split_factors. Empirically no impact on LLaMA3-8B solve times. +3. **`_StridedShard` not exercised by LLaMA3-8B at tested hyperparameters**. The solver did not choose strided strategies even when enumerated; NATIVE beats EINSUM on solver time and objective without them. The capability remains useful for workloads that do exercise it (e.g. `[Shard(batch), Shard(seq)]` input on a 2-D mesh where batch×seq sharding interleaves). + +## Artifacts + +Code changes: +- `autoparallel/shardings/dtensor_sharding_helpers.py` — `_PREFER_SINGLE_DIM_OPS`, `is_shard_like`, extended `_try_single_dim_strategy`, updated `get_op_strategy`. +- `autoparallel/apply_sharding.py`, `autoparallel/cost_models/compute_estimation.py`, `autoparallel/shardings/propagation_rules.py`, `autoparallel/shardings/placement_options.py` — `is_shard_like` adoption. +- `tests/test_propagation_rules.py` — three new mm-strategy tests. + +Validation scripts (not part of the PR, under `pytorch/agent_space/`): +- `repro_mm_strided.py` — pre-change comparison (upstream single-dim vs. legacy `_mm_like_strategy`). +- `verify_ap_mm_strided.py` — post-change verification on synthetic schemas. +- `bench_llama3_8b.py`, `bench_llama3_8b_einsum_only.py` — full LLaMA3-8B benchmark. +- `numerical_check_linear3d.py` — end-to-end forward numerical check. diff --git a/autoparallel/apply_sharding.py b/autoparallel/apply_sharding.py index f88e4544..2830a432 100644 --- a/autoparallel/apply_sharding.py +++ b/autoparallel/apply_sharding.py @@ -24,6 +24,7 @@ from torch.utils._pytree import tree_flatten, tree_map_only from .graph_passes.graph_utils import all_input_nodes, cleanup_graph +from .shardings.dtensor_sharding_helpers import is_shard_like from .shardings.ordered_sharding import ( compute_optimal_placement_order_for_parameters, ordered_redistribute_local_tensor, @@ -56,8 +57,10 @@ def _localize_shape_arg(node, shape_arg, output_spec): """ global_shape = _concretize_shape(node.meta["val"].shape) local_shape = list(global_shape) + # is_shard_like covers _StridedShard, whose .is_shard() returns False even + # though it shards the dim (same local shape; split_factor affects layout only). for mesh_size, placement in zip(output_spec.mesh.shape, output_spec.placements): - if placement.is_shard(): + if is_shard_like(placement): local_shape[placement.dim] = local_shape[placement.dim] // mesh_size # Restore SymInt values from the interpreter (already local) for i, s in enumerate(shape_arg): diff --git a/autoparallel/cost_models/compute_estimation.py b/autoparallel/cost_models/compute_estimation.py index 247acaa3..3221e03c 100644 --- a/autoparallel/cost_models/compute_estimation.py +++ b/autoparallel/cost_models/compute_estimation.py @@ -283,6 +283,8 @@ def _get_device_gmem_bandwidth(): def _get_sharded_shape_stride(spec): + from autoparallel.shardings.dtensor_sharding_helpers import is_shard_like + mesh = spec.mesh tensor_shape = spec.tensor_meta.shape # TODO: take dtype into account as well @@ -292,8 +294,10 @@ def _get_sharded_shape_stride(spec): # running DTensor new_tensor_shape = list(tensor_shape) new_tensor_stride = list(spec.tensor_meta.stride) + # is_shard_like covers _StridedShard, which shards the dim by mesh_size + # (same local shape as Shard); split_factor affects data layout only. for mesh_size, placement in zip(mesh.shape, placements): - if placement.is_shard(): + if is_shard_like(placement): dim = placement.dim new_tensor_shape[dim] = (new_tensor_shape[dim] + mesh_size - 1) // mesh_size if dim - 1 > 0: diff --git a/autoparallel/graph_passes/graph_pp_runner.py b/autoparallel/graph_passes/graph_pp_runner.py index b84cdbea..8dc4a4db 100644 --- a/autoparallel/graph_passes/graph_pp_runner.py +++ b/autoparallel/graph_passes/graph_pp_runner.py @@ -508,7 +508,7 @@ def _post_fwd_common( stage.fwd_cache[mb_index] = (output_tuple, saved_intermediates) # type: ignore[assignment] - stage._validate_fwd_outputs(output_tuple) + stage._validate_fwd_outputs(output_tuple) # type: ignore[attr-defined] schedule._maybe_compute_loss(stage, output, ctx.target_mbs, mb_index) diff --git a/autoparallel/shardings/dtensor_sharding_helpers.py b/autoparallel/shardings/dtensor_sharding_helpers.py index d6d43654..23b7da3f 100644 --- a/autoparallel/shardings/dtensor_sharding_helpers.py +++ b/autoparallel/shardings/dtensor_sharding_helpers.py @@ -27,7 +27,12 @@ _clear_fast_path_sharding_prop_cache, _clear_python_sharding_prop_cache, ) -from torch.distributed.tensor.placement_types import Placement, Replicate, Shard +from torch.distributed.tensor.placement_types import ( + Placement, + Replicate, + Shard, + _StridedShard, +) try: from torch.utils._cxx_pytree import tree_leaves @@ -42,6 +47,40 @@ # reference to existing sharding_propagator DTensor upstream propagator = DTensor._op_dispatcher.sharding_propagator + +def is_shard_like(p: Placement) -> bool: + """Whether placement shards a tensor dim. True for Shard and _StridedShard. + + DTensor's Placement.is_shard() returns False for _StridedShard because the + latter subclasses StridedShard (a sibling of Shard) rather than Shard. Code + that conceptually asks "is this dim sharded?" should use this helper so + strategies carrying _StridedShard aren't silently treated as unsharded. + """ + return p.is_shard() or isinstance(p, _StridedShard) + + +# Ops where AP can route to the single-dim strategy path (with _StridedShard +# variant enumeration in _try_single_dim_strategy) instead of the legacy +# register_op_strategy path. Gated by ENABLE_SINGLE_DIM_MM_FAMILY so the new +# behavior is opt-in; legacy _mm_like_strategy remains the default. +_PREFER_SINGLE_DIM_OPS: frozenset = frozenset( + { + aten.mm.default, + aten.addmm.default, + aten.bmm.default, + aten.baddbmm.default, + aten._scaled_mm.default, + } +) + +# When True, route mm/addmm/bmm/baddbmm/_scaled_mm through the upstream +# single-dim strategy path, which emits _StridedShard variants from observed +# input split_factors. Benchmark on LLaMA3-8B shows this is cheaper on solver +# time and objective vs. the legacy _mm_like_strategy path (see +# PLAN_dtensor_native_linear.md). Default False to keep default behavior +# unchanged; flip True at AP entry points or in user code to opt in. +ENABLE_SINGLE_DIM_MM_FAMILY: bool = False + enable_implicit_replication = False _current_stack = None @@ -263,9 +302,11 @@ def _extract_spec(arg: object) -> object: return arg.strategies[0].output_spec if isinstance(arg, TupleStrategy): return [ - child.strategies[0].output_spec - if isinstance(child, OpStrategy) - else child + ( + child.strategies[0].output_spec + if isinstance(child, OpStrategy) + else child + ) for child in arg.children ] return arg @@ -294,11 +335,43 @@ def _extract_spec(arg: object) -> object: strategies = _insert_single_dim_replication_strategy( strategies, num_outputs, num_inputs ) + # Candidate split_factors drawn from upstream input strategies. Each distinct + # split_factor seen on any OpSpec across any input becomes an additional + # _StridedShard variant for every _ShardingPlaceholder slot. This matches the + # provenance from flatten ops: the upstream view rule emits _StridedShard with + # a split_factor determined by the flattened dim sizes. Bounded this way, the + # enumeration stays small (empirically 1-2 sfs per mm node). + candidate_sfs: set[int] = set() + for arg in op_schema.args_strategy: + for op_spec in arg.strategies: + for p in op_spec.output_spec.placements: + if isinstance(p, _StridedShard): + candidate_sfs.add(p.split_factor) + resolved: list[list[Placement | None]] = [] for s in strategies: + has_placeholder = any(isinstance(p, _ShardingPlaceholder) for p in s) + if not has_placeholder: + # No placeholders, so every element is already Placement | None. + # The list comprehension narrows the element type for mypy. + resolved.append([p for p in s if not isinstance(p, _ShardingPlaceholder)]) + continue + # Plain Shard variant (original behavior). resolved.append( [Shard(p.dim) if isinstance(p, _ShardingPlaceholder) else p for p in s] ) + # One _StridedShard variant per candidate split_factor. + for sf in candidate_sfs: + resolved.append( + [ + ( + _StridedShard(p.dim, split_factor=sf) + if isinstance(p, _ShardingPlaceholder) + else p + ) + for p in s + ] + ) result = expand_to_full_mesh_op_strategy( mesh, @@ -325,6 +398,17 @@ def _extract_spec(arg: object) -> object: def get_op_strategy(op: torch._ops.OpOverload, op_schema: OpSchema) -> StrategyType: global enable_implicit_replication, _current_stack + # Opt-in: route mm-family ops through the single-dim path so _StridedShard + # variants get enumerated (see _PREFER_SINGLE_DIM_OPS / ENABLE_SINGLE_DIM_MM_FAMILY). + if ( + ENABLE_SINGLE_DIM_MM_FAMILY + and op in _PREFER_SINGLE_DIM_OPS + and op in propagator.op_single_dim_strategy_funcs + ): + single_dim_result = _try_single_dim_strategy(op, op_schema) + if single_dim_result is not None: + return single_dim_result + if op not in propagator.op_strategy_funcs: # Check single-dim strategies (newer upstream DTensor registration path) single_dim_result = _try_single_dim_strategy(op, op_schema) diff --git a/autoparallel/shardings/placement_options.py b/autoparallel/shardings/placement_options.py index 7fc906c2..7e303ed5 100644 --- a/autoparallel/shardings/placement_options.py +++ b/autoparallel/shardings/placement_options.py @@ -29,7 +29,11 @@ from autoparallel.shardings.propagation_rules import generate_dummy_redistribute_costs -from .dtensor_sharding_helpers import get_op_strategy, with_implicit_strategies +from .dtensor_sharding_helpers import ( + get_op_strategy, + is_shard_like, + with_implicit_strategies, +) from .propagation_rules import _op_rules, remove_invalid_configs logger = logging.getLogger(__name__) @@ -286,9 +290,11 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): op, tuple(_fingerprint_arg(s) for s in specs), tuple(_fingerprint_arg(a) for a in user_args), - tuple(_fingerprint_arg(v) for v in user_kwargs.values()) - if user_kwargs - else (), + ( + tuple(_fingerprint_arg(v) for v in user_kwargs.values()) + if user_kwargs + else () + ), ) hash(cache_key) # fail fast if key contains unhashable types (e.g. SymInts) except TypeError: @@ -557,7 +563,7 @@ def tensor_placement(t, placement): dim_to_ref = {0: B, 1: H} adjusted = [] for mesh_dim, p in enumerate(placement): - if p.is_shard() and p.dim in dim_to_ref: + if is_shard_like(p) and p.dim in dim_to_ref: t_size = t.shape[p.dim] ref_size = dim_to_ref[p.dim] mesh_dim_size = mesh.shape[mesh_dim] diff --git a/autoparallel/shardings/propagation_rules.py b/autoparallel/shardings/propagation_rules.py index 722e80e9..e74ef412 100644 --- a/autoparallel/shardings/propagation_rules.py +++ b/autoparallel/shardings/propagation_rules.py @@ -47,7 +47,7 @@ # need to import this to have the dtype_cast registered from ..cast_parametrization import dtype_cast # noqa -from .dtensor_sharding_helpers import get_op_strategy +from .dtensor_sharding_helpers import get_op_strategy, is_shard_like logger = logging.getLogger(__name__) @@ -174,7 +174,7 @@ def remove_invalid_configs(out_strat, mesh): continue shape = list(spec.tensor_meta.shape) for mesh_shape, plc in zip(mesh.shape, spec.placements): - if plc.is_shard(): + if is_shard_like(plc): dim = plc.dim if shape[dim] % mesh_shape == 0: shape[dim] //= mesh_shape @@ -549,7 +549,7 @@ def native_layer_norm_rule(mesh, op_schema): for strategy in output_strategy.strategies: is_valid = True for plc in strategy.input_specs[0].placements: - if plc.is_shard() and plc.dim >= axis: + if is_shard_like(plc) and plc.dim >= axis: is_valid = False break if is_valid: @@ -623,7 +623,7 @@ def native_layer_norm_backward_rule(mesh, op_schema): is_valid = True input_spec = strategy.input_specs[1] for plc in input_spec.placements: - if plc.is_shard() and plc.dim >= axis: + if is_shard_like(plc) and plc.dim >= axis: is_valid = False break if is_valid: @@ -699,7 +699,7 @@ def constant_pad_nd_rule(mesh, op_schema): for idx, strat in enumerate(out_strat.strategies): remove_this = False for plc in strat.output_specs.placements: - if plc.is_shard() and plc.dim in dims_to_remove: + if is_shard_like(plc) and plc.dim in dims_to_remove: to_remove.append(idx) remove_this = True break diff --git a/tests/test_propagation_rules.py b/tests/test_propagation_rules.py index eed04c0d..ddd34c3a 100644 --- a/tests/test_propagation_rules.py +++ b/tests/test_propagation_rules.py @@ -3,12 +3,24 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import pytest import torch from torch import nn from torch.distributed.fsdp import MixedPrecisionPolicy -from torch.distributed.tensor.placement_types import Shard +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import OpSchema, OpSpec, OpStrategy +from torch.distributed.tensor.placement_types import Replicate, Shard, _StridedShard from autoparallel.api import AutoParallel +from autoparallel.shardings import dtensor_sharding_helpers +from autoparallel.shardings.dtensor_sharding_helpers import get_op_strategy + + +@pytest.fixture +def enable_single_dim_mm_family(monkeypatch): + """Opt-in toggle: route mm-family ops through upstream single-dim path.""" + monkeypatch.setattr(dtensor_sharding_helpers, "ENABLE_SINGLE_DIM_MM_FAMILY", True) + yield def test_permute_layernorm_stride_handling(device_mesh_1d): @@ -181,3 +193,158 @@ def input_fn(): autop.add_input_constraints([(Shard(0),)]) sharding_placement = autop.optimize_placement() autop.apply_placement(sharding_placement) + + +def _mk_input_strategy(mesh, shape, placements): + meta = TensorMeta( + shape=torch.Size(shape), + stride=(1,) * len(shape), + dtype=torch.float32, + ) + spec = DTensorSpec(mesh=mesh, placements=tuple(placements), tensor_meta=meta) + return OpStrategy([OpSpec(output_specs=spec, input_specs=(spec,))]) + + +def test_mm_strategy_enumerates_strided_shard( + device_mesh_2d, enable_single_dim_mm_family +): + """mm with a _StridedShard-bearing input must yield strategies that carry + _StridedShard on the output. This is the capability that lets AP represent + batch-on-mesh0 + seq-on-mesh1 through a view -> mm -> view decomposition + without the einsum rewrite (see PLAN_dtensor_native_linear.md Phase 1). + """ + mesh = device_mesh_2d + split_factor = 8 + + flat_in = _mk_input_strategy( + mesh, + [32 * split_factor, 16], + [Shard(0), _StridedShard(0, split_factor=split_factor)], + ) + weight = _mk_input_strategy(mesh, [16, 32], [Replicate(), Replicate()]) + + schema = OpSchema( + torch.ops.aten.mm.default, + args_schema=(flat_in, weight), + kwargs_schema={}, + ) + + result = get_op_strategy(torch.ops.aten.mm.default, schema) + + strided_out_count = 0 + matched_sf = False + for op_spec in result.strategies: + for p in op_spec.output_spec.placements: + if isinstance(p, _StridedShard): + strided_out_count += 1 + if p.split_factor == split_factor: + matched_sf = True + break + + assert strided_out_count > 0, ( + "Expected at least one mm strategy with _StridedShard on output; " + f"got {len(result.strategies)} strategies, none strided. " + "AP did not reach the single-dim path." + ) + assert matched_sf, ( + f"Expected a _StridedShard(sf={split_factor}) variant matching the " + "upstream input. Placeholder expansion is not propagating the " + "split_factor from input strategies." + ) + + +def test_mm_strategy_plain_shard_still_present( + device_mesh_2d, enable_single_dim_mm_family +): + """Regression: enabling _StridedShard variants must not drop the plain + Shard strategies. The solver still needs those for cases where the upstream + chain hasn't introduced any _StridedShard. + """ + mesh = device_mesh_2d + + lhs = _mk_input_strategy(mesh, [256, 16], [Shard(0), Replicate()]) + rhs = _mk_input_strategy(mesh, [16, 32], [Replicate(), Replicate()]) + + schema = OpSchema( + torch.ops.aten.mm.default, + args_schema=(lhs, rhs), + kwargs_schema={}, + ) + result = get_op_strategy(torch.ops.aten.mm.default, schema) + + has_plain_shard = any( + any( + isinstance(p, Shard) and not isinstance(p, _StridedShard) + for p in s.output_spec.placements + ) + for s in result.strategies + ) + assert has_plain_shard, ( + "Expected at least one plain-Shard output strategy for mm with " + "non-strided inputs." + ) + + +def test_mm_strategy_backward_grad_weight_strided( + device_mesh_2d, enable_single_dim_mm_family +): + """Backward grad-weight mm form: grad_out @ input where both operands + carry _StridedShard on the contracting dim (the flattened batch*seq). + + Pattern in the autograd-generated backward: + grad_out: [B, S, N] -> view -> [B*S, N] -> permute -> [N, B*S] + input : [B, S, K] -> view -> [B*S, K] + mm(permuted_grad_out, flat_input) -> [N, K] + + If both inputs carry _StridedShard on the contracting dim (flat M), + the mm strategy should produce at least one strategy where both inputs + are _StridedShard on the contracting dim and the output is Partial + (the usual contracting-dim pattern, specialized with split_factor). + """ + mesh = device_mesh_2d + split_factor = 8 + flat_m = 32 * split_factor + + # grad_out after permute: [N, M] with _StridedShard(1, sf) on M + grad_out_p = _mk_input_strategy( + mesh, + [32, flat_m], + [Shard(1), _StridedShard(1, split_factor=split_factor)], + ) + # input after flatten: [M, K] with _StridedShard(0, sf) on M + flat_input = _mk_input_strategy( + mesh, + [flat_m, 16], + [Shard(0), _StridedShard(0, split_factor=split_factor)], + ) + + schema = OpSchema( + torch.ops.aten.mm.default, + args_schema=(grad_out_p, flat_input), + kwargs_schema={}, + ) + result = get_op_strategy(torch.ops.aten.mm.default, schema) + + matched = False + for op_spec in result.strategies: + in1 = op_spec.input_specs[0].placements + in2 = op_spec.input_specs[1].placements + out = op_spec.output_spec.placements + # Contracting-dim strided pair produces Partial output. + has_strided_in1 = any( + isinstance(p, _StridedShard) and p.split_factor == split_factor for p in in1 + ) + has_strided_in2 = any( + isinstance(p, _StridedShard) and p.split_factor == split_factor for p in in2 + ) + has_partial = any(p.is_partial() for p in out) + if has_strided_in1 and has_strided_in2 and has_partial: + matched = True + break + + assert matched, ( + "Expected at least one backward-mm strategy with _StridedShard on " + f"both contracting inputs (sf={split_factor}) and Partial output. " + "Phase 1 extension is not propagating _StridedShard through the " + "contracting-dim pattern." + )