Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1a4d279
Add Qwen3 AutoParallel model and examples
AlbedoWang May 26, 2026
b02ac05
Add sharding optimizer profiling snapshot
AlbedoWang May 28, 2026
2f4f102
Add LP relaxation support for sharding optimizer
AlbedoWang May 28, 2026
ad7ee80
Checkpoint scratch LP benchmark and ignore reference PDFs
AlbedoWang May 29, 2026
6613928
Add approximate belief-propagation sharding solver
AlbedoWang May 30, 2026
d06957f
Speed up optimizer build by skipping PuLP for the approximate solver
AlbedoWang May 30, 2026
17fdb4e
Prune invalid sharding strategies and skip CBC integer preprocessing
AlbedoWang May 30, 2026
238443e
Add sharding annotations with Shardy-like propagation to shrink the I…
AlbedoWang May 30, 2026
ed5bfb3
Merge remote-tracking branch 'origin/kaijian/annotated_search' into k…
AlbedoWang May 31, 2026
ecdb8d5
Merge remote-tracking branch 'origin/kaijian/dp_solver' into kaijian/…
AlbedoWang May 31, 2026
c33c0ef
Integrate prune + dp_solver + annotated into a joint optimization
AlbedoWang May 31, 2026
f7af135
Fix loaded-optimizer resolve() under dp_solver profiling
AlbedoWang May 31, 2026
b767f2d
Apply the memory constraint in get_lower_bound
AlbedoWang May 31, 2026
6fcf844
Add joint-optimization benchmark for LLaMA3 on 2D/3D meshes
AlbedoWang May 31, 2026
e8689cd
Extend benches for 3D: MODEL=small, MERGED flag, LP-bound certificate
AlbedoWang May 31, 2026
523f3aa
Use HiGHS (scipy.linprog) for the 3D LP-bound certificate
AlbedoWang May 31, 2026
fc434d5
Skip enumeration redistribute-cost computation (algorithm-preserving …
AlbedoWang May 31, 2026
c78555a
Store cluster_links node-level (drop per-option expansion) + Decision…
AlbedoWang May 31, 2026
f493ab8
Parallelize decision-var cost computation across forked workers
AlbedoWang May 31, 2026
496e7b3
Add cross-size prune+dp benchmark (latency + LP-relaxation accuracy)
AlbedoWang May 31, 2026
2f06359
Approx solver: replace loopy min-sum BP with TRW-S; skip-clustered bu…
AlbedoWang Jun 1, 2026
2ce86ea
Add approx-solver diagnostic + accuracy benchmarks
AlbedoWang Jun 1, 2026
ba02ea5
Add real-GPU LLaMA3 training sanity check; drop stale loss-curve syml…
AlbedoWang Jun 1, 2026
cbd9575
Drop committed loss-curve/profiling artifacts; ignore png/svg/csv
AlbedoWang Jun 1, 2026
a7ea958
Add LP-relaxation solver, optimality-check gap logging, step-time pro…
AlbedoWang Jun 6, 2026
1435b7b
Approx solver: memory-constrained solve via Lagrangian relaxation
AlbedoWang Jun 7, 2026
99339c3
Mark annotation propagation and DP solver as experimental
AlbedoWang Jun 7, 2026
65c9503
Prepare approx-solver PR: drop qwen3/scratch benches, fix test cache …
AlbedoWang Jun 8, 2026
8eda68d
Fall back to serial build when CUDA is initialized (fix fork crash)
AlbedoWang Jun 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@
*.pyc
*.pyo
*.so
*.log

.mypy_cache/
*.egg-info/

*.pdf
*.png
*.svg
*.csv

build/
dist/
tmp/
out/
profile_results/

.vscode/
293 changes: 290 additions & 3 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from .module_construction import make_parallel_module
from .optimize_sharding import ShardingOptimizer
from .propagation import ShardingAnnotation, ShardingPropagator
from .shardings.placement_options import _get_device_from_mesh
from .tracing import (
_add_unused_params_and_buffers,
Expand Down Expand Up @@ -193,6 +194,11 @@ class AutoParallel:
The meta model is moved to a fake device based on mesh.device_type.
"""

# Selectable solvers. "ilp": exact PuLP/CBC. "approx": heuristic TRW-S
# (light build, no PuLP). "lp": LP relaxation used directly as the solve
# (empirically integral for this problem, so much cheaper than CBC).
SOLVER_CHOICES = ("ilp", "approx", "lp")

def __init__(
self,
model,
Expand All @@ -203,8 +209,20 @@ def __init__(
dynamic: bool = False,
cost_model: Any = "nccl",
repeated_subgraphs: bool = True,
solver: str = "ilp",
):
self.stack = ExitStack()
# The solver chosen here decides how the optimizer is built: "ilp"/"lp"
# build the full PuLP problem (CBC exact solve / LP relaxation solve);
# "approx" builds a lighter optimizer (no PuLP variables/constraints),
# much faster to construct, solved heuristically. optimize_placement(
# solver=...) may override the solve as long as it is compatible with
# this build.
if solver not in self.SOLVER_CHOICES:
raise ValueError(
f"Unknown solver={solver!r}; expected one of {self.SOLVER_CHOICES}"
)
self.solver = solver
self.fake_mode = (
FakeTensorMode()
) # TODO: maybe need to reuse the model's fake mode
Expand Down Expand Up @@ -281,12 +299,15 @@ def __enter__(self):
self.mesh,
force_grad_reduce_in_higher_precision,
repeated_subgraphs=self.repeated_subgraphs,
build_pulp=self.solver in ("ilp", "lp"),
)

self.sharding_optimizer = sharding_optimizer

self.input_constraints = None
self.output_constraints = None
self._annotations: list[tuple[Any, ShardingAnnotation]] = []
self.propagation_result = None

self.active = True

Expand Down Expand Up @@ -356,10 +377,240 @@ def add_output_constraints(self, constraints):
self.sharding_optimizer.add_sharded_output_constraint(constraints)
self.output_constraints = constraints

def optimize_placement(self, verbose=True):
# ---- Sharding annotations (Shardy-like propagation) ----
# EXPERIMENTAL: opt-in only. These have no effect unless you call an
# annotate_* method and then propagate_annotations() before
# optimize_placement(); the default solve path never invokes them. The
# propagation may shrink the search space in ways that move the objective off
# the full-ILP optimum, so treat results as unstable.

def _normalize_placements(self, placements):
"""Pad/validate a placement tuple to mesh.ndim, leaving missing trailing
axes open (``None``)."""
placements = tuple(placements)
if len(placements) > self.mesh.ndim:
raise ValueError(
f"annotation has {len(placements)} placements but mesh has "
f"{self.mesh.ndim} dims"
)
return placements + (None,) * (self.mesh.ndim - len(placements))

def _param_fqn_to_node(self):
from torch._functorch._aot_autograd.fx_utils import get_param_and_grad_nodes

graph = self.sharding_optimizer.graph
return {
desc.target: node
for desc, (node, _grad) in get_param_and_grad_nodes(graph).items()
}

def annotate_parameter(self, fqn, placements, priority=1):
"""Annotate the sharding of one or more parameters.

``fqn`` is a parameter fully-qualified name, optionally a glob pattern
(e.g. ``"layers.*.attention.wq.weight"``) to annotate the matching
parameter in every layer at once. ``placements`` is a tuple of
:class:`Placement` (or ``None`` to leave a mesh axis open — typical for
the data/FSDP axis of a weight). Weights default to a lower priority
than activations so the data-parallel axis wins shared-axis conflicts.
"""
import fnmatch

placements = self._normalize_placements(placements)
fqn_map = self._param_fqn_to_node()
matched = [node for name, node in fqn_map.items() if fnmatch.fnmatch(name, fqn)]
if not matched:
raise ValueError(
f"No parameter matches {fqn!r}. Available parameters: "
f"{sorted(fqn_map)}"
)
for node in matched:
self._annotations.append((node, ShardingAnnotation(placements, priority)))
return matched

def annotate_input(self, idx, placements, priority=0):
"""Annotate the sharding of graph input ``idx``."""
from torch._functorch._aot_autograd.fx_utils import (
get_plain_input_and_grad_nodes,
)

placements = self._normalize_placements(placements)
graph = self.sharding_optimizer.graph
nodes = {
desc.idx: node
for desc, (node, _grad) in get_plain_input_and_grad_nodes(graph).items()
}
if idx not in nodes:
raise ValueError(f"No graph input with index {idx}; have {sorted(nodes)}")
self._annotations.append((nodes[idx], ShardingAnnotation(placements, priority)))
return nodes[idx]

def annotate_output(self, idx, placements, priority=0):
"""Annotate the sharding of graph output ``idx``."""
from torch._functorch._aot_autograd.fx_utils import (
get_plain_output_and_tangent_nodes,
)

placements = self._normalize_placements(placements)
graph = self.sharding_optimizer.graph
nodes = {
desc.idx: node
for desc, (node, _t) in get_plain_output_and_tangent_nodes(graph).items()
}
if idx not in nodes:
raise ValueError(f"No graph output with index {idx}; have {sorted(nodes)}")
self._annotations.append((nodes[idx], ShardingAnnotation(placements, priority)))
return nodes[idx]

def annotate_node(self, node, placements, priority=0):
"""Annotate the sharding of an arbitrary graph node."""
placements = self._normalize_placements(placements)
self._annotations.append((node, ShardingAnnotation(placements, priority)))
return node

def _mirror_annotations_to_backward(self):
"""Build extra propagation seeds on the backward twins of annotated
forward tensors.

A gradient shares the sharding of the value it is the gradient of, so a
forward annotation also pins its twin (parameter->grad, input->grad,
output->tangent). Seeding the twins lets the TP plan propagate through
the backward pass too. These seeds are only used for propagation: the
twins themselves stay unconstrained (handled by the forward/backward
consistency constraints), but their neighbors get determined.
"""
from torch._functorch._aot_autograd.fx_utils import (
get_param_and_grad_nodes,
get_plain_input_and_grad_nodes,
get_plain_output_and_tangent_nodes,
)

graph = self.sharding_optimizer.graph
twin = {}
for _d, (node, grad) in get_param_and_grad_nodes(graph).items():
if grad is not None:
twin[node] = grad
for _d, (node, grad) in get_plain_input_and_grad_nodes(graph).items():
if grad is not None:
twin[node] = grad
for _d, (node, tangent) in get_plain_output_and_tangent_nodes(graph).items():
if tangent is not None:
twin[node] = tangent

mirrored = []
for node, ann in self._annotations:
if node in twin:
mirrored.append((twin[node], ann))
return mirrored

def propagate_annotations(self, verbose=True, aggressive=False, method="fix"):
"""EXPERIMENTAL (opt-in, off by default; may be unstable).

Propagate the registered annotations Shardy-style and turn the
unambiguously-determined nodes into ILP constraints, shrinking the
search space. Returns a :class:`PropagationResult`.

Call this after the ``annotate_*`` / ``add_*_constraint`` calls and
before :meth:`optimize_placement`. The default solve path does not call
this; nothing happens unless you invoke it explicitly.

With ``aggressive=False`` (the default) only genuine ``Shard`` axes are
pinned, which keeps the full-ILP optimum reachable. ``aggressive=True``
also pins ``Replicate`` / ``Partial`` axes for a larger reduction at the
cost of possibly forbidding cheaper reshard placements (e.g. sequence
parallelism), so the objective may move slightly off the optimum.

``method`` is how each pin is enforced: ``"fix"`` (default) removes the
ruled-out decision variables (shrinks the problem; scales best on large
meshes), ``"constraint"`` adds removable ``== 1`` rows instead.
"""
self._assert_entered()
propagator = ShardingPropagator(self.sharding_optimizer)
seeds = self._annotations + self._mirror_annotations_to_backward()
propagator.run(seeds)
self.propagation_result = propagator.apply_to_optimizer(
aggressive=aggressive, method=method
)
if verbose:
logger.info(
"Annotation propagation reduced the output-strategy search "
"space by %.1f%% (%d -> %d) via %d per-axis constraints on %d "
"nodes",
100.0 * self.propagation_result.reduction,
self.propagation_result.strategies_before,
self.propagation_result.strategies_after,
self.propagation_result.axis_constraints,
self.propagation_result.nodes_determined,
)
return self.propagation_result

def optimize_placement(
self,
verbose=True,
solver=None,
approximate_options=None,
optimality_check=False,
):
"""Solve for the optimal placement.

solver selects how the placement is solved (defaults to the solver chosen
at AutoParallel construction):
- "ilp": exact PuLP/CBC solve.
- "approx": heuristic TRW-S ApproximateShardingSolver — trades a small
objective gap for a much faster solve.
- "lp": solve the LP relaxation and use it directly. This problem is
empirically integral, so the relaxation optimum equals the ILP optimum
while skipping branch-and-bound; raises if it comes out fractional.
approximate_options is forwarded as kwargs to the approximate solver
(e.g. candidate_limit, max_sweeps). The requested solver must be
compatible with how the optimizer was built: "ilp"/"lp" need a PuLP
problem (build with solver="ilp" or "lp").

optimality_check: after solving, solve the LP relaxation as a lower bound
and log the certified gap of the achieved objective from the optimum.
Requires a PuLP problem (i.e. an "ilp"/"lp" build).
"""
self._assert_entered()
if solver is None:
solver = self.solver

opt = self.sharding_optimizer
if solver in ("approx", "approximate"):
from .approximate_sharding import ApproximateShardingSolver

approx = ApproximateShardingSolver(opt, **(approximate_options or {}))
self.sharding_placement = approx.get_solution(verbose=verbose)
elif solver == "ilp":
if opt.prob is None:
raise RuntimeError(
"solver='ilp' requires a PuLP problem, but this AutoParallel "
"was constructed without one (e.g. solver='approx'). "
"Construct with solver='ilp' to use the exact solver."
)
self.sharding_placement = opt.get_solution(verbose=False)
elif solver in ("lp", "lp_relax", "lp_relaxation"):
if opt.prob is None:
raise RuntimeError(
"solver='lp' requires a PuLP problem, but this AutoParallel "
"was constructed without one (e.g. solver='approx'). "
"Construct with solver='lp' or 'ilp' to use the LP solver."
)
opt._set_objective()
res = opt.solve_lp_relaxation(verbose=verbose, extract=True)
if res["solution"] is None:
raise RuntimeError(
"solver='lp' requires an integral LP relaxation, but it came "
f"out fractional ({res['n_fractional']}/{res['n_vars']} "
"variables). Use solver='ilp' for an exact integral solve."
)
self.sharding_placement = res["solution"]
else:
raise ValueError(
f"Unknown solver={solver!r}; expected one of {self.SOLVER_CHOICES}"
)

self.sharding_placement = self.sharding_optimizer.get_solution(verbose=False)
if optimality_check:
self._log_optimality_check(solver, verbose=verbose)

if verbose:
logger.info(self.sharding_optimizer.get_log(verbose=True))
Expand All @@ -375,7 +626,10 @@ def optimize_placement(self, verbose=True):
),
)

if self.sharding_optimizer.prob.status == -1:
if (
self.sharding_optimizer.prob is not None
and self.sharding_optimizer.prob.status == -1
):
raise RuntimeError(
"The sharding optimizer could not find a feasible solution. "
"This typically means the user-specified constraints are "
Expand All @@ -387,6 +641,39 @@ def optimize_placement(self, verbose=True):

return self.sharding_placement

def _log_optimality_check(self, solver, verbose=False):
"""Solve the LP relaxation as a lower bound and log the certified gap of
the achieved objective from the optimum. Needs a PuLP problem."""
import pulp

opt = self.sharding_optimizer
if opt.prob is None:
logger.warning(
"optimality_check skipped: solver=%r build has no PuLP problem; "
"construct with solver='ilp' or 'lp' to enable it.",
self.solver,
)
return
achieved = opt._safe_float(pulp.value(opt.prob.objective))
lb_res = opt.get_lower_bound(verbose=verbose)
lb = lb_res.objective
if not lb or lb <= 0 or achieved is None:
logger.warning(
"optimality_check inconclusive: lower_bound=%s achieved=%s",
lb,
achieved,
)
return
gap = (achieved - lb) / lb
logger.info(
"optimality check (solver=%s): objective=%.4f LP lower bound=%.4f "
"=> within %.2f%% of optimum (certified)",
solver,
achieved,
lb,
gap * 100,
)

def _apply_placement_common(self, sharding_placement):
t0 = time.perf_counter()
self._assert_entered()
Expand Down
Loading
Loading