Add LP-relaxation and TRW-S approximate sharding solvers#484
Open
AlbedoWang wants to merge 29 commits into
Open
Add LP-relaxation and TRW-S approximate sharding solvers#484AlbedoWang wants to merge 29 commits into
AlbedoWang wants to merge 29 commits into
Conversation
Record optimizer setup and solve profiling in ShardingOptimizer, add a contributor pipeline document, and include the profiling result artifacts used to inspect LLaMA and Qwen behavior. Authored with Claude.
Adds LP-relaxation lower-bound plumbing and initial DP topology construction coverage, while removing generated profile/log artifacts from tracking and ignoring future outputs.\n\nAuthored with Claude.
Snapshot the current working tree before adding the approximate sharding solver. Tracks the scratch _bench_lp_3d.py benchmark and adds *.pdf to .gitignore so reference papers stay out of git history. Authored with Claude.
Adds a heuristic alternative to the ILP for the placement problem, formulated as pairwise MRF energy minimization on the strategy DAG and solved with a sequential min-sum belief propagation over coupled groups, followed by coordinate-descent and star-block local search. The energy is an exact transcription of the ILP objective, so the assignment is scored identically and the gap is small (LP-certified within ~3-8% on LLaMA3 1B), while the solve runs ~10x faster than CBC and works on 3D meshes where the ILP is intractable. Exposed via optimize_placement(solver="approx"). Review order: optimize_sharding.py (idempotent _set_objective) and api.py (solver dispatch) are the integration points; approximate_sharding.py is the solver; test_approximate_sharding.py checks the objective gap, energy faithfulness, and flow feasibility against the ILP. Authored with Claude.
The optimizer build (strategy enumeration, decision vars, PuLP variables and constraints) dominates end-to-end time, especially on 3D meshes where it constructs ~14M PuLP variables and ~6M constraints that the approximate solver never needs. Two result-preserving changes cut build time: - Hoist the per-node _all_input_nodes / producer-strategy lookups out of the inner decision-var loops (they were recomputed once per decision variable, ~14M times on 3D); this also speeds up the ILP build. - Add ShardingOptimizer(build_pulp=False), selected via AutoParallel(solver="approx"), which skips PuLP variable and constraint construction entirely. The approximate solver then derives the constraint topology directly from the graph + cluster links + constraint log (_topology_direct), verified byte-identical to parsing the PuLP constraints. On LLaMA3 1B the build drops ~2.1x (2D) and ~3.3x (3D, ~13min -> ~4min) with byte-identical placements; 3D end-to-end goes ~17min -> ~5min. test_lite_build_matches_full guards the equivalence. Authored with Claude.
The sharding ILP's LP relaxation is naturally integral, so CBC reaches the optimum at the root with zero branch-and-bound. The solve time was dominated by CBC's integer preprocessing churning through hundreds of thousands of binary columns, ~30% of which are invalid (infinite-cost) strategy edges the optimizer materialized only to immediately constrain to zero. To review, start with optimize_sharding.py: _build_decision_vars now computes each edge's cost up front and only creates a variable when it is finite, recording the survivors in _valid_keys. The constraint builders and _create_pulp_variables tolerate the pruned keys (a missing key is an empty, i.e. zero, term), the same-output and flow constraints key explicitly by output index instead of relying on positional alignment, and add_inf_cost_constraint becomes a no-op for fresh builds. _solve then passes "preprocess off" to CBC. serialization.py seeds _valid_keys on load so saved optimizers match freshly built ones, and test_optimize_placement.py adds a regression test for the invariant. On LLaMA3-1B with a 2D mesh this drops the problem from 476176 to 335390 variables and 173442 to 29643 constraints, and the solve from ~66s to ~11s, with the objective unchanged (48449.3483). Authored with Claude.
…LP search space Users can express a tensor-parallel plan as a few annotations and have it propagated through the graph, turning the unambiguous part into ILP constraints while leaving the genuine cost tradeoffs (FSDP/data axis, residual sequence-parallelism, collective placement) to the solver. Review in this order: propagation.py introduces the propagation engine (per-mesh-axis, reshard-free, worklist fixpoint with priority rounds, pinning only Shard placements so the optimum stays reachable); optimize_sharding.py adds the primitives it emits -- per-axis node constraints (add_node_axis_constraint, with method="fix" that prunes decision variables instead of adding equality rows), memory-budget awareness of per-axis-pinned params, and solve_lp_relaxation for diagnosing/short-circuiting the solve; api.py exposes the user-facing annotate_* and propagate_annotations entry points. Then tests, example, and docs. On LLaMA3-1B (2D mesh) the annotated path reaches the same objective as the full ILP on a ~36% smaller search space and solves faster. The LP relaxation is integral on this problem, so solve_lp_relaxation(extract=True) gives an even larger, exact speedup. Authored with Claude.
…aijian/merge_joint_opt
…merge_joint_opt # Conflicts: # autoparallel/api.py # autoparallel/optimize_sharding.py # examples/example_llama3.py
Make the approximate (dp) solver work with the pruned search space and the propagated per-axis annotations, the two pieces neither branch had on its own: - Pruning removes infinite-cost edges from decision_vars entirely, so the approx solver must treat a key absent from decision_vars as forbidden (_is_forbidden) and read per-strategy costs from any surviving inp_idx (_surviving_dv). Applied across the forbidden checks and decision_var reads. - Replay add_node_axis_constraint from _constraint_log in both the PuLP and the lite (no-PuLP) topology paths so propagated Shard pins restrict the approx solver's per-node out_idx domain (method="fix" leaves no PuLP row). - Port the forward param-dtype constraint (current main) into _topology_direct so the lite build matches the full build exactly under mixed precision. - Guard _fix_node_output_indices / add_node_axis_constraint against pruned (None) variables and the lite build. Authored with Claude.
A loaded optimizer (ShardingOptimizer.load) is built via __new__ and never ran the dp_solver init-time profiling, so resolve()/get_solution() -> _log_solve_profile hit a missing self.profile. Guard the solve profiler to no-op without init timings, and initialize profile/build_pulp/_node_axis_constraints/_fixed_vars in load_optimizer so loaded optimizers carry the full attribute set. Authored with Claude.
The LP relaxation lower bound must include the parameter-memory budget, or it bounds a different (unconstrained) problem and reads below the true ILP optimum. With the fix the LP bound equals the exact constrained optimum on LLaMA3-1B, making it a tight optimality certificate (used for the 3D gap, where the ILP is intractable). Authored with Claude.
_bench_merge.py compares the four configurations (prune ILP, annotated ILP, prune+dp approx, prune+dp+annotated) on one traced model, reporting per-phase timings, objectives, the LP-relaxation optimality certificate, and the acceptance checks. _bench_dp_alone.py isolates the approx-without-prune baseline (run against the dp_solver checkout) for the dp-alone comparison. Authored with Claude.
At full-1B 3D scale the PuLP problem has ~8M binary variables (strategy count is rank x mesh-dims, independent of tensor size and -- via clustering -- of layer count), so the exact ILP is intractable. _bench_3d_cert.py certifies the merged gap on full 3D via the LP-relaxation lower bound (tight: it equals the exact optimum on 2D). _bench_dp_alone.py gains a MERGED flag (annotate+propagate) and _bench_merge.py a MODEL=small mode. Authored with Claude.
CBC's simplex on the 8M-variable 3D LP runs for hours; HiGHS solves it in minutes. Validated on 2D: HiGHS lower bound (72011.5) matches CBC and the exact ILP optimum to the decimal. The cert now does one full build -> prune+dp + merged approx objectives + HiGHS LP lower bound -> certified gaps. Authored with Claude.
…fast build) Strategy enumeration fills each OpSpec's redistribute_cost via torch's generate_redistribute_costs (~50% of 3D build time per py-spy), but _build_decision_vars overwrites every edge with the NCCL-aware estimate_strategy_comms_cost, and nothing reads the enumeration costs in between (remove_invalid_configs/keep_unique_configs select on placements/shapes only). So during build_sharding_metadata we patch torch's _ops.utils.redistribute_cost to a structure-preserving dummy. Autoparallel's own cost model uses a separate redistribute_cost and is unaffected. A/B verified byte-identical decision_vars (dv_hash) and approx objective on tiny + 1B/2D; toggle via AP_FAST_BUILD=0. Authored with Claude.
…Var slots
create_cluster_links materialized one dict entry per (arg,out,inp) option-tuple
per cluster copy (~120M entries, ~80s, huge memory on 3D), but the mapping is
purely node-level (copy->root, identical option indices) and every consumer
reduced it back to node level. Store cluster_links as {copy_node_idx:
root_node_idx} and reconstruct option keys on demand (_cluster_root_key /
_linked_option_keys / _root_to_copies). Serialization already used the
node-level form on disk. Also @DataClass(slots=True) on DecisionVar (millions of
instances). A/B verified byte-identical decision_vars + objective vs the prior
commit (tiny + 1B/2D); all 50 cluster/serialization/approx/propagation tests pass.
Authored with Claude.
decision_var_build (estimate_strategy_comms_cost over millions of edges) is the last build bottleneck and is per-node independent. Split _build_decision_vars into Phase A (compute per-edge costs, fork-parallel) + Phase B (assemble DecisionVars / PuLP vars, serial). Workers read the optimizer from the fork-inherited address space (no pickling of the mesh / strategy graph) and return only primitive cost tuples; the deterministic computation makes the result byte-identical to serial. Workers fork before any PuLP object exists. Cumulative build result on LLaMA3-1B 3D (2,4,8): 777s -> 62s (12.5x), now comparable to the ~50s approximate solve. A/B byte-identical (tiny + 1B/2D); 3D end-to-end objective unchanged (50222.7); all 50 build/approx/serialization/ propagation tests pass. Serial fallback via AP_PARALLEL_BUILD=1. Authored with Claude.
examples/_bench_sizes.py runs the prune+dp approximate search across LLaMA3 1B/3B/8B/70B on a configurable mesh, reporting end-to-end latency (lite build + approx solve) and an accuracy reference: the gap of the approximate objective against a HiGHS LP-relaxation lower bound (the sharding LP is integral, so the bound equals the exact ILP optimum). Controlled via MODEL/MESH/SEQLEN/ACCURACY/ LP_METHOD env vars. Authored with Claude.
…ild work The approximate solver's min-sum belief propagation settled into globally inconsistent fixed points on the sharding MRF (the undirected factor graph is loopy: residual and multi-branch reconvergence give ~129 cycles after clustering), leaving the objective 5-16% above the optimum on 2D and up to 12% on 3D. The factor graph and objective are faithful and the optimum is representable (verified against an exact CBC solve on 2D and an integral LP on 3D), so this was purely a solver failure. _belief_propagation now runs tree-reweighted sequential message passing (TRW-S): a node ordering induces monotonic chains, each node is reweighted by 1/max(in,out)-degree, and forward/backward half-sweeps send min-sum messages only along the pass direction. On this integral problem TRW-S converges to the exact MAP: the bare approx (no annotation) drops to +0.00% on 2D (1B/3B/8B/70B, matching CBC) and +0.08-0.82% on 3D, ~20-30x faster than solving the LP. The decoded energy converges in long irregular plateaus, so a fixed sweep budget (time-bounded) is used rather than an early-stop heuristic; the now-dominated greedy second candidate is dropped. Two algorithm-preserving build speedups also land here: validate() skips cluster-copy nodes (the root covers them), and graph clustering memoizes each node's op-strategy string instead of rebuilding it per consumer. Authored with Claude.
Helper scripts used to diagnose and validate the TRW-S fix: factor-graph faithfulness/representability check, LP integrality check, hyperparameter and iterated-local-search sweeps, a standalone TRW-S prototype, an annotation ablation, and a per-phase build profiler. Authored with Claude.
…inks examples/_sanity_llama3.py traces LLaMA3, selects a strategy with the approximate (TRW-S) solver, applies it as DTensor, and trains a fixed random batch for a few steps on real GPUs over a 2D or 3D mesh, verifying the loss curve descends. Also removes three dangling loss-curve symlinks left over from an earlier run. Authored with Claude.
Authored with Claude.
…filing Adds solver="lp" (use the empirically-integral LP relaxation directly, skipping branch-and-bound) and an optimality_check option that solves the LP lower bound and logs the certified gap of the achieved objective. The sanity script now reports steady-state per-step latency. Tests cover the LP solver matching the ILP optimum and the gap logging. Authored with Claude.
The parameter-memory budget Sum_param ratio*x in [low,high] is a single node-separable linear coupling, so penalizing it by lambda folds lambda*ratio into the param unaries and leaves the pairwise MRF untouched. A scalar bisection on lambda drives the achieved memory into the budget, and a budget-constrained coordinate/star polish recovers integer solutions inside the (memory, cost) hull that no single lambda reaches. Routed in only for non-tight budgets; the tight default still uses build-time param pinning. Authored with Claude.
The Shardy-like annotation propagation (annotate_* / propagate_annotations) and the DP-based solver are opt-in and off by default: annotations do nothing unless explicitly propagated before optimize_placement(), and the DP solver is only reachable via the non-default solver_backend="dp" (not exposed through AutoParallel) and still raises NotImplementedError. Document them as experimental / unstable so the default solve path is unambiguous. Authored with Claude.
…isolation Excludes the Qwen3 model + examples (shipped in a separate PR) and the scratch _bench_/_sanity scripts from this branch, and applies black/isort formatting. Adds an autouse conftest fixture that clears the process-global placement-options cache before each test, so an optimizer build never reuses stale strategies from a prior test's model (this otherwise made test_lp_relaxation fail when run after test_approximate_sharding). Authored with Claude.
The parallel decision-var cost build forks workers, but forking a process that has already initialized CUDA crashes them with "Cannot re-initialize CUDA in forked subprocess" once they touch the NCCL cost model. Real-GPU runs (example scripts, torchrun) and CUDA-touching tests hit this. Skip the fork and use the byte-identical serial path whenever torch.cuda.is_initialized(). Authored with Claude.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds two faster alternatives to the exact ILP for the sharding-placement solve:
E(x) = Σ_v U_v(x_v) + Σ_(u,v) B_uv(x_u,x_v)over coupled groups. We solve it with sequentialtree-reweighted message passing (TRW-S), which replaces plain loopy min-sum BP (that settled into globally-inconsistent fixed points), then polish with group coordinate + star-block local search. The solver reuses the strategies / decision-vars / constraints already
built by
ShardingOptimizer(it replaces only the solve, not problem construction) and writes its assignment back into the PuLP variables, so it is scored with the exact same objective as the ILP. ~20-30x faster than CBC at 2D +0.00% / 3D <0.82% objective gap.solve_lp_relaxation()plus an LP-based solve used directly: the relaxation is empirically integral for this problem, so it matches the ILP optimum while skipping branch-and-bound, with an optimality-check that logs the certified gap.It also adds memory-budget support to the approximate path via Lagrangian relaxation (the parameter-memory budget is node-separable, so it folds into the unaries; a scalar λ-bisection plus a budget-constrained polish lands on the integer optimum), and build-time
speedups shared by all solvers (parallel decision-var cost, node-level
cluster_links,DecisionVarslots, skip enumeration redistribute-cost, lighter no-PuLP build for the approximate solver).Authored with Claude.
API
AutoParallel(..., solver="ilp" | "approx" | "lp")selects how the optimizer is built and solved;optimize_placement(solver=..., optimality_check=...)can override the solve.How to Test
python -m pytest tests/(fake PG, no GPU needed)tests/test_approximate_sharding.py— TRW-S objective vs ILP, faithfulness, input/output-constraint honoring, lite-build parity, memory-constrained (Lagrangian) vs ILP.tests/test_lp_relaxation.py— LP relaxation matches the ILP optimum and certifies the example search.tests/test_optimize_placement.py— solver selection.tests/conftest.pyadds an autouse fixture clearing the process-global placement-options cache between tests (fixes cross-test contamination).Notes for reviewers