Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions PLAN_dtensor_native_linear.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Enable Native `view → mm → view` in AutoParallel via DTensor Strided Sharding

## Summary

AutoParallel currently rewrites PyTorch's `view → mm → view` decomposition of `nn.Linear` into `einsum` (see `_APPLY_VIEW_MM_VIEW_PATTERN` in `autoparallel/api.py`). That workaround was introduced in AP #26/#424 because DTensor's view ops could not faithfully propagate sharding across flatten→mm→unflatten.

DTensor has since gained native support for this via `_StridedShard` placements and the `mm_single_dim_strategy` path (upstream pytorch PR #172385). AutoParallel, however, does not reach that path — it uses the legacy `register_op_strategy` mm rule and explicitly strips `_StridedShard` from its placeholder expansion.

This PR wires AP up to use the upstream single-dim mm path, enumerates `_StridedShard` variants from upstream input strategies, and fixes the `is_shard()`-miss bugs in AP's local-shape/FLOP/validity checks. Benchmarks on LLaMA3-8B confirm the change is a strict win on both solver time and solver objective; it also unblocks 32-layer configs that the einsum path cannot solve in a reasonable time.

## Headline Result

Benchmarked on LLaMA3-8B at PR #424-class config (`dim=4096, seqlen=8192, 64-rank 8×8 fake-PG mesh`, `cost_model=nccl`, single-H100, fake collectives):

| Scale | Solver time | Solver objective (NCCL cost proxy) |
|---|---|---|
| LLaMA3-8B **2-layer** | NATIVE 45.7s vs EINSUM 76.1s (**-40%**) | NATIVE 57576 vs EINSUM 57761 (**-0.32% cheaper**) |
| LLaMA3-8B **32-layer** | NATIVE **29.5 min** vs EINSUM **>4 h (timed out)** | NATIVE 520184, EINSUM unknown |

- Objectives reproducible across seeds 0 and 1 (solver is deterministic given the graph).
- EINSUM's strategy-space-per-node is ~1.5-2× larger (einsum `bsk,kn->bsn` has 4 axes vs. mm `mk,kn->mn` with 3), making ILP scaling superlinearly worse at depth.
- `_StridedShard` never appears in the chosen strategies for the LLaMA3-8B configs tested. Phase 1's `_StridedShard` enumeration is correct when dormant and ready when exercised by other workloads.

## What's Done

### 1. Route mm-family ops through the single-dim path (opt-in)

In `autoparallel/shardings/dtensor_sharding_helpers.py`:
- Added `_PREFER_SINGLE_DIM_OPS = {mm, addmm, bmm, baddbmm, _scaled_mm}`.
- Added `ENABLE_SINGLE_DIM_MM_FAMILY: bool = False` (opt-in toggle).
- `get_op_strategy` now prefers the upstream single-dim path for those ops **when the flag is True**, bypassing the legacy `op_strategy_funcs` that otherwise shadows it. Default behavior is unchanged.

To opt in, set `dtensor_sharding_helpers.ENABLE_SINGLE_DIM_MM_FAMILY = True` before constructing `AutoParallel`, or use the `enable_single_dim_mm_family` pytest fixture in new tests.

### 2. Enumerate `_StridedShard` variants in placeholder expansion

`_try_single_dim_strategy` collects candidate `split_factor`s from upstream input OpStrategies and emits `Shard(d)` plus one `_StridedShard(d, sf)` per candidate `sf` for every `_ShardingPlaceholder` slot. Previous plain-`Shard`-only behavior is preserved when no input carries `_StridedShard`.

### 3. Add `is_shard_like()` helper and fix `is_shard()`-miss bugs

`_StridedShard.is_shard()` returns `False`, which caused several AP call sites to silently treat `_StridedShard` dims as unsharded (over-counting FLOPs, wrong local shapes, keeping invalid strategies). Fixed by:

- New `is_shard_like(p)` helper in `shardings/dtensor_sharding_helpers.py`.
- Applied at:
- `apply_sharding.py:_localize_shape_arg` — local shape was not being divided by mesh_size for `_StridedShard` dims.
- `cost_models/compute_estimation.py:_get_sharded_shape_stride` — over-counted FLOPs for strided strategies.
- `shardings/propagation_rules.py:remove_invalid_configs` (strategy-shape validity), LayerNorm fwd/bwd reduction-axis checks, `aten.pad` trailing-dim removal.
- `shardings/placement_options.py` — flex_attention Q/KV dim validity adjustment.

### 4. Tests

Three new tests in `tests/test_propagation_rules.py`:
- `test_mm_strategy_enumerates_strided_shard` — `_StridedShard`-bearing input yields `_StridedShard`-bearing output with matching `split_factor`.
- `test_mm_strategy_plain_shard_still_present` — regression: plain-Shard inputs do not spuriously produce `_StridedShard` outputs.
- `test_mm_strategy_backward_grad_weight_strided` — backward mm with `_StridedShard` on both contracting-dim inputs yields strategies with Partial output.

All existing tests in `tests/test_optimize_placement.py` (11 tests) pass with both `_APPLY_VIEW_MM_VIEW_PATTERN = True` and `False`. The three new tests also pass in both configurations.

### 5. End-to-end numerical correctness

`pytorch/agent_space/numerical_check_linear3d.py` runs a small 3-D Linear model through AP with both flag values and compares forward output to a single-device reference: **max abs diff = 0.000e+00** in all pairwise comparisons.

## What's Next

1. **Review + merge this PR**, which lands the routing + `_StridedShard` enumeration behind `ENABLE_SINGLE_DIM_MM_FAMILY = False`. Zero default-behavior change.
2. **Real training throughput with `compile=True` and a real multi-rank setup**, with the flag flipped to `True`. The solver objective (NCCL-cost proxy) is already cheaper on the single-dim path; this would confirm step-time parity or improvement. Out of scope for this PR.
3. **Flip `ENABLE_SINGLE_DIM_MM_FAMILY = True`** as the default in a follow-up PR once step-time is confirmed.
4. **Flip `_APPLY_VIEW_MM_VIEW_PATTERN = False`** as the default (separate toggle, but naturally pairs with step 3 for Linear workloads).
5. **Remove `_replace_view_mm_view_with_einsum`** and its pattern matchers in `autoparallel/graph_passes/graph_utils.py` after a release with no regressions.

## Not in Scope

- Making PyTorch stop decomposing `nn.Linear` (separate upstream effort; the TODO at `autoparallel/graph_passes/graph_utils.py:247` points to it).
- `nn.Bilinear`, scaled_dot_product_attention, or other non-mm matmul paths that don't go through the view-flatten.
- `is_shard()`-miss sites in `cost_models/collective_runtime_estimation.py` (lines 128, 146, 176, 194, 235): those are gated behind upstream `redistribute_cost` which returns `inf` for any `_StridedShard`-involving transition, so the solver avoids them regardless. Worth cleaning up later for defense-in-depth.

## Known Caveats

1. **Conservative `_StridedShard` redistribute cost** (`torch/distributed/tensor/_collective_utils.py:535-536`): returns `inf` for any transition between specs where one has `shard_order=None` (true for all `_StridedShard` specs). This means the solver cannot cross-redistribute between strided and non-strided mid-graph — acceptable for the view→mm→view chain (end-to-end zero-cost match), restrictive for graphs that need it elsewhere.
2. **Strategy-space blow-up** from enumerating `_StridedShard(sf)` variants is bounded because sf is drawn only from upstream-observed split_factors. Empirically no impact on LLaMA3-8B solve times.
3. **`_StridedShard` not exercised by LLaMA3-8B at tested hyperparameters**. The solver did not choose strided strategies even when enumerated; NATIVE beats EINSUM on solver time and objective without them. The capability remains useful for workloads that do exercise it (e.g. `[Shard(batch), Shard(seq)]` input on a 2-D mesh where batch×seq sharding interleaves).

## Artifacts

Code changes:
- `autoparallel/shardings/dtensor_sharding_helpers.py` — `_PREFER_SINGLE_DIM_OPS`, `is_shard_like`, extended `_try_single_dim_strategy`, updated `get_op_strategy`.
- `autoparallel/apply_sharding.py`, `autoparallel/cost_models/compute_estimation.py`, `autoparallel/shardings/propagation_rules.py`, `autoparallel/shardings/placement_options.py` — `is_shard_like` adoption.
- `tests/test_propagation_rules.py` — three new mm-strategy tests.

Validation scripts (not part of the PR, under `pytorch/agent_space/`):
- `repro_mm_strided.py` — pre-change comparison (upstream single-dim vs. legacy `_mm_like_strategy`).
- `verify_ap_mm_strided.py` — post-change verification on synthetic schemas.
- `bench_llama3_8b.py`, `bench_llama3_8b_einsum_only.py` — full LLaMA3-8B benchmark.
- `numerical_check_linear3d.py` — end-to-end forward numerical check.
5 changes: 4 additions & 1 deletion autoparallel/apply_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.utils._pytree import tree_flatten, tree_map_only

from .graph_passes.graph_utils import all_input_nodes, cleanup_graph
from .shardings.dtensor_sharding_helpers import is_shard_like
from .shardings.ordered_sharding import (
compute_optimal_placement_order_for_parameters,
ordered_redistribute_local_tensor,
Expand Down Expand Up @@ -56,8 +57,10 @@ def _localize_shape_arg(node, shape_arg, output_spec):
"""
global_shape = _concretize_shape(node.meta["val"].shape)
local_shape = list(global_shape)
# is_shard_like covers _StridedShard, whose .is_shard() returns False even
# though it shards the dim (same local shape; split_factor affects layout only).
for mesh_size, placement in zip(output_spec.mesh.shape, output_spec.placements):
if placement.is_shard():
if is_shard_like(placement):
local_shape[placement.dim] = local_shape[placement.dim] // mesh_size
# Restore SymInt values from the interpreter (already local)
for i, s in enumerate(shape_arg):
Expand Down
6 changes: 5 additions & 1 deletion autoparallel/cost_models/compute_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def _get_device_gmem_bandwidth():


def _get_sharded_shape_stride(spec):
from autoparallel.shardings.dtensor_sharding_helpers import is_shard_like

mesh = spec.mesh
tensor_shape = spec.tensor_meta.shape
# TODO: take dtype into account as well
Expand All @@ -292,8 +294,10 @@ def _get_sharded_shape_stride(spec):
# running DTensor
new_tensor_shape = list(tensor_shape)
new_tensor_stride = list(spec.tensor_meta.stride)
# is_shard_like covers _StridedShard, which shards the dim by mesh_size
# (same local shape as Shard); split_factor affects data layout only.
for mesh_size, placement in zip(mesh.shape, placements):
if placement.is_shard():
if is_shard_like(placement):
dim = placement.dim
new_tensor_shape[dim] = (new_tensor_shape[dim] + mesh_size - 1) // mesh_size
if dim - 1 > 0:
Expand Down
2 changes: 1 addition & 1 deletion autoparallel/graph_passes/graph_pp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def _post_fwd_common(

stage.fwd_cache[mb_index] = (output_tuple, saved_intermediates) # type: ignore[assignment]

stage._validate_fwd_outputs(output_tuple)
stage._validate_fwd_outputs(output_tuple) # type: ignore[attr-defined]

schedule._maybe_compute_loss(stage, output, ctx.target_mbs, mb_index)

Expand Down
92 changes: 88 additions & 4 deletions autoparallel/shardings/dtensor_sharding_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
_clear_fast_path_sharding_prop_cache,
_clear_python_sharding_prop_cache,
)
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
from torch.distributed.tensor.placement_types import (
Placement,
Replicate,
Shard,
_StridedShard,
)

try:
from torch.utils._cxx_pytree import tree_leaves
Expand All @@ -42,6 +47,40 @@
# reference to existing sharding_propagator DTensor upstream
propagator = DTensor._op_dispatcher.sharding_propagator


def is_shard_like(p: Placement) -> bool:
"""Whether placement shards a tensor dim. True for Shard and _StridedShard.

DTensor's Placement.is_shard() returns False for _StridedShard because the
latter subclasses StridedShard (a sibling of Shard) rather than Shard. Code
that conceptually asks "is this dim sharded?" should use this helper so
strategies carrying _StridedShard aren't silently treated as unsharded.
"""
return p.is_shard() or isinstance(p, _StridedShard)


# Ops where AP can route to the single-dim strategy path (with _StridedShard
# variant enumeration in _try_single_dim_strategy) instead of the legacy
# register_op_strategy path. Gated by ENABLE_SINGLE_DIM_MM_FAMILY so the new
# behavior is opt-in; legacy _mm_like_strategy remains the default.
_PREFER_SINGLE_DIM_OPS: frozenset = frozenset(
{
aten.mm.default,
aten.addmm.default,
aten.bmm.default,
aten.baddbmm.default,
aten._scaled_mm.default,
}
)

# When True, route mm/addmm/bmm/baddbmm/_scaled_mm through the upstream
# single-dim strategy path, which emits _StridedShard variants from observed
# input split_factors. Benchmark on LLaMA3-8B shows this is cheaper on solver
# time and objective vs. the legacy _mm_like_strategy path (see
# PLAN_dtensor_native_linear.md). Default False to keep default behavior
# unchanged; flip True at AP entry points or in user code to opt in.
ENABLE_SINGLE_DIM_MM_FAMILY: bool = False

enable_implicit_replication = False
_current_stack = None

Expand Down Expand Up @@ -263,9 +302,11 @@ def _extract_spec(arg: object) -> object:
return arg.strategies[0].output_spec
if isinstance(arg, TupleStrategy):
return [
child.strategies[0].output_spec
if isinstance(child, OpStrategy)
else child
(
child.strategies[0].output_spec
if isinstance(child, OpStrategy)
else child
)
for child in arg.children
]
return arg
Expand Down Expand Up @@ -294,11 +335,43 @@ def _extract_spec(arg: object) -> object:
strategies = _insert_single_dim_replication_strategy(
strategies, num_outputs, num_inputs
)
# Candidate split_factors drawn from upstream input strategies. Each distinct
# split_factor seen on any OpSpec across any input becomes an additional
# _StridedShard variant for every _ShardingPlaceholder slot. This matches the
# provenance from flatten ops: the upstream view rule emits _StridedShard with
# a split_factor determined by the flattened dim sizes. Bounded this way, the
# enumeration stays small (empirically 1-2 sfs per mm node).
candidate_sfs: set[int] = set()
for arg in op_schema.args_strategy:
for op_spec in arg.strategies:
for p in op_spec.output_spec.placements:
if isinstance(p, _StridedShard):
candidate_sfs.add(p.split_factor)

resolved: list[list[Placement | None]] = []
for s in strategies:
has_placeholder = any(isinstance(p, _ShardingPlaceholder) for p in s)
if not has_placeholder:
# No placeholders, so every element is already Placement | None.
# The list comprehension narrows the element type for mypy.
resolved.append([p for p in s if not isinstance(p, _ShardingPlaceholder)])
continue
# Plain Shard variant (original behavior).
resolved.append(
[Shard(p.dim) if isinstance(p, _ShardingPlaceholder) else p for p in s]
)
# One _StridedShard variant per candidate split_factor.
for sf in candidate_sfs:
resolved.append(
[
(
_StridedShard(p.dim, split_factor=sf)
if isinstance(p, _ShardingPlaceholder)
else p
)
for p in s
]
)

result = expand_to_full_mesh_op_strategy(
mesh,
Expand All @@ -325,6 +398,17 @@ def _extract_spec(arg: object) -> object:
def get_op_strategy(op: torch._ops.OpOverload, op_schema: OpSchema) -> StrategyType:
global enable_implicit_replication, _current_stack

# Opt-in: route mm-family ops through the single-dim path so _StridedShard
# variants get enumerated (see _PREFER_SINGLE_DIM_OPS / ENABLE_SINGLE_DIM_MM_FAMILY).
if (
ENABLE_SINGLE_DIM_MM_FAMILY
and op in _PREFER_SINGLE_DIM_OPS
and op in propagator.op_single_dim_strategy_funcs
):
single_dim_result = _try_single_dim_strategy(op, op_schema)
if single_dim_result is not None:
return single_dim_result

if op not in propagator.op_strategy_funcs:
# Check single-dim strategies (newer upstream DTensor registration path)
single_dim_result = _try_single_dim_strategy(op, op_schema)
Expand Down
16 changes: 11 additions & 5 deletions autoparallel/shardings/placement_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@

from autoparallel.shardings.propagation_rules import generate_dummy_redistribute_costs

from .dtensor_sharding_helpers import get_op_strategy, with_implicit_strategies
from .dtensor_sharding_helpers import (
get_op_strategy,
is_shard_like,
with_implicit_strategies,
)
from .propagation_rules import _op_rules, remove_invalid_configs

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -286,9 +290,11 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs):
op,
tuple(_fingerprint_arg(s) for s in specs),
tuple(_fingerprint_arg(a) for a in user_args),
tuple(_fingerprint_arg(v) for v in user_kwargs.values())
if user_kwargs
else (),
(
tuple(_fingerprint_arg(v) for v in user_kwargs.values())
if user_kwargs
else ()
),
)
hash(cache_key) # fail fast if key contains unhashable types (e.g. SymInts)
except TypeError:
Expand Down Expand Up @@ -557,7 +563,7 @@ def tensor_placement(t, placement):
dim_to_ref = {0: B, 1: H}
adjusted = []
for mesh_dim, p in enumerate(placement):
if p.is_shard() and p.dim in dim_to_ref:
if is_shard_like(p) and p.dim in dim_to_ref:
t_size = t.shape[p.dim]
ref_size = dim_to_ref[p.dim]
mesh_dim_size = mesh.shape[mesh_dim]
Expand Down
10 changes: 5 additions & 5 deletions autoparallel/shardings/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

# need to import this to have the dtype_cast registered
from ..cast_parametrization import dtype_cast # noqa
from .dtensor_sharding_helpers import get_op_strategy
from .dtensor_sharding_helpers import get_op_strategy, is_shard_like

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -174,7 +174,7 @@ def remove_invalid_configs(out_strat, mesh):
continue
shape = list(spec.tensor_meta.shape)
for mesh_shape, plc in zip(mesh.shape, spec.placements):
if plc.is_shard():
if is_shard_like(plc):
dim = plc.dim
if shape[dim] % mesh_shape == 0:
shape[dim] //= mesh_shape
Expand Down Expand Up @@ -549,7 +549,7 @@ def native_layer_norm_rule(mesh, op_schema):
for strategy in output_strategy.strategies:
is_valid = True
for plc in strategy.input_specs[0].placements:
if plc.is_shard() and plc.dim >= axis:
if is_shard_like(plc) and plc.dim >= axis:
is_valid = False
break
if is_valid:
Expand Down Expand Up @@ -623,7 +623,7 @@ def native_layer_norm_backward_rule(mesh, op_schema):
is_valid = True
input_spec = strategy.input_specs[1]
for plc in input_spec.placements:
if plc.is_shard() and plc.dim >= axis:
if is_shard_like(plc) and plc.dim >= axis:
is_valid = False
break
if is_valid:
Expand Down Expand Up @@ -699,7 +699,7 @@ def constant_pad_nd_rule(mesh, op_schema):
for idx, strat in enumerate(out_strat.strategies):
remove_this = False
for plc in strat.output_specs.placements:
if plc.is_shard() and plc.dim in dims_to_remove:
if is_shard_like(plc) and plc.dim in dims_to_remove:
to_remove.append(idx)
remove_this = True
break
Expand Down
Loading
Loading