Skip to content

Add LP-relaxation and TRW-S approximate sharding solvers#484

Open
AlbedoWang wants to merge 29 commits into
mainfrom
kaijian/approx_solver_pr
Open

Add LP-relaxation and TRW-S approximate sharding solvers#484
AlbedoWang wants to merge 29 commits into
mainfrom
kaijian/approx_solver_pr

Conversation

@AlbedoWang

Copy link
Copy Markdown

Summary

Adds two faster alternatives to the exact ILP for the sharding-placement solve:

  • Approximate solver (TRW-S). The ILP's flow constraints make the only free variables the per-node output strategies, so the problem reduces to a pairwise MRF E(x) = Σ_v U_v(x_v) + Σ_(u,v) B_uv(x_u,x_v) over coupled groups. We solve it with sequential
    tree-reweighted message passing (TRW-S), which replaces plain loopy min-sum BP (that settled into globally-inconsistent fixed points), then polish with group coordinate + star-block local search. The solver reuses the strategies / decision-vars / constraints already
    built by ShardingOptimizer (it replaces only the solve, not problem construction) and writes its assignment back into the PuLP variables, so it is scored with the exact same objective as the ILP. ~20-30x faster than CBC at 2D +0.00% / 3D <0.82% objective gap.
  • LP relaxation. solve_lp_relaxation() plus an LP-based solve used directly: the relaxation is empirically integral for this problem, so it matches the ILP optimum while skipping branch-and-bound, with an optimality-check that logs the certified gap.

It also adds memory-budget support to the approximate path via Lagrangian relaxation (the parameter-memory budget is node-separable, so it folds into the unaries; a scalar λ-bisection plus a budget-constrained polish lands on the integer optimum), and build-time
speedups shared by all solvers (parallel decision-var cost, node-level cluster_links, DecisionVar slots, skip enumeration redistribute-cost, lighter no-PuLP build for the approximate solver).

Authored with Claude.

API

AutoParallel(..., solver="ilp" | "approx" | "lp") selects how the optimizer is built and solved; optimize_placement(solver=..., optimality_check=...) can override the solve.

How to Test

  • Unit tests: python -m pytest tests/ (fake PG, no GPU needed)
    • tests/test_approximate_sharding.py — TRW-S objective vs ILP, faithfulness, input/output-constraint honoring, lite-build parity, memory-constrained (Lagrangian) vs ILP.
    • tests/test_lp_relaxation.py — LP relaxation matches the ILP optimum and certifies the example search.
    • tests/test_optimize_placement.py — solver selection.
  • tests/conftest.py adds an autouse fixture clearing the process-global placement-options cache between tests (fixes cross-test contamination).

Notes for reviewers

  • Suggested reading order: LP relaxation -> approximate/TRW-S solver -> Lagrangian memory solve -> build speedups.
  • This branch also carries some EXPERIMENTAL, off-by-default scaffolding (Shardy-like annotation propagation, a DP-solver stub). It is opt-in only and not part of the supported solve path; it is documented as experimental and the default flow never invokes it.
  • Qwen3 model + examples are shipped in a separate PR and are excluded here.

AlbedoWang added 28 commits May 26, 2026 11:37
Record optimizer setup and solve profiling in ShardingOptimizer, add a contributor pipeline document, and include the profiling result artifacts used to inspect LLaMA and Qwen behavior.

Authored with Claude.
Adds LP-relaxation lower-bound plumbing and initial DP topology construction coverage, while removing generated profile/log artifacts from tracking and ignoring future outputs.\n\nAuthored with Claude.
Snapshot the current working tree before adding the approximate sharding
solver. Tracks the scratch _bench_lp_3d.py benchmark and adds *.pdf to
.gitignore so reference papers stay out of git history.

Authored with Claude.
Adds a heuristic alternative to the ILP for the placement problem,
formulated as pairwise MRF energy minimization on the strategy DAG and
solved with a sequential min-sum belief propagation over coupled groups,
followed by coordinate-descent and star-block local search. The energy is
an exact transcription of the ILP objective, so the assignment is scored
identically and the gap is small (LP-certified within ~3-8% on LLaMA3 1B),
while the solve runs ~10x faster than CBC and works on 3D meshes where the
ILP is intractable. Exposed via optimize_placement(solver="approx").

Review order: optimize_sharding.py (idempotent _set_objective) and api.py
(solver dispatch) are the integration points; approximate_sharding.py is the
solver; test_approximate_sharding.py checks the objective gap, energy
faithfulness, and flow feasibility against the ILP.

Authored with Claude.
The optimizer build (strategy enumeration, decision vars, PuLP variables and
constraints) dominates end-to-end time, especially on 3D meshes where it
constructs ~14M PuLP variables and ~6M constraints that the approximate solver
never needs. Two result-preserving changes cut build time:

- Hoist the per-node _all_input_nodes / producer-strategy lookups out of the
  inner decision-var loops (they were recomputed once per decision variable,
  ~14M times on 3D); this also speeds up the ILP build.
- Add ShardingOptimizer(build_pulp=False), selected via
  AutoParallel(solver="approx"), which skips PuLP variable and constraint
  construction entirely. The approximate solver then derives the constraint
  topology directly from the graph + cluster links + constraint log
  (_topology_direct), verified byte-identical to parsing the PuLP constraints.

On LLaMA3 1B the build drops ~2.1x (2D) and ~3.3x (3D, ~13min -> ~4min) with
byte-identical placements; 3D end-to-end goes ~17min -> ~5min.
test_lite_build_matches_full guards the equivalence.

Authored with Claude.
The sharding ILP's LP relaxation is naturally integral, so CBC reaches the
optimum at the root with zero branch-and-bound. The solve time was dominated
by CBC's integer preprocessing churning through hundreds of thousands of
binary columns, ~30% of which are invalid (infinite-cost) strategy edges the
optimizer materialized only to immediately constrain to zero.

To review, start with optimize_sharding.py: _build_decision_vars now computes
each edge's cost up front and only creates a variable when it is finite,
recording the survivors in _valid_keys. The constraint builders and
_create_pulp_variables tolerate the pruned keys (a missing key is an empty,
i.e. zero, term), the same-output and flow constraints key explicitly by
output index instead of relying on positional alignment, and
add_inf_cost_constraint becomes a no-op for fresh builds. _solve then passes
"preprocess off" to CBC. serialization.py seeds _valid_keys on load so saved
optimizers match freshly built ones, and test_optimize_placement.py adds a
regression test for the invariant.

On LLaMA3-1B with a 2D mesh this drops the problem from 476176 to 335390
variables and 173442 to 29643 constraints, and the solve from ~66s to ~11s,
with the objective unchanged (48449.3483).

Authored with Claude.
…LP search space

Users can express a tensor-parallel plan as a few annotations and have it propagated through the graph, turning the unambiguous part into ILP constraints while leaving the genuine cost tradeoffs (FSDP/data axis, residual sequence-parallelism, collective placement) to the solver.

Review in this order: propagation.py introduces the propagation engine (per-mesh-axis, reshard-free, worklist fixpoint with priority rounds, pinning only Shard placements so the optimum stays reachable); optimize_sharding.py adds the primitives it emits -- per-axis node constraints (add_node_axis_constraint, with method="fix" that prunes decision variables instead of adding equality rows), memory-budget awareness of per-axis-pinned params, and solve_lp_relaxation for diagnosing/short-circuiting the solve; api.py exposes the user-facing annotate_* and propagate_annotations entry points. Then tests, example, and docs.

On LLaMA3-1B (2D mesh) the annotated path reaches the same objective as the full ILP on a ~36% smaller search space and solves faster. The LP relaxation is integral on this problem, so solve_lp_relaxation(extract=True) gives an even larger, exact speedup.

Authored with Claude.
…merge_joint_opt

# Conflicts:
#	autoparallel/api.py
#	autoparallel/optimize_sharding.py
#	examples/example_llama3.py
Make the approximate (dp) solver work with the pruned search space and the
propagated per-axis annotations, the two pieces neither branch had on its own:

- Pruning removes infinite-cost edges from decision_vars entirely, so the
  approx solver must treat a key absent from decision_vars as forbidden
  (_is_forbidden) and read per-strategy costs from any surviving inp_idx
  (_surviving_dv). Applied across the forbidden checks and decision_var reads.
- Replay add_node_axis_constraint from _constraint_log in both the PuLP and
  the lite (no-PuLP) topology paths so propagated Shard pins restrict the
  approx solver's per-node out_idx domain (method="fix" leaves no PuLP row).
- Port the forward param-dtype constraint (current main) into _topology_direct
  so the lite build matches the full build exactly under mixed precision.
- Guard _fix_node_output_indices / add_node_axis_constraint against pruned
  (None) variables and the lite build.

Authored with Claude.
A loaded optimizer (ShardingOptimizer.load) is built via __new__ and never ran
the dp_solver init-time profiling, so resolve()/get_solution() -> _log_solve_profile
hit a missing self.profile. Guard the solve profiler to no-op without init
timings, and initialize profile/build_pulp/_node_axis_constraints/_fixed_vars in
load_optimizer so loaded optimizers carry the full attribute set.

Authored with Claude.
The LP relaxation lower bound must include the parameter-memory budget, or it
bounds a different (unconstrained) problem and reads below the true ILP optimum.
With the fix the LP bound equals the exact constrained optimum on LLaMA3-1B,
making it a tight optimality certificate (used for the 3D gap, where the ILP is
intractable).

Authored with Claude.
_bench_merge.py compares the four configurations (prune ILP, annotated ILP,
prune+dp approx, prune+dp+annotated) on one traced model, reporting per-phase
timings, objectives, the LP-relaxation optimality certificate, and the
acceptance checks. _bench_dp_alone.py isolates the approx-without-prune baseline
(run against the dp_solver checkout) for the dp-alone comparison.

Authored with Claude.
At full-1B 3D scale the PuLP problem has ~8M binary variables (strategy count is
rank x mesh-dims, independent of tensor size and -- via clustering -- of layer
count), so the exact ILP is intractable. _bench_3d_cert.py certifies the merged
gap on full 3D via the LP-relaxation lower bound (tight: it equals the exact
optimum on 2D). _bench_dp_alone.py gains a MERGED flag (annotate+propagate) and
_bench_merge.py a MODEL=small mode.

Authored with Claude.
CBC's simplex on the 8M-variable 3D LP runs for hours; HiGHS solves it in
minutes. Validated on 2D: HiGHS lower bound (72011.5) matches CBC and the exact
ILP optimum to the decimal. The cert now does one full build -> prune+dp +
merged approx objectives + HiGHS LP lower bound -> certified gaps.

Authored with Claude.
…fast build)

Strategy enumeration fills each OpSpec's redistribute_cost via torch's
generate_redistribute_costs (~50% of 3D build time per py-spy), but
_build_decision_vars overwrites every edge with the NCCL-aware
estimate_strategy_comms_cost, and nothing reads the enumeration costs in between
(remove_invalid_configs/keep_unique_configs select on placements/shapes only).
So during build_sharding_metadata we patch torch's _ops.utils.redistribute_cost
to a structure-preserving dummy. Autoparallel's own cost model uses a separate
redistribute_cost and is unaffected. A/B verified byte-identical decision_vars
(dv_hash) and approx objective on tiny + 1B/2D; toggle via AP_FAST_BUILD=0.

Authored with Claude.
…Var slots

create_cluster_links materialized one dict entry per (arg,out,inp) option-tuple
per cluster copy (~120M entries, ~80s, huge memory on 3D), but the mapping is
purely node-level (copy->root, identical option indices) and every consumer
reduced it back to node level. Store cluster_links as {copy_node_idx:
root_node_idx} and reconstruct option keys on demand (_cluster_root_key /
_linked_option_keys / _root_to_copies). Serialization already used the
node-level form on disk. Also @DataClass(slots=True) on DecisionVar (millions of
instances). A/B verified byte-identical decision_vars + objective vs the prior
commit (tiny + 1B/2D); all 50 cluster/serialization/approx/propagation tests pass.

Authored with Claude.
decision_var_build (estimate_strategy_comms_cost over millions of edges) is the
last build bottleneck and is per-node independent. Split _build_decision_vars
into Phase A (compute per-edge costs, fork-parallel) + Phase B (assemble
DecisionVars / PuLP vars, serial). Workers read the optimizer from the
fork-inherited address space (no pickling of the mesh / strategy graph) and
return only primitive cost tuples; the deterministic computation makes the
result byte-identical to serial. Workers fork before any PuLP object exists.

Cumulative build result on LLaMA3-1B 3D (2,4,8): 777s -> 62s (12.5x), now
comparable to the ~50s approximate solve. A/B byte-identical (tiny + 1B/2D);
3D end-to-end objective unchanged (50222.7); all 50 build/approx/serialization/
propagation tests pass. Serial fallback via AP_PARALLEL_BUILD=1.

Authored with Claude.
examples/_bench_sizes.py runs the prune+dp approximate search across LLaMA3
1B/3B/8B/70B on a configurable mesh, reporting end-to-end latency (lite build +
approx solve) and an accuracy reference: the gap of the approximate objective
against a HiGHS LP-relaxation lower bound (the sharding LP is integral, so the
bound equals the exact ILP optimum). Controlled via MODEL/MESH/SEQLEN/ACCURACY/
LP_METHOD env vars.

Authored with Claude.
…ild work

The approximate solver's min-sum belief propagation settled into globally
inconsistent fixed points on the sharding MRF (the undirected factor graph is
loopy: residual and multi-branch reconvergence give ~129 cycles after
clustering), leaving the objective 5-16% above the optimum on 2D and up to 12%
on 3D. The factor graph and objective are faithful and the optimum is
representable (verified against an exact CBC solve on 2D and an integral LP on
3D), so this was purely a solver failure.

_belief_propagation now runs tree-reweighted sequential message passing (TRW-S):
a node ordering induces monotonic chains, each node is reweighted by
1/max(in,out)-degree, and forward/backward half-sweeps send min-sum messages
only along the pass direction. On this integral problem TRW-S converges to the
exact MAP: the bare approx (no annotation) drops to +0.00% on 2D (1B/3B/8B/70B,
matching CBC) and +0.08-0.82% on 3D, ~20-30x faster than solving the LP. The
decoded energy converges in long irregular plateaus, so a fixed sweep budget
(time-bounded) is used rather than an early-stop heuristic; the now-dominated
greedy second candidate is dropped.

Two algorithm-preserving build speedups also land here: validate() skips
cluster-copy nodes (the root covers them), and graph clustering memoizes each
node's op-strategy string instead of rebuilding it per consumer.

Authored with Claude.
Helper scripts used to diagnose and validate the TRW-S fix: factor-graph
faithfulness/representability check, LP integrality check, hyperparameter and
iterated-local-search sweeps, a standalone TRW-S prototype, an annotation
ablation, and a per-phase build profiler.

Authored with Claude.
…inks

examples/_sanity_llama3.py traces LLaMA3, selects a strategy with the approximate
(TRW-S) solver, applies it as DTensor, and trains a fixed random batch for a few
steps on real GPUs over a 2D or 3D mesh, verifying the loss curve descends. Also
removes three dangling loss-curve symlinks left over from an earlier run.

Authored with Claude.
…filing

Adds solver="lp" (use the empirically-integral LP relaxation directly,
skipping branch-and-bound) and an optimality_check option that solves the
LP lower bound and logs the certified gap of the achieved objective. The
sanity script now reports steady-state per-step latency. Tests cover the
LP solver matching the ILP optimum and the gap logging.

Authored with Claude.
The parameter-memory budget Sum_param ratio*x in [low,high] is a single
node-separable linear coupling, so penalizing it by lambda folds lambda*ratio
into the param unaries and leaves the pairwise MRF untouched. A scalar
bisection on lambda drives the achieved memory into the budget, and a
budget-constrained coordinate/star polish recovers integer solutions inside
the (memory, cost) hull that no single lambda reaches. Routed in only for
non-tight budgets; the tight default still uses build-time param pinning.

Authored with Claude.
The Shardy-like annotation propagation (annotate_* / propagate_annotations)
and the DP-based solver are opt-in and off by default: annotations do nothing
unless explicitly propagated before optimize_placement(), and the DP solver is
only reachable via the non-default solver_backend="dp" (not exposed through
AutoParallel) and still raises NotImplementedError. Document them as
experimental / unstable so the default solve path is unambiguous.

Authored with Claude.
…isolation

Excludes the Qwen3 model + examples (shipped in a separate PR) and the scratch
_bench_/_sanity scripts from this branch, and applies black/isort formatting.

Adds an autouse conftest fixture that clears the process-global
placement-options cache before each test, so an optimizer build never reuses
stale strategies from a prior test's model (this otherwise made
test_lp_relaxation fail when run after test_approximate_sharding).

Authored with Claude.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 8, 2026
The parallel decision-var cost build forks workers, but forking a process that
has already initialized CUDA crashes them with "Cannot re-initialize CUDA in
forked subprocess" once they touch the NCCL cost model. Real-GPU runs (example
scripts, torchrun) and CUDA-touching tests hit this. Skip the fork and use the
byte-identical serial path whenever torch.cuda.is_initialized().

Authored with Claude.
@sanketpurandare sanketpurandare self-requested a review June 9, 2026 21:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant