diff --git a/.gitignore b/.gitignore index bcaae24d..4936ecca 100644 --- a/.gitignore +++ b/.gitignore @@ -4,13 +4,20 @@ *.pyc *.pyo *.so +*.log .mypy_cache/ *.egg-info/ +*.pdf +*.png +*.svg +*.csv + build/ dist/ tmp/ out/ +profile_results/ .vscode/ diff --git a/autoparallel/api.py b/autoparallel/api.py index 1670d509..5e99d7d4 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -44,6 +44,7 @@ ) from .module_construction import make_parallel_module from .optimize_sharding import ShardingOptimizer +from .propagation import ShardingAnnotation, ShardingPropagator from .shardings.placement_options import _get_device_from_mesh from .tracing import ( _add_unused_params_and_buffers, @@ -193,6 +194,11 @@ class AutoParallel: The meta model is moved to a fake device based on mesh.device_type. """ + # Selectable solvers. "ilp": exact PuLP/CBC. "approx": heuristic TRW-S + # (light build, no PuLP). "lp": LP relaxation used directly as the solve + # (empirically integral for this problem, so much cheaper than CBC). + SOLVER_CHOICES = ("ilp", "approx", "lp") + def __init__( self, model, @@ -203,8 +209,20 @@ def __init__( dynamic: bool = False, cost_model: Any = "nccl", repeated_subgraphs: bool = True, + solver: str = "ilp", ): self.stack = ExitStack() + # The solver chosen here decides how the optimizer is built: "ilp"/"lp" + # build the full PuLP problem (CBC exact solve / LP relaxation solve); + # "approx" builds a lighter optimizer (no PuLP variables/constraints), + # much faster to construct, solved heuristically. optimize_placement( + # solver=...) may override the solve as long as it is compatible with + # this build. + if solver not in self.SOLVER_CHOICES: + raise ValueError( + f"Unknown solver={solver!r}; expected one of {self.SOLVER_CHOICES}" + ) + self.solver = solver self.fake_mode = ( FakeTensorMode() ) # TODO: maybe need to reuse the model's fake mode @@ -281,12 +299,15 @@ def __enter__(self): self.mesh, force_grad_reduce_in_higher_precision, repeated_subgraphs=self.repeated_subgraphs, + build_pulp=self.solver in ("ilp", "lp"), ) self.sharding_optimizer = sharding_optimizer self.input_constraints = None self.output_constraints = None + self._annotations: list[tuple[Any, ShardingAnnotation]] = [] + self.propagation_result = None self.active = True @@ -356,10 +377,240 @@ def add_output_constraints(self, constraints): self.sharding_optimizer.add_sharded_output_constraint(constraints) self.output_constraints = constraints - def optimize_placement(self, verbose=True): + # ---- Sharding annotations (Shardy-like propagation) ---- + # EXPERIMENTAL: opt-in only. These have no effect unless you call an + # annotate_* method and then propagate_annotations() before + # optimize_placement(); the default solve path never invokes them. The + # propagation may shrink the search space in ways that move the objective off + # the full-ILP optimum, so treat results as unstable. + + def _normalize_placements(self, placements): + """Pad/validate a placement tuple to mesh.ndim, leaving missing trailing + axes open (``None``).""" + placements = tuple(placements) + if len(placements) > self.mesh.ndim: + raise ValueError( + f"annotation has {len(placements)} placements but mesh has " + f"{self.mesh.ndim} dims" + ) + return placements + (None,) * (self.mesh.ndim - len(placements)) + + def _param_fqn_to_node(self): + from torch._functorch._aot_autograd.fx_utils import get_param_and_grad_nodes + + graph = self.sharding_optimizer.graph + return { + desc.target: node + for desc, (node, _grad) in get_param_and_grad_nodes(graph).items() + } + + def annotate_parameter(self, fqn, placements, priority=1): + """Annotate the sharding of one or more parameters. + + ``fqn`` is a parameter fully-qualified name, optionally a glob pattern + (e.g. ``"layers.*.attention.wq.weight"``) to annotate the matching + parameter in every layer at once. ``placements`` is a tuple of + :class:`Placement` (or ``None`` to leave a mesh axis open — typical for + the data/FSDP axis of a weight). Weights default to a lower priority + than activations so the data-parallel axis wins shared-axis conflicts. + """ + import fnmatch + + placements = self._normalize_placements(placements) + fqn_map = self._param_fqn_to_node() + matched = [node for name, node in fqn_map.items() if fnmatch.fnmatch(name, fqn)] + if not matched: + raise ValueError( + f"No parameter matches {fqn!r}. Available parameters: " + f"{sorted(fqn_map)}" + ) + for node in matched: + self._annotations.append((node, ShardingAnnotation(placements, priority))) + return matched + + def annotate_input(self, idx, placements, priority=0): + """Annotate the sharding of graph input ``idx``.""" + from torch._functorch._aot_autograd.fx_utils import ( + get_plain_input_and_grad_nodes, + ) + + placements = self._normalize_placements(placements) + graph = self.sharding_optimizer.graph + nodes = { + desc.idx: node + for desc, (node, _grad) in get_plain_input_and_grad_nodes(graph).items() + } + if idx not in nodes: + raise ValueError(f"No graph input with index {idx}; have {sorted(nodes)}") + self._annotations.append((nodes[idx], ShardingAnnotation(placements, priority))) + return nodes[idx] + + def annotate_output(self, idx, placements, priority=0): + """Annotate the sharding of graph output ``idx``.""" + from torch._functorch._aot_autograd.fx_utils import ( + get_plain_output_and_tangent_nodes, + ) + + placements = self._normalize_placements(placements) + graph = self.sharding_optimizer.graph + nodes = { + desc.idx: node + for desc, (node, _t) in get_plain_output_and_tangent_nodes(graph).items() + } + if idx not in nodes: + raise ValueError(f"No graph output with index {idx}; have {sorted(nodes)}") + self._annotations.append((nodes[idx], ShardingAnnotation(placements, priority))) + return nodes[idx] + + def annotate_node(self, node, placements, priority=0): + """Annotate the sharding of an arbitrary graph node.""" + placements = self._normalize_placements(placements) + self._annotations.append((node, ShardingAnnotation(placements, priority))) + return node + + def _mirror_annotations_to_backward(self): + """Build extra propagation seeds on the backward twins of annotated + forward tensors. + + A gradient shares the sharding of the value it is the gradient of, so a + forward annotation also pins its twin (parameter->grad, input->grad, + output->tangent). Seeding the twins lets the TP plan propagate through + the backward pass too. These seeds are only used for propagation: the + twins themselves stay unconstrained (handled by the forward/backward + consistency constraints), but their neighbors get determined. + """ + from torch._functorch._aot_autograd.fx_utils import ( + get_param_and_grad_nodes, + get_plain_input_and_grad_nodes, + get_plain_output_and_tangent_nodes, + ) + + graph = self.sharding_optimizer.graph + twin = {} + for _d, (node, grad) in get_param_and_grad_nodes(graph).items(): + if grad is not None: + twin[node] = grad + for _d, (node, grad) in get_plain_input_and_grad_nodes(graph).items(): + if grad is not None: + twin[node] = grad + for _d, (node, tangent) in get_plain_output_and_tangent_nodes(graph).items(): + if tangent is not None: + twin[node] = tangent + + mirrored = [] + for node, ann in self._annotations: + if node in twin: + mirrored.append((twin[node], ann)) + return mirrored + + def propagate_annotations(self, verbose=True, aggressive=False, method="fix"): + """EXPERIMENTAL (opt-in, off by default; may be unstable). + + Propagate the registered annotations Shardy-style and turn the + unambiguously-determined nodes into ILP constraints, shrinking the + search space. Returns a :class:`PropagationResult`. + + Call this after the ``annotate_*`` / ``add_*_constraint`` calls and + before :meth:`optimize_placement`. The default solve path does not call + this; nothing happens unless you invoke it explicitly. + + With ``aggressive=False`` (the default) only genuine ``Shard`` axes are + pinned, which keeps the full-ILP optimum reachable. ``aggressive=True`` + also pins ``Replicate`` / ``Partial`` axes for a larger reduction at the + cost of possibly forbidding cheaper reshard placements (e.g. sequence + parallelism), so the objective may move slightly off the optimum. + + ``method`` is how each pin is enforced: ``"fix"`` (default) removes the + ruled-out decision variables (shrinks the problem; scales best on large + meshes), ``"constraint"`` adds removable ``== 1`` rows instead. + """ + self._assert_entered() + propagator = ShardingPropagator(self.sharding_optimizer) + seeds = self._annotations + self._mirror_annotations_to_backward() + propagator.run(seeds) + self.propagation_result = propagator.apply_to_optimizer( + aggressive=aggressive, method=method + ) + if verbose: + logger.info( + "Annotation propagation reduced the output-strategy search " + "space by %.1f%% (%d -> %d) via %d per-axis constraints on %d " + "nodes", + 100.0 * self.propagation_result.reduction, + self.propagation_result.strategies_before, + self.propagation_result.strategies_after, + self.propagation_result.axis_constraints, + self.propagation_result.nodes_determined, + ) + return self.propagation_result + + def optimize_placement( + self, + verbose=True, + solver=None, + approximate_options=None, + optimality_check=False, + ): + """Solve for the optimal placement. + + solver selects how the placement is solved (defaults to the solver chosen + at AutoParallel construction): + - "ilp": exact PuLP/CBC solve. + - "approx": heuristic TRW-S ApproximateShardingSolver — trades a small + objective gap for a much faster solve. + - "lp": solve the LP relaxation and use it directly. This problem is + empirically integral, so the relaxation optimum equals the ILP optimum + while skipping branch-and-bound; raises if it comes out fractional. + approximate_options is forwarded as kwargs to the approximate solver + (e.g. candidate_limit, max_sweeps). The requested solver must be + compatible with how the optimizer was built: "ilp"/"lp" need a PuLP + problem (build with solver="ilp" or "lp"). + + optimality_check: after solving, solve the LP relaxation as a lower bound + and log the certified gap of the achieved objective from the optimum. + Requires a PuLP problem (i.e. an "ilp"/"lp" build). + """ self._assert_entered() + if solver is None: + solver = self.solver + + opt = self.sharding_optimizer + if solver in ("approx", "approximate"): + from .approximate_sharding import ApproximateShardingSolver + + approx = ApproximateShardingSolver(opt, **(approximate_options or {})) + self.sharding_placement = approx.get_solution(verbose=verbose) + elif solver == "ilp": + if opt.prob is None: + raise RuntimeError( + "solver='ilp' requires a PuLP problem, but this AutoParallel " + "was constructed without one (e.g. solver='approx'). " + "Construct with solver='ilp' to use the exact solver." + ) + self.sharding_placement = opt.get_solution(verbose=False) + elif solver in ("lp", "lp_relax", "lp_relaxation"): + if opt.prob is None: + raise RuntimeError( + "solver='lp' requires a PuLP problem, but this AutoParallel " + "was constructed without one (e.g. solver='approx'). " + "Construct with solver='lp' or 'ilp' to use the LP solver." + ) + opt._set_objective() + res = opt.solve_lp_relaxation(verbose=verbose, extract=True) + if res["solution"] is None: + raise RuntimeError( + "solver='lp' requires an integral LP relaxation, but it came " + f"out fractional ({res['n_fractional']}/{res['n_vars']} " + "variables). Use solver='ilp' for an exact integral solve." + ) + self.sharding_placement = res["solution"] + else: + raise ValueError( + f"Unknown solver={solver!r}; expected one of {self.SOLVER_CHOICES}" + ) - self.sharding_placement = self.sharding_optimizer.get_solution(verbose=False) + if optimality_check: + self._log_optimality_check(solver, verbose=verbose) if verbose: logger.info(self.sharding_optimizer.get_log(verbose=True)) @@ -375,7 +626,10 @@ def optimize_placement(self, verbose=True): ), ) - if self.sharding_optimizer.prob.status == -1: + if ( + self.sharding_optimizer.prob is not None + and self.sharding_optimizer.prob.status == -1 + ): raise RuntimeError( "The sharding optimizer could not find a feasible solution. " "This typically means the user-specified constraints are " @@ -387,6 +641,39 @@ def optimize_placement(self, verbose=True): return self.sharding_placement + def _log_optimality_check(self, solver, verbose=False): + """Solve the LP relaxation as a lower bound and log the certified gap of + the achieved objective from the optimum. Needs a PuLP problem.""" + import pulp + + opt = self.sharding_optimizer + if opt.prob is None: + logger.warning( + "optimality_check skipped: solver=%r build has no PuLP problem; " + "construct with solver='ilp' or 'lp' to enable it.", + self.solver, + ) + return + achieved = opt._safe_float(pulp.value(opt.prob.objective)) + lb_res = opt.get_lower_bound(verbose=verbose) + lb = lb_res.objective + if not lb or lb <= 0 or achieved is None: + logger.warning( + "optimality_check inconclusive: lower_bound=%s achieved=%s", + lb, + achieved, + ) + return + gap = (achieved - lb) / lb + logger.info( + "optimality check (solver=%s): objective=%.4f LP lower bound=%.4f " + "=> within %.2f%% of optimum (certified)", + solver, + achieved, + lb, + gap * 100, + ) + def _apply_placement_common(self, sharding_placement): t0 = time.perf_counter() self._assert_entered() diff --git a/autoparallel/approximate_sharding.py b/autoparallel/approximate_sharding.py new file mode 100644 index 00000000..3e102da5 --- /dev/null +++ b/autoparallel/approximate_sharding.py @@ -0,0 +1,1552 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +""" +Approximate sharding solver. + +The ILP in :mod:`optimize_sharding` selects, for every operation, an output +placement and (per argument) the input placement of its producer. The flow +constraint forces a consumer's input placement to equal its producer's chosen +output placement, so the only genuinely free variables are the per-node output +strategy indices ``x_v``. The problem therefore reduces to a pairwise discrete +energy minimization over a DAG:: + + E(x) = Σ_v U_v(x_v) + Σ_{(u,v)} B_{uv}(x_u, x_v) + +where ``U_v`` is the compute cost and ``B_{uv}`` is the communication + +sharding-transition cost on the edge from producer ``u`` to consumer ``v``. + +This is a pairwise MRF. The autograd DAG has small in-degree (<3) but large +out-degree (tens) and a wide topological frontier (hundreds), so exact +frontier/junction-tree DP blows up. We instead solve it with **min-sum belief +propagation** (max-product in min-sum form) on the graph of *coupled groups*, +which propagates coordinated decisions globally, then polish with group-level +coordinate descent and a star-block local search. + +Nodes that must be chosen jointly are merged into groups: repeated-subgraph +cluster copies share a strategy index, and forward/backward pairs share an +output placement. The solver reuses the strategies, decision variables and +constraints already built by ``ShardingOptimizer`` (it replaces only the +CBC/ILP *solve*, not problem construction) and writes its assignment back into +the PuLP variables, so the result is scored with the exact same objective as the +ILP (``pulp.value(prob.objective)``). +""" + +import logging +import math +import time +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Optional + +import numpy as np +import pulp +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Replicate, Shard + +from .cost_models.compute_estimation import _get_sharded_shape_stride + +logger = logging.getLogger(__name__) + +INF = float("inf") +BIG = 1e12 # finite stand-in for forbidden combinations (avoids NaN in min-sum) + +# Paired forward/backward constraints couple two nodes to the *same output +# placement* (the strategy index may differ between the two strategy lists). +_PAIRED_PREFIXES = ( + "grad_param_constraint", + "grad_input_constraint", + "grad_output_constraint", +) + + +@dataclass +class ApproximateSolveResult: + objective: float + status: str + build_s: float + solve_s: float + total_s: float + num_groups: int + num_nodes: int + + +@dataclass +class _Group: + """A set of node indices chosen jointly (cluster copies share a strategy + index; forward/backward pairs share an output placement).""" + + members: list[int] + cost_bearing: list[int] = field(default_factory=list) + choices: list[dict[int, int]] = field(default_factory=list) # member -> out_idx + current: int = 0 + + @property + def domain(self) -> int: + return len(self.choices) + + +class _UnionFind: + def __init__(self, n: int): + self.parent = list(range(n)) + + def find(self, x: int) -> int: + root = x + while self.parent[root] != root: + root = self.parent[root] + while self.parent[x] != root: + self.parent[x], x = root, self.parent[x] + return root + + def union(self, a: int, b: int) -> None: + ra, rb = self.find(a), self.find(b) + if ra != rb: + self.parent[rb] = ra + + +class ApproximateShardingSolver: + """Approximate solver for the sharding placement problem on an already-built + :class:`ShardingOptimizer`. + + Call :meth:`get_solution` for a ``{node: OpSpec}`` dict (same format as + ``ShardingOptimizer.get_solution``); it also fills the PuLP variables and + ``optimizer.selected_keys`` so the assignment can be scored/inspected exactly + like an ILP solution. + """ + + def __init__( + self, + optimizer, + candidate_limit: Optional[int] = 64, + bp_iters: int = 400, + bp_tol: float = 1e-3, + max_sweeps: int = 12, + max_time_s: float = 60.0, + star_passes: int = 2, + max_star_children: int = 32, + group_domain_limit: int = 512, + ): + self.opt = optimizer + self.candidate_limit = candidate_limit + self.bp_iters = bp_iters + self.bp_tol = bp_tol + self.max_sweeps = max_sweeps + self.max_time_s = max_time_s + self.star_passes = star_passes + self.max_star_children = max_star_children + self.group_domain_limit = group_domain_limit + + # Populated by _build_problem(). + self.cost_bearing: list[int] = [] + self.node_mult: dict[int, int] = {} + self.forbidden: set[tuple] = set() + self.allowed_out: dict[int, list[int]] = {} + self.groups: list[_Group] = [] + self.node_to_group: dict[int, int] = {} + self.input_edges: dict[int, list[tuple[int, int]]] = {} + self._arg_prod: dict[int, dict[int, int]] = {} + self.consumers: dict[int, list[tuple[int, int]]] = defaultdict(list) + self.cur_out: dict[int, int] = {} + self._memory: Optional[dict[str, Any]] = None + # When False, the hard memory-budget checks in local search are skipped + # (used by the Lagrangian solve, which enforces the budget softly via a + # penalty folded into the unaries instead). + self._mem_enforce: bool = True + self._mem_unary: list[np.ndarray] = [] + + # Populated by _build_factors(). + self.g_unary: list[np.ndarray] = [] + self.C: dict[tuple, np.ndarray] = {} + self.nbrs: list[list[int]] = [] + + # ------------------------------------------------------------------ # + # Public entry point + # ------------------------------------------------------------------ # + def get_solution(self, verbose: bool = False): + result, solution = self._solve(verbose=verbose) + self.result = result + return solution + + def _solve(self, verbose: bool = False): + opt = self.opt + if getattr(opt, "solver_backend", "ilp") != "ilp": + raise RuntimeError( + "ApproximateShardingSolver requires an ILP-built optimizer " + "(decision_vars / pulp_variables / constraints)." + ) + t0 = time.perf_counter() + self._build_problem() + t_bp = time.perf_counter() + self._build_factors() + t_bf = time.perf_counter() + t_build = t_bf - t0 + if verbose: + logger.info( + "approx build: problem=%.2fs %s factors=%.2fs groups=%d " + "cost_bearing=%d edges=%d max_domain=%d", + t_bp - t0, + getattr(self, "_build_times", {}), + t_bf - t_bp, + len(self.groups), + len(self.cost_bearing), + sum(len(v) for v in self.input_edges.values()), + max((g.domain for g in self.groups), default=0), + ) + + deadline = t0 + self.max_time_s + # TRW-S init, then local-search polish. TRW-S reaches the exact MAP on the + # (integral) sharding problem, so the old greedy second candidate it used + # to be compared against is strictly dominated and has been dropped; the + # polish remains for the memory budget and as a local-search safety net. + t_bp0 = time.perf_counter() + mem = self._memory + if mem is not None and not mem.get("tight"): + # A non-tight budget can bind the runtime-optimal placement; solve it + # exactly via Lagrangian relaxation (folds λ·ratio into the unaries). + # A tight budget is already handled by build-time param pinning, and + # the no-memory case has nothing to relax, so both take the plain path. + res = self.solve_lagrangian( + mem["budget_low"], + mem["budget_high"], + deadline=deadline, + verbose=verbose, + ) + if verbose: + logger.info( + "approx phase: lagrangian lam=%.4g memory=%.4f feasible=%s", + res["lam"], + res["memory"], + res["feasible"], + ) + else: + self._belief_propagation(deadline) + if verbose: + logger.info( + "approx phase: trws iter=%s delta=%.4g in %.2fs; " + "decode energy=%.1f", + getattr(self, "_bp_last_iter", None), + getattr(self, "_bp_last_delta", float("nan")), + time.perf_counter() - t_bp0, + self._fast_total_energy(), + ) + self._memory_repair() + self._coordinate_descent(deadline) + self._star_block_search(deadline) + bp_energy = self._fast_total_energy() + if verbose: + logger.info("approx phase: polished energy=%.1f", bp_energy) + t_solve = time.perf_counter() - t0 - t_build + + objective = self._write_back() + total_s = time.perf_counter() - t0 + infeasible = not math.isfinite(objective) + status = "Infeasible" if infeasible else "Heuristic" + result = ApproximateSolveResult( + objective=objective, + status=status, + build_s=t_build, + solve_s=t_solve, + total_s=total_s, + num_groups=len(self.groups), + num_nodes=len(self.cost_bearing), + ) + logger.info( + "ApproximateShardingSolver: status=%s objective=%.4f " + "(trws+polish=%.1f) groups=%d nodes=%d " + "timings={build=%.3fs,solve=%.3fs,total=%.3fs}", + status, + objective, + bp_energy, + len(self.groups), + len(self.cost_bearing), + t_build, + t_solve, + total_s, + ) + opt.profile["approximate"] = { + "objective": objective, + "status": status, + "build_s": t_build, + "solve_s": t_solve, + "total_s": total_s, + "groups": len(self.groups), + "bp_energy": bp_energy, + } + if infeasible: + raise RuntimeError( + "ApproximateShardingSolver could not find a feasible assignment. " + "User constraints may be contradictory or the mesh too small." + ) + solution = opt._to_orig_solution(opt._extract_and_validate_solution()) + return result, solution + + # ------------------------------------------------------------------ # + # Problem construction + # ------------------------------------------------------------------ # + def _build_problem(self): + opt = self.opt + # cluster_links is node-level: copy node idx -> root node idx. + cluster_linked = set(opt.cluster_links) + self.cost_bearing = [ + opt.node_map[node] + for node in opt.strats + if node.op != "output" and opt.node_map[node] not in cluster_linked + ] + + root_to_copies: dict[int, set] = defaultdict(set) + for copy_idx, root_idx in opt.cluster_links.items(): + root_to_copies[root_idx].add(copy_idx) + self.node_mult = { + v: 1 + len(root_to_copies.get(v, ())) for v in self.cost_bearing + } + + self.allowed_out = {} + for node, strat in opt.strats.items(): + if node.op == "output": + continue + self.allowed_out[opt.node_map[node]] = list(range(len(strat.strategies))) + + t = time.perf_counter() + if opt.prob is None: + # Lite build: no PuLP problem was constructed, derive topology directly. + paired_edges, authoritative = self._topology_direct() + else: + paired_edges, authoritative = self._parse_constraints() + # Flow edges are taken from the ILP's output_input_consistent constraints + # (the authoritative producer per consumer-arg), NOT from _all_input_nodes: + # the two disagree for some ops (einsum list-args, alias/backward nodes), + # and trusting _all_input_nodes yields flow-infeasible assignments. The + # producer here is the (possibly cluster-resolved) node carrying the + # producer's pulp variable; the ILP guarantees its out_idx range matches + # the consumer's inp_idx range for that arg. + self._arg_prod: dict[int, dict[int, int]] = defaultdict(dict) + flow_couplings = [] # producer sets forced to share an out_idx + for (c_idx, argi), producers in authoritative.items(): + rep = min(producers) # all coupled -> same out, any representative + self._arg_prod[c_idx][argi] = rep + if len(producers) > 1: + flow_couplings.append(producers) + self.input_edges = {} + self.consumers = defaultdict(list) + for v in self.cost_bearing: + edges = sorted(self._arg_prod.get(v, {}).items()) + self.input_edges[v] = edges + for argi, p in edges: + self.consumers[p].append((v, argi)) + t_parse = time.perf_counter() + + # Remove fully-forbidden out_idx for cost-bearing nodes. + for v in self.cost_bearing: + node = opt.nodes[v] + self.allowed_out[v] = [ + o + for o in self.allowed_out[v] + if not self._out_fully_forbidden(v, node, o) + ] + t_forbid = time.perf_counter() + + self._build_memory_info() # also pins params when the budget is tight + t_mem = time.perf_counter() + self._build_groups(paired_edges, flow_couplings) + t_groups = time.perf_counter() + self._prune_candidates() + self._build_times = { + "parse": t_parse - t, + "forbid": t_forbid - t_parse, + "memory": t_mem - t_forbid, + "groups": t_groups - t_mem, + "prune": time.perf_counter() - t_groups, + } + + # Constraint families that never restrict the per-node out_idx domain and + # are handled structurally (flow/uniqueness) or via the cost sentinel below. + # Skipping them by name avoids materializing items() for the ~majority of the + # (often >100k) constraints. + _SKIP_PREFIXES = ( + "unique_decision", + "same_across_args", + "inf_cases", + "memory_constraint", + ) + + def _parse_constraints(self): + opt = self.opt + # inf-cost keys are forced to 0 by add_inf_cost_constraint, which also + # stamps dv.cost = 10000.0. Detect them directly instead of parsing the + # (very numerous) inf_cases constraints. + for key, dv in opt.decision_vars.items(): + if dv.cost == 10000.0: + self.forbidden.add(key) + + var_to_key = {var: key for key, var in opt.pulp_variables.items()} + restrict: dict[int, set] = {} + paired_edges: list[tuple[int, int, frozenset]] = [] + # (consumer_idx, argi) -> set of producer_idx, from flow constraints. A + # clustered consumer's single inp variable is shared across all its + # copies, so the ILP couples one producer per copy (resolved to its root) + # to that inp, forcing them all equal; we collect the whole set. + authoritative: dict[tuple[int, int], set] = {} + for name, c in opt.prob.constraints.items(): + if name.startswith("output_input_consistent"): + # +side = producer (grouped by out), -side = consumer (grouped by + # inp at a fixed arg). One +var and one -var pin down the edge. + pos_key = neg_key = None + for var, coeff in c.items(): + k = var_to_key.get(var) + if k is None: + continue + if coeff > 0: + pos_key = pos_key or k + else: + neg_key = neg_key or k + if pos_key is not None and neg_key is not None: + break + if pos_key is not None and neg_key is not None: + authoritative.setdefault((neg_key[0], neg_key[1]), set()).add( + pos_key[0] + ) + continue + if name.startswith(self._SKIP_PREFIXES): + continue + items = list(c.items()) + if not items: + continue + rhs = -c.constant + coeffs = [coeff for _, coeff in items] + keys = [var_to_key.get(var) for var, _ in items] + if any(k is None for k in keys): + continue + all_pos = all(coeff > 0 for coeff in coeffs) + if c.sense == pulp.LpConstraintEQ and rhs == 0 and all_pos: + self.forbidden.update(keys) # Σ vars == 0 (inf / dtype / disable) + elif c.sense == pulp.LpConstraintEQ and rhs == 1 and all_pos: + nodes = {k[0] for k in keys} + if len(nodes) == 1: + n = next(iter(nodes)) + out_set = {k[2] for k in keys} + restrict[n] = restrict.get(n, out_set) & out_set + elif ( + c.sense == pulp.LpConstraintEQ + and rhs == 0 + and any(name.startswith(p) for p in _PAIRED_PREFIXES) + and "disable" not in name + ): + pos = {k for k, coeff in zip(keys, coeffs) if coeff > 0} + neg = {k for k, coeff in zip(keys, coeffs) if coeff < 0} + na, nb = {k[0] for k in neg}, {k[0] for k in pos} + oa, ob = {k[2] for k in neg}, {k[2] for k in pos} + if len(na) == 1 and len(nb) == 1 and len(oa) == 1 and len(ob) == 1: + paired_edges.append( + ( + next(iter(na)), + next(iter(nb)), + frozenset({(next(iter(oa)), next(iter(ob)))}), + ) + ) + # method="fix" axis pins leave no PuLP row to parse above, so replay the + # log to recover them (constraint-method pins are also picked up here, + # idempotently with their == 1 rows). + for n, out_set in self._axis_restrict_from_log().items(): + restrict[n] = restrict.get(n, out_set) & out_set + for n, out_set in restrict.items(): + if n in self.allowed_out: + self.allowed_out[n] = [o for o in self.allowed_out[n] if o in out_set] + return paired_edges, authoritative + + def _topology_direct(self): + """Compute the same topology (forbidden / out_idx restrictions / paired + edges / flow producers) that _parse_constraints extracts, but directly + from the graph + cluster_links + _constraint_log, WITHOUT a PuLP problem. + This lets the optimizer skip building millions of PuLP variables and + constraints when only the approximate solver is used. + + Mirrors ShardingOptimizer.add_inf_cost_constraint / + add_grad_reduce_dtype_constraints / add_forward_backward_consistency_constraints / + _add_paired_output_constraint / add_node_constraint / + add_output_input_consistent_constraint. Verified byte-identical to + _parse_constraints on a full build (see tests).""" + from torch._functorch._aot_autograd.fx_utils import ( + get_param_and_grad_nodes, + get_plain_input_and_grad_nodes, + get_plain_output_and_tangent_nodes, + ) + + opt = self.opt + cl = opt.cluster_links # node-level: copy node idx -> root node idx + + def rootkey(k): + return opt._cluster_root_key(k) + + cluster_linked = set(cl) + node_root = dict(cl) + + def nroot(idx): + return node_root.get(idx, idx) + + # 1. inf-cost forbidden (== add_inf_cost_constraint). + for key, dv in opt.decision_vars.items(): + if not math.isfinite(dv.cost) or dv.cost == 10000.0: + self.forbidden.add(key) + + # 2a. forward param-dtype forbidden (== add_grad_reduce_dtype_constraints + # forward part, unconditional). Force the FSDP allgather to run after + # a downcasting param dtype_cast (in the smaller param_dtype) by + # forbidding any pre-cast redistribution. + cast_op = torch.ops.autoparallel.dtype_cast.default + fwd_pre_cast: set[int] = set() + for param, _grad in get_param_and_grad_nodes(opt.graph).values(): + n = param + while True: + if n.target == cast_op: + break + users = list(n.users.keys()) + if len(users) != 1: + break + child = users[0] + if len(child.all_input_nodes) != 1: + break + n = child + if n.target != cast_op: + continue + if n.meta["val"].dtype.itemsize >= param.meta["val"].dtype.itemsize: + continue # only constrain downcasts + node = n + while node != param: + if node in opt.node_map: + fwd_pre_cast.add(opt.node_map[node]) + node = node.all_input_nodes[0] + for key, dv in opt.decision_vars.items(): + if key[0] in fwd_pre_cast and dv.comm_cost > 0: + self.forbidden.add(key) + + # 2. grad-reduce-dtype (backward) forbidden + # (== add_grad_reduce_dtype_constraints backward part). + if getattr(opt, "force_grad_reduce_in_higher_precision", False): + cast_op = torch.ops.autoparallel.dtype_cast.default + pre_cast: set[int] = set() + for param, grad in get_param_and_grad_nodes(opt.graph).values(): + if grad is None: + continue + chain = [grad] + n = grad + while len(n.all_input_nodes) == 1: + parent = n.all_input_nodes[0] + if len(parent.all_input_nodes) != 1: + break + chain.append(parent) + n = parent + cast_idx = next( + (i for i, nd in enumerate(chain) if nd.target == cast_op), None + ) + if cast_idx is None: + continue + for nd in chain[cast_idx:]: + if nd in opt.node_map: + pre_cast.add(opt.node_map[nd]) + for key, dv in opt.decision_vars.items(): + if key[0] in pre_cast and dv.comm_cost > 0: + self.forbidden.add(key) + + # 3. forward/backward paired output constraints + disables + # (== add_forward_backward_consistency_constraints / _add_paired_output_constraint). + paired_edges: list[tuple[int, int, frozenset]] = [] + + def add_paired(node_a, node_b): + idx_a, idx_b = opt.node_map[node_a], opt.node_map[node_b] + strat_a = [str(s.output_specs) for s in opt.strats[node_a].strategies] + strat_b = [str(s.output_specs) for s in opt.strats[node_b].strategies] + num_inp_a = len(opt.strats[node_a].strategies[0].redistribute_cost[0]) + for out_idx, sp in enumerate(strat_a): + if sp not in strat_b: + for inp in range(num_inp_a): + self.forbidden.add(rootkey((idx_a, 0, out_idx, inp))) + continue + out_idx_b = strat_b.index(sp) + ra = rootkey((idx_a, 0, out_idx, 0))[0] + rb = rootkey((idx_b, 0, out_idx_b, 0))[0] + paired_edges.append((ra, rb, frozenset({(out_idx, out_idx_b)}))) + + for param, grad in get_param_and_grad_nodes(opt.graph).values(): + if grad is not None: + add_paired(param, grad) + for node, gnode in get_plain_input_and_grad_nodes(opt.graph).values(): + if gnode is not None: + add_paired(node, gnode) + for node, tnode in get_plain_output_and_tangent_nodes(opt.graph).values(): + if tnode is not None: + add_paired(node, tnode) + + # 4. user node/input/output placement restrictions (== add_node_constraint), + # replayed from _constraint_log. + restrict: dict[int, set] = {} + for fname, kwargs in getattr(opt, "_constraint_log", []): + if fname != "add_node_constraint": + continue + node = next( + (nd for nd in opt.nodes if nd.name == kwargs["node_name"]), None + ) + if node is None or node not in opt.strats: + continue + placement = kwargs["placement"] + if placement is None: + placement = (Shard(0),) + (Replicate(),) * (opt.mesh.ndim - 1) + out_set = set() + for i, s in enumerate(opt.strats[node].strategies): + specs = s.output_specs + if isinstance(specs, DTensorSpec): + if specs.placements == placement: + out_set.add(i) + elif isinstance(specs, (list, tuple)): + for spec in specs: + if isinstance(spec, DTensorSpec): + if spec.placements == placement: + out_set.add(i) + break + r = nroot(opt.node_map[node]) + restrict[r] = restrict.get(r, out_set) & out_set + # 4b. per-axis placement restrictions (== add_node_axis_constraint), what + # sharding propagation emits. With method="fix" these leave no PuLP + # row to parse, so replaying the log is the only way the approx solver + # sees the pin. + for r, out_set in self._axis_restrict_from_log().items(): + restrict[r] = restrict.get(r, out_set) & out_set + for n_idx, out_set in restrict.items(): + if n_idx in self.allowed_out: + self.allowed_out[n_idx] = [ + o for o in self.allowed_out[n_idx] if o in out_set + ] + + # 5. flow producers (== add_output_input_consistent_constraint): for each + # consumer-arg, the set of (cluster-resolved) producers feeding it. + authoritative: dict[tuple[int, int], set] = {} + for node in opt.graph.nodes: + if node.op == "output" or node not in opt.node_map: + continue + p_idx = opt.node_map[node] + p_linked = p_idx in cluster_linked + p_root = nroot(p_idx) + for user in node.users: + if user.op == "output" or user not in opt.node_map: + continue + u_idx = opt.node_map[user] + if p_linked and u_idx in cluster_linked: + continue + ain = opt._all_input_nodes(user) + argi = next((i for i, x in enumerate(ain) if x is node), None) + if argi is None: + continue + ispecs = opt.strats[user].strategies[0].input_specs + if argi < len(ispecs) and ispecs[argi] is None: + continue + authoritative.setdefault((nroot(u_idx), argi), set()).add(p_root) + + return paired_edges, authoritative + + def _axis_restrict_from_log(self): + """out_idx restrictions implied by add_node_axis_constraint calls, + replayed from _constraint_log → {root_node_idx: set(out_idx)}. + + This is how the approximate solver honors propagated per-axis pins: keep + only the strategies whose output placement matches the pinned axis, + exactly like ShardingOptimizer.add_node_axis_constraint. It works whether + the pin was applied as a PuLP row ("constraint") or as variable bounds + ("fix", which leaves no row to parse) and in the lite (no-PuLP) build.""" + opt = self.opt + node_root = dict(opt.cluster_links) # node-level: copy idx -> root idx + restrict: dict[int, set] = {} + for fname, kwargs in getattr(opt, "_constraint_log", []): + if fname != "add_node_axis_constraint": + continue + node = next( + (nd for nd in opt.nodes if nd.name == kwargs["node_name"]), None + ) + if node is None or node not in opt.strats: + continue + mesh_dim, placement = kwargs["mesh_dim"], kwargs["placement"] + out_set = set() + for i, s in enumerate(opt.strats[node].strategies): + specs = s.output_specs + if isinstance(specs, DTensorSpec): + spec = specs + elif isinstance(specs, (list, tuple)): + spec = next((x for x in specs if isinstance(x, DTensorSpec)), None) + else: + spec = None + if spec is not None and spec.placements[mesh_dim] == placement: + out_set.add(i) + r = node_root.get(opt.node_map[node], opt.node_map[node]) + restrict[r] = restrict.get(r, out_set) & out_set + return restrict + + def _is_forbidden(self, key) -> bool: + """A strategy edge is forbidden if a constraint ruled it out OR it was + pruned for infinite cost. Pruning removes such keys from decision_vars + entirely (see ShardingOptimizer._build_decision_vars), so a key missing + from decision_vars is just as forbidden as one in ``self.forbidden``.""" + return key in self.forbidden or key not in self.opt.decision_vars + + def _surviving_dv(self, v, argi, o): + """A DecisionVar for (v, argi, o, *) using any inp_idx that survived + pruning, or None if every edge for that (arg, out) was pruned. + compute_cost / input_spec are identical across inp_idx for a fixed out.""" + strat = self.opt.strats[self.opt.nodes[v]].strategies[o] + n_inp = ( + len(strat.redistribute_cost[argi]) + if argi < len(strat.redistribute_cost) + else 1 + ) + for inp in range(n_inp): + dv = self.opt.decision_vars.get((v, argi, o, inp)) + if dv is not None: + return dv + return None + + def _out_fully_forbidden(self, v, node, o): + strat = self.opt.strats[node].strategies[o] + for argi, costs in enumerate(strat.redistribute_cost): + if all(self._is_forbidden((v, argi, o, inp)) for inp in range(len(costs))): + return True + return False + + def _build_groups(self, paired_edges, flow_couplings): + opt = self.opt + n = len(opt.nodes) + uf = _UnionFind(n) + # cluster_links is node-level: (copy node idx, root node idx) pairs. + cluster_pairs = set(opt.cluster_links.items()) + for li, ri in cluster_pairs: + uf.union(li, ri) + for a, b, _ in paired_edges: + uf.union(a, b) + + allow: dict[tuple, dict[int, set]] = defaultdict(lambda: defaultdict(set)) + adj: dict[int, set] = defaultdict(set) + for li, ri in cluster_pairs: + for o in self.allowed_out.get(ri, []): + allow[(ri, li)][o].add(o) + for o in self.allowed_out.get(li, []): + allow[(li, ri)][o].add(o) + adj[li].add(ri) + adj[ri].add(li) + for a, b, pairs in paired_edges: + for oa, ob in pairs: + allow[(a, b)][oa].add(ob) + allow[(b, a)][ob].add(oa) + adj[a].add(b) + adj[b].add(a) + # Flow couplings: producers feeding a clustered consumer's shared inp are + # forced to the same out_idx (same-index coupling, star to the rep). + for producers in flow_couplings: + ps = sorted(producers) + rep = ps[0] + for q in ps[1:]: + uf.union(rep, q) + for o in self.allowed_out.get(rep, []): + allow[(rep, q)][o].add(o) + for o in self.allowed_out.get(q, []): + allow[(q, rep)][o].add(o) + adj[rep].add(q) + adj[q].add(rep) + + comps: dict[int, list[int]] = defaultdict(list) + for node in opt.strats: + if node.op == "output": + continue + v = opt.node_map[node] + comps[uf.find(v)].append(v) + + cost_bearing_set = set(self.cost_bearing) + self.groups = [] + self.node_to_group = {} + for members in comps.values(): + members.sort() + group = _Group(members=members) + group.cost_bearing = [m for m in members if m in cost_bearing_set] + group.choices = self._enumerate_choices(members, allow, adj) + if not group.choices: + raise RuntimeError( + f"No feasible joint choice for group {members}; " + "constraints are contradictory." + ) + gid = len(self.groups) + self.groups.append(group) + for m in members: + self.node_to_group[m] = gid + + def _enumerate_choices(self, members, allow, adj): + if len(members) == 1: + v = members[0] + return [{v: o} for o in self.allowed_out.get(v, [])] + member_set = set(members) + # BFS order from a representative so every member after the first is + # adjacent to an already-assigned one; coupling then propagates + # deterministically (no spurious K-way branching that explodes the + # domain for large cluster+paired groups). + order = [] + seen = set() + for start in members: + if start in seen: + continue + queue = [start] + seen.add(start) + while queue: + m = queue.pop(0) + order.append(m) + for nb in adj[m]: + if nb in member_set and nb not in seen: + seen.add(nb) + queue.append(nb) + results: list[dict[int, int]] = [] + limit = self.group_domain_limit + + def candidates(m, assign): + cand = None + for nb in adj[m]: + if nb in assign and nb in member_set: + allowed = allow[(nb, m)].get(assign[nb], set()) + cand = allowed if cand is None else (cand & allowed) + cand = ( + set(self.allowed_out.get(m, [])) + if cand is None + else (cand & set(self.allowed_out.get(m, []))) + ) + return cand + + def dfs(i, assign): + if len(results) >= limit: + return + if i == len(order): + results.append(dict(assign)) + return + m = order[i] + for val in sorted(candidates(m, assign)): + assign[m] = val + dfs(i + 1, assign) + del assign[m] + if len(results) >= limit: + return + + dfs(0, {}) + if len(results) >= limit: + logger.warning( + "Approximate solver: group of %d nodes hit group_domain_limit=%d.", + len(members), + limit, + ) + return results + + def _prune_candidates(self): + if self.candidate_limit is None: + return + for group in self.groups: + if len(group.members) != 1 or len(group.choices) <= self.candidate_limit: + continue + v = group.members[0] + node = self.opt.nodes[v] + lbs = sorted( + (self._choice_lower_bound(v, node, c[v]), ci) + for ci, c in enumerate(group.choices) + ) + keep = {ci for _, ci in lbs[: self.candidate_limit]} + group.choices = [group.choices[ci] for ci in sorted(keep)] + + def _choice_lower_bound(self, v, node, o): + opt = self.opt + strat = opt.strats[node].strategies[o] + mult = self.node_mult[v] + dv0 = self._surviving_dv(v, 0, o) + if dv0 is None: + return INF # every edge for this output strategy was pruned + lb = dv0.compute_cost * len(strat.redistribute_cost) + lb *= mult + for argi, _p in self.input_edges.get(v, []): + best = INF + for inp in range(len(strat.redistribute_cost[argi])): + key = (v, argi, o, inp) + if self._is_forbidden(key): + continue + dv = opt.decision_vars[key] + best = min(best, dv.comm_cost + dv.sharding_transition_cost) + if math.isfinite(best): + lb += mult * best + return lb + + # ------------------------------------------------------------------ # + # Memory constraint (ratios, budget, tight-budget param pinning) + # ------------------------------------------------------------------ # + def _build_memory_info(self): + opt = self.opt + factors = None + for fname, kwargs in getattr(opt, "_constraint_log", []): + if fname == "add_parameter_memory_constraint": + factors = kwargs + if factors is None: + return + try: + from torch._functorch._aot_autograd.fx_utils import get_param_nodes + + param_nodes = get_param_nodes(opt.graph) + except Exception: + return + + low_f, high_f = factors["memory_factor_low"], factors["memory_factor_high"] + budget_low = budget_high = 0.0 + param_idxs, ratios = [], {} + for node in param_nodes: + v = opt.node_map[node] + param_idxs.append(v) + r = {o: self._param_ratio(v, node, o) for o in self.allowed_out.get(v, [])} + ratios[v] = r + best = min(r.values()) + budget_low += max(best, low_f) + budget_high += max(best, high_f) + + tight = abs(budget_high - budget_low) < 1e-9 + if tight: + # Σ ratio == Σ min(ratio) forces every param to a min-ratio choice. + for v in param_idxs: + r = ratios[v] + mn = min(r.values()) + self.allowed_out[v] = [ + o for o in self.allowed_out[v] if r[o] <= mn + 1e-12 + ] + self._memory = { + "param_idxs": param_idxs, + "ratios": ratios, + "budget_low": budget_low, + "budget_high": budget_high, + "tight": tight, + } + + def _param_ratio(self, v, node, o): + spec = self._surviving_dv(v, 0, o).input_spec + new_shape, _ = _get_sharded_shape_stride(spec) + return math.prod(new_shape) / math.prod(spec.tensor_meta.shape) + + # ------------------------------------------------------------------ # + # Factor graph (numpy unary + pairwise matrices over groups) + # ------------------------------------------------------------------ # + def _build_factors(self): + G = len(self.groups) + # per member, its out_idx across its group's choices + member_vals = [] + for group in self.groups: + mv = {} + for m in group.cost_bearing: + mv[m] = np.array([c[m] for c in group.choices], dtype=np.int64) + # also predecessors that are non-cost-bearing but in this group + for m in group.members: + if m not in mv: + mv[m] = np.array([c[m] for c in group.choices], dtype=np.int64) + member_vals.append(mv) + + self.g_unary = [np.zeros(g.domain) for g in self.groups] + for gid, group in enumerate(self.groups): + for m in group.cost_bearing: + vals = member_vals[gid][m] + self.g_unary[gid] += self.node_mult[m] * self._self_cost_vec(m, vals) + + C: dict[tuple, np.ndarray] = {} + nbr_set: list[set] = [set() for _ in range(G)] + for v in self.cost_bearing: + gv = self.node_to_group[v] + mult = self.node_mult[v] + for argi, p in self.input_edges[v]: + gp = self.node_to_group[p] + R = self._edge_matrix(v, argi, p) # (Kv, Kp) raw, BIG if forbidden + av = member_vals[gv][v] + bp = member_vals[gp][p] + contrib = mult * R[np.ix_(av, bp)] # (D_gv, D_gp) + if gv == gp: + self.g_unary[gv] += np.diagonal(contrib) + else: + a, b = (gv, gp) if gv < gp else (gp, gv) + mat = contrib if gv < gp else contrib.T + if (a, b) in C: + C[(a, b)] += mat + else: + C[(a, b)] = mat.copy() + nbr_set[a].add(b) + nbr_set[b].add(a) + self.C = C + self.nbrs = [sorted(s) for s in nbr_set] + + def _self_cost_vec(self, m, out_indices): + """Vectorized self-cost (compute + producer-less arg costs) for node m + over an array of out_idx.""" + opt = self.opt + node = opt.nodes[m] + prod = self._arg_prod.get(m, {}) + out = np.empty(len(out_indices)) + for i, o in enumerate(out_indices): + strat = opt.strats[node].strategies[o] + n_args = len(strat.redistribute_cost) + dv0 = self._surviving_dv(m, 0, o) + if dv0 is None: # whole output strategy pruned + out[i] = BIG + continue + c = dv0.compute_cost * n_args + # Args with no flow edge (constructors / None-spec) are scored at + # inp=0 here; args with a producer are charged via the pairwise edges. + for argi in range(n_args): + if argi in prod: + continue + key = (m, argi, o, 0) + if self._is_forbidden(key): + c = BIG + break + dv = opt.decision_vars[key] + c += dv.comm_cost + dv.sharding_transition_cost + out[i] = c + return out + + def _edge_matrix(self, v, argi, p): + """Raw (Kv, Kp) edge cost matrix R[o_v][o_p] = comm + transition, BIG when + the (o_v, o_p) combination is forbidden. Only entries that can actually be + indexed by the group choices are filled; the rest are BIG.""" + opt = self.opt + Kv = len(opt.strats[opt.nodes[v]].strategies) + Kp = len(opt.strats[opt.nodes[p]].strategies) + R = np.full((Kv, Kp), BIG) + gv = self.node_to_group[v] + gp = self.node_to_group[p] + ov_vals = sorted({c[v] for c in self.groups[gv].choices}) + op_vals = sorted({c[p] for c in self.groups[gp].choices}) + for ov in ov_vals: + for op in op_vals: + key = (v, argi, ov, op) + if self._is_forbidden(key): + continue + dv = opt.decision_vars[key] + R[ov, op] = dv.comm_cost + dv.sharding_transition_cost + return R + + def _pair_matrix(self, g, h): + """Pairwise cost oriented as (x_g, x_h).""" + if g < h: + return self.C[(g, h)] + return self.C[(h, g)].T + + # ------------------------------------------------------------------ # + # Energy (fast, numpy) + # ------------------------------------------------------------------ # + def _fast_group_energy(self, gid, ci): + e = self.g_unary[gid][ci] + for h in self.nbrs[gid]: + ch = self.groups[h].current + e += self.C[(gid, h)][ci, ch] if gid < h else self.C[(h, gid)][ch, ci] + return e + + def _fast_total_energy(self): + total = 0.0 + for gid, g in enumerate(self.groups): + total += self.g_unary[gid][g.current] + for (a, b), mat in self.C.items(): + total += mat[self.groups[a].current, self.groups[b].current] + return total + + # ------------------------------------------------------------------ # + # Belief propagation (min-sum) + decode + # ------------------------------------------------------------------ # + def _belief_propagation(self, deadline=None): + """Sequential tree-reweighted message passing (TRW-S). + + Plain loopy min-sum BP settles into globally-inconsistent fixed points on + this MRF (empirically 5-16% above the optimum). TRW-S optimizes a convex + upper bound over a tree decomposition (here: monotonic chains induced by a + node ordering), so on the integral sharding problem it converges to the + exact MAP. Node g is reweighted by 1/(chains through g) = 1/max(in,out)deg + under the ordering; forward and backward half-sweeps send only along edges + oriented with the pass. We decode each sweep and keep the best assignment.""" + G = len(self.groups) + if G == 0: + return + unary = self.g_unary + nbrs = self.nbrs + + order = sorted(range(G), key=lambda g: min(self.groups[g].members)) + pos = [0] * G + for i, g in enumerate(order): + pos[g] = i + gamma = np.ones(G) + for g in range(G): + indeg = sum(1 for h in nbrs[g] if pos[h] < pos[g]) + outdeg = sum(1 for h in nbrs[g] if pos[h] > pos[g]) + gamma[g] = 1.0 / max(indeg, outdeg, 1) + + msg: dict[tuple, np.ndarray] = {} + for g in range(G): + for h in nbrs[g]: + msg[(g, h)] = np.zeros(len(unary[h])) + + # We decode every sweep and keep the best assignment. The decoded energy + # converges in long, irregular plateaus (it can sit at a high value for + # ~100 sweeps, drop, plateau again, then drop to the optimum), so neither + # an energy-plateau counter nor a message-delta threshold detects true + # convergence without stopping on a false plateau. We therefore run a + # fixed sweep budget (bounded by the time deadline), which is enough for + # the slowest converger observed, and an exact fixed point ends early. + best_e = INF + best_snap = None + for sweep in range(self.bp_iters): + max_delta = 0.0 + for forward in (True, False): + for g in order if forward else order[::-1]: + if not nbrs[g]: + continue + wp = unary[g].copy() + for r in nbrs[g]: + wp += msg[(r, g)] + wp *= gamma[g] + for h in nbrs[g]: + if (pos[h] > pos[g]) != forward: + continue + P = self._pair_matrix(g, h) # (D_g, D_h) + m = ((wp - msg[(h, g)])[:, None] + P).min(axis=0) + m -= m.min() + d = np.abs(m - msg[(g, h)]).max() + if d > max_delta: + max_delta = d + msg[(g, h)] = m + self._decode(msg) + e = self._fast_total_energy() + if e < best_e: + best_e, best_snap = e, [grp.current for grp in self.groups] + self._bp_last_iter = sweep + 1 + self._bp_last_delta = max_delta + if max_delta == 0.0 or ( + deadline is not None and time.perf_counter() > deadline + ): + break + + if best_snap is not None: + for gid, ci in enumerate(best_snap): + self._set_group(gid, ci) + + def _decode(self, msg): + """Sequential topological decode: fix each group to the argmin of its + belief conditioned on already-decoded neighbors (exact pairwise cost) and + BP messages for the rest. Produces a consistent, forbidden-avoiding + assignment, unlike independent argmin on a loopy graph.""" + G = len(self.groups) + order = sorted(range(G), key=lambda g: min(self.groups[g].members)) + decided: dict[int, int] = {} + for g in order: + b = self.g_unary[g].copy() + for h in self.nbrs[g]: + if h in decided: + b = b + self._pair_matrix(g, h)[:, decided[h]] + else: + b = b + msg[(h, g)] + ci = int(np.argmin(b)) + decided[g] = ci + self._set_group(g, ci) + + # ------------------------------------------------------------------ # + # Local search + # ------------------------------------------------------------------ # + def _set_group(self, gid, ci): + group = self.groups[gid] + group.current = ci + for m, o in group.choices[ci].items(): + self.cur_out[m] = o + + def _coordinate_descent(self, deadline): + for _ in range(self.max_sweeps): + if time.perf_counter() > deadline: + break + improved = False + for gid in range(len(self.groups)): + if self.groups[gid].domain <= 1: + continue + cur = self.groups[gid].current + best_i, best_e = cur, self._fast_group_energy(gid, cur) + for ci in range(self.groups[gid].domain): + if ci == cur: + continue + e = self._fast_group_energy(gid, ci) + if e < best_e - 1e-6 and self._memory_ok_after(gid, ci): + best_i, best_e = ci, e + if best_i != cur: + self._set_group(gid, best_i) + improved = True + if not improved: + break + + def _star_block_search(self, deadline): + ranked = sorted( + ( + (len(self.nbrs[g]), g) + for g in range(len(self.groups)) + if len(self.nbrs[g]) >= 2 and self.groups[g].domain > 1 + ), + reverse=True, + ) + for _ in range(self.star_passes): + if time.perf_counter() > deadline: + break + improved = False + for _deg, gid in ranked: + if time.perf_counter() > deadline: + break + if self._optimize_star(gid): + improved = True + if not improved: + break + + def _optimize_star(self, gid): + children = [h for h in self.nbrs[gid] if self.groups[h].domain > 1] + child_costs = sorted( + ((self._fast_group_energy(h, self.groups[h].current), h) for h in children), + reverse=True, + ) + child_ids = [h for _e, h in child_costs[: self.max_star_children]] + if not child_ids: + return False + block = [gid, *child_ids] + base = self._block_energy(block) + best_energy = base + best_center = self.groups[gid].current + best_children = {h: self.groups[h].current for h in child_ids} + for ci in range(self.groups[gid].domain): + self._set_group(gid, ci) + if not self._memory_ok_after(gid, ci): + continue + chosen = {} + for h in child_ids: + b_i, b_e = self.groups[h].current, INF + for hi in range(self.groups[h].domain): + e = self._fast_group_energy(h, hi) + if e < b_e: + b_i, b_e = hi, e + self._set_group(h, b_i) + chosen[h] = b_i + energy = self._block_energy(block) + if energy < best_energy - 1e-6 and self._block_memory_ok(): + best_energy = energy + best_center = ci + best_children = dict(chosen) + self._set_group(gid, best_center) + for h, hi in best_children.items(): + self._set_group(h, hi) + return best_energy < base - 1e-6 + + def _block_energy(self, gids): + total = 0.0 + seen_edges = set() + for g in gids: + total += self.g_unary[g][self.groups[g].current] + for h in self.nbrs[g]: + key = (g, h) if g < h else (h, g) + if key in seen_edges: + continue + seen_edges.add(key) + a, b = key + total += self.C[key][self.groups[a].current, self.groups[b].current] + return total + + # ------------------------------------------------------------------ # + # Memory repair + # ------------------------------------------------------------------ # + def _current_memory(self): + if self._memory is None: + return 0.0 + return sum( + self._memory["ratios"][v][self.cur_out[v]] + for v in self._memory["param_idxs"] + ) + + def _memory_ok_after(self, gid, ci): + if self._memory is None or self._memory.get("tight") or not self._mem_enforce: + return True + ratios = self._memory["ratios"] + choice = self.groups[gid].choices[ci] + delta = sum( + ratios[m][o] - ratios[m][self.cur_out[m]] + for m, o in choice.items() + if m in ratios + ) + mem = self._current_memory() + delta + return ( + self._memory["budget_low"] - 1e-6 + <= mem + <= self._memory["budget_high"] + 1e-6 + ) + + def _block_memory_ok(self): + if self._memory is None or self._memory.get("tight") or not self._mem_enforce: + return True + mem = self._current_memory() + return ( + self._memory["budget_low"] - 1e-6 + <= mem + <= self._memory["budget_high"] + 1e-6 + ) + + def _memory_repair(self): + if self._memory is None or self._memory.get("tight"): + return + low, high = self._memory["budget_low"], self._memory["budget_high"] + ratios = self._memory["ratios"] + param_groups = { + self.node_to_group[v] + for v in self._memory["param_idxs"] + if v in self.node_to_group + } + for _ in range(2 * max(1, len(param_groups))): + mem = self._current_memory() + if low - 1e-6 <= mem <= high + 1e-6: + return + over = mem > high + best = None + for gid in param_groups: + group = self.groups[gid] + cur_e = self._fast_group_energy(gid, group.current) + for ci in range(group.domain): + if ci == group.current: + continue + choice = group.choices[ci] + dmem = sum( + ratios[m][choice[m]] - ratios[m][self.cur_out[m]] + for m in choice + if m in ratios + ) + if (dmem < -1e-9) != over and abs(dmem) > 1e-9: + continue + if abs(dmem) <= 1e-9: + continue + score = (self._fast_group_energy(gid, ci) - cur_e) / abs(dmem) + if best is None or score < best[0]: + best = (score, gid, ci) + if best is None: + logger.warning( + "Approximate solver: memory repair stuck at %.4f " + "(budget=[%.4f,%.4f]).", + mem, + low, + high, + ) + return + self._set_group(best[1], best[2]) + + # ------------------------------------------------------------------ # + # Lagrangian memory-constrained solve + # ------------------------------------------------------------------ # + def _build_mem_unary(self): + """Per-group vector mem_unary[gid][ci] = Σ_{param member} ratio[member][ci], + i.e. the memory term as a node-separable unary so it folds into the + Lagrangian objective with no change to the pairwise structure.""" + self._mem_unary = [np.zeros(g.domain) for g in self.groups] + if self._memory is None: + return + ratios = self._memory["ratios"] + for v in self._memory["param_idxs"]: + gid = self.node_to_group.get(v) + if gid is None: + continue + r = ratios[v] + self._mem_unary[gid] += np.array( + [r[c[v]] for c in self.groups[gid].choices] + ) + + def _run_search(self, deadline): + self._belief_propagation(deadline) + self._coordinate_descent(deadline) + self._star_block_search(deadline) + + def solve_lagrangian( + self, + budget_low, + budget_high, + deadline=None, + max_iter=30, + lam_tol=1e-9, + verbose=False, + ): + """Memory-constrained solve via Lagrangian relaxation. + + The budget Σ_param ratio[v][x_v] ∈ [low, high] is a single linear, + node-separable coupling. Penalizing it by λ folds λ·ratio into each param + node's unary and leaves the pairwise MRF untouched, so TRW-S + polish + solves the penalized problem directly. a(λ) := Σ ratio at the optimum is + monotone non-increasing in λ (larger λ ⇒ more sharding ⇒ less memory), so + a scalar bisection on λ ≥ 0 drives a(λ) into the budget. The cheapest + feasible assignment seen is kept; the existing greedy repair only closes + any small residual from integrality. + + Leaves the solver at the chosen assignment (does not write back) and + returns a dict: objective (true), memory (achieved a), lam, feasible, + iters.""" + if not self._mem_unary: + self._build_mem_unary() + t_start = time.perf_counter() + if deadline is None: + deadline = t_start + self.max_time_s + # Reserve the tail of the budget for the constrained polish below. + bisect_deadline = t_start + 0.6 * (deadline - t_start) + base = [u.copy() for u in self.g_unary] + prev_enforce = self._mem_enforce + self._mem_enforce = False + eps = 1e-6 + + best = {"objective": INF, "snapshot": None, "memory": None, "lam": None} + # Closest over-budget assignment (smallest excess memory) — the seed the + # repair step nudges down into the budget to recover integer solutions + # that lie inside the (memory, cost) hull and so no lambda can reach. + seed = {"memory": INF, "snapshot": None} + + def evaluate(lam): + for gid in range(len(self.groups)): + self.g_unary[gid] = base[gid] + lam * self._mem_unary[gid] + self._run_search(bisect_deadline) + a = self._current_memory() + obj = self.total_objective() + feasible = budget_low - eps <= a <= budget_high + eps + if feasible and obj < best["objective"]: + best.update( + objective=obj, + snapshot=[g.current for g in self.groups], + memory=a, + lam=lam, + ) + if budget_high + eps < a < seed["memory"]: + seed.update(memory=a, snapshot=[g.current for g in self.groups]) + if verbose: + logger.info( + "lagrangian: lam=%.6g memory=%.5f obj=%.2f feasible=%s", + lam, + a, + obj, + feasible, + ) + return a + + a0 = evaluate(0.0) + iters = 1 + if a0 <= budget_high + eps: + lam = 0.0 # unconstrained optimum already fits the budget + else: + lo_lam, hi_lam = 0.0, 1.0 + while evaluate(hi_lam) > budget_high + eps and iters < max_iter: + lo_lam, hi_lam = hi_lam, hi_lam * 2.0 + iters += 1 + while iters < max_iter and hi_lam - lo_lam > lam_tol: + mid = 0.5 * (lo_lam + hi_lam) + a = evaluate(mid) + iters += 1 + if a > budget_high + eps: + lo_lam = mid # still over budget, penalize harder + else: + hi_lam = mid # feasible, try to relax toward the cheaper side + lam = hi_lam + + for gid in range(len(self.groups)): + self.g_unary[gid] = base[gid] + self._mem_enforce = prev_enforce + + # Constrained polish (under the base unary, budget enforced). No single λ + # recovers integer solutions inside the (memory, cost) hull; coordinate + + # star search restricted to the budget can climb from an over-sharded + # point back up to a cheaper intermediate-memory one. We polish both the + # bisection's feasible point and the repaired closest-over-budget seed and + # keep the cheapest feasible result. + def polish(snapshot): + for gid, ci in enumerate(snapshot): + self._set_group(gid, ci) + self._memory_repair() + self._coordinate_descent(deadline) + self._star_block_search(deadline) + a = self._current_memory() + if budget_low - eps <= a <= budget_high + eps: + obj = self.total_objective() + if obj < best["objective"]: + best.update( + objective=obj, + snapshot=[g.current for g in self.groups], + memory=a, + lam=lam, + ) + + for snap in (best["snapshot"], seed["snapshot"]): + if snap is not None: + polish(snap) + + if best["snapshot"] is not None: + for gid, ci in enumerate(best["snapshot"]): + self._set_group(gid, ci) + else: + # Nothing landed in [low, high]; repair the last assignment in place. + self._memory_repair() + + a = self._current_memory() + return { + "objective": self.total_objective(), + "memory": a, + "lam": lam, + "feasible": budget_low - eps <= a <= budget_high + eps, + "iters": iters, + } + + # ------------------------------------------------------------------ # + # Write-back + # ------------------------------------------------------------------ # + def total_objective(self): + """Exact objective of the current assignment via decision_vars (for + verification); equals pulp.value(prob.objective) after write-back.""" + total = 0.0 + for v in self.cost_bearing: + node = self.opt.nodes[v] + o = self.cur_out[v] + strat = self.opt.strats[node].strategies[o] + prod = self._arg_prod.get(v, {}) + n_args = len(strat.redistribute_cost) + c = 0.0 + for argi in range(n_args): + p = prod.get(argi) + inp = self.cur_out[p] if p is not None else 0 + key = (v, argi, o, inp) + if self._is_forbidden(key): + return INF + c += self.opt.decision_vars[key].cost + total += self.node_mult[v] * c + return total + + def _write_back(self): + opt = self.opt + has_pulp = bool(opt.pulp_variables) + if has_pulp: + for var in opt.pulp_variables.values(): + var.varValue = 0 + selected = [] + feasible = True + for v in self.cost_bearing: + node = opt.nodes[v] + o = self.cur_out[v] + strat = opt.strats[node].strategies[o] + prod = self._arg_prod.get(v, {}) + for argi in range(len(strat.redistribute_cost)): + p = prod.get(argi) + inp = self.cur_out[p] if p is not None else 0 + key = (v, argi, o, inp) + if self._is_forbidden(key): + feasible = False + # A pruned key has no PuLP variable; the infeasible flag above + # already records it (and raises in _solve). + if has_pulp and key in opt.pulp_variables: + opt.pulp_variables[key].varValue = 1 + selected.append(key) + opt.selected_keys = list(selected) + for rk in selected: + opt.selected_keys.extend(opt._linked_option_keys(rk)) + # Populate prob.objective (when a PuLP problem exists) so callers can also + # score via pulp.value(prob.objective); the returned value uses the + # equivalent but cheaper total_objective(). In the lite (no-PuLP) build, + # there is no problem to populate. + if opt.prob is not None: + opt.prob.status = pulp.LpStatusOptimal + opt.prob.sol_status = pulp.LpSolutionOptimal + opt._set_objective() + return INF if not feasible else self.total_objective() diff --git a/autoparallel/graph_passes/graph_clustering.py b/autoparallel/graph_passes/graph_clustering.py index c01a09a3..a8efafea 100644 --- a/autoparallel/graph_passes/graph_clustering.py +++ b/autoparallel/graph_passes/graph_clustering.py @@ -65,18 +65,17 @@ def _prepare_op_strategy(op_strategy): return str(op_strategy) -def _hash_node(node, strategies, input_pickler): +def _hash_node(node, strategies, input_pickler, op_str): + # op_str caches _prepare_op_strategy(strategies[n]) per node: each node's + # (large, 3D-mesh) strategy string is otherwise rebuilt once as self plus + # once per consumer, dominating clustering time on deep models. key = ( str(node.target), node.meta.get("partitioner_tag"), node.meta.get("stack_trace"), _normalize_args(node), - _prepare_op_strategy(strategies[node]), - tuple( - _prepare_op_strategy(strategies[s]) - for s in node.all_input_nodes - if s in strategies - ), + op_str[node], + tuple(op_str[s] for s in node.all_input_nodes if s in strategies), ) return sha256_hash(input_pickler.dumps(key)) @@ -107,6 +106,7 @@ def get_identical_regions( hash_to_duplicates: dict[str, IdenticalNodes] = defaultdict(list) node_to_duplicates: dict[Node, IdenticalNodes] = {} t = time.time() + op_str = {n: _prepare_op_strategy(s) for n, s in strategies.items()} for node in graph.nodes: if node.op == "placeholder": continue @@ -115,7 +115,9 @@ def get_identical_regions( # HOP submodule get_attr nodes are not in strategies. continue - duplicates = hash_to_duplicates[_hash_node(node, strategies, input_pickler)] + duplicates = hash_to_duplicates[ + _hash_node(node, strategies, input_pickler, op_str) + ] duplicates.append(node) node_to_duplicates[node] = duplicates logger.debug(f"Hashed nodes in {time.time() - t} s") diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 6b72878b..c3832c87 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -69,9 +69,11 @@ runtime cost while satisfying all constraints. """ +import contextlib import logging import math import operator +import os import tempfile import time from collections import defaultdict @@ -107,6 +109,87 @@ logger = logging.getLogger(__name__) +# Strategy enumeration fills each OpSpec's redistribute_cost via torch's +# generate_redistribute_costs (an expensive per-strategy redistribute-plan +# computation, the dominant cost of build on large/3D meshes). But +# _build_decision_vars overwrites every edge with the NCCL-aware +# estimate_strategy_comms_cost, and nothing reads the enumeration costs in +# between (remove_invalid_configs / keep_unique_configs select on placements/ +# shapes only). So during enumeration we replace torch's redistribute_cost with +# a structure-preserving dummy to skip the wasted work; the final decision_vars +# are byte-identical. Autoparallel's own cost model uses a separate +# redistribute_cost (collective_runtime_estimation) and is unaffected. Escape +# hatch for A/B verification: AP_FAST_BUILD=0. +_FAST_BUILD = os.environ.get("AP_FAST_BUILD", "1") == "1" + + +@contextlib.contextmanager +def _skip_enumeration_redistribute_cost(): + if not _FAST_BUILD: + yield + return + import torch.distributed.tensor._ops.utils as _dt_utils + + orig = _dt_utils.redistribute_cost + _dt_utils.redistribute_cost = lambda *args, **kwargs: 0.0 + try: + yield + finally: + _dt_utils.redistribute_cost = orig + + +# Number of fork workers for the per-edge cost computation in _build_decision_vars +# (the dominant cost of build on large/3D meshes). 1 = serial (use for A/B +# verification); default scales with cores. The computation is per-node +# independent and deterministic, so the parallel result is byte-identical. +_PARALLEL_BUILD_WORKERS = int( + os.environ.get("AP_PARALLEL_BUILD", str(min(32, (os.cpu_count() or 1)))) +) + +# Set to the optimizer before forking cost workers; the workers read it from the +# fork-inherited address space (no pickling of the mesh / strategy graph). +_FORK_OPT: "ShardingOptimizer | None" = None + + +def _par_node_edge_costs(node_idx): + """Worker: compute the per-edge (comm, transition) costs and the per-strategy + compute cost for one root node, reading the fork-inherited optimizer. Pure — + it reads strats and mutates nothing; the parent assembles DecisionVars from + these primitives. Returns (node_idx, out_data) where + out_data[out_idx] = (per_arg_compute, arg_rows) and + arg_rows[argi][inp_idx] = (comm_cost, transition_cost).""" + opt = _FORK_OPT + node = opt.nodes[node_idx] + op_strategy = opt.strats[node] + num_args = len(op_strategy.strategies[0].input_specs) + all_input_nodes = opt._all_input_nodes(node) + producer_strategies = [opt.strats[n] for n in all_input_nodes] + out_data = [] + for output_strategy in op_strategy.strategies: + per_arg_compute = ( + estimate_strategy_runtime_cost(node, output_strategy) / num_args + ) + arg_rows = [] + for argi, redist_costs in enumerate(output_strategy.redistribute_cost): + producer_strategy = ( + producer_strategies[argi] if argi < len(producer_strategies) else None + ) + arg_rows.append( + [ + opt._compute_edge_costs( + node, + output_strategy, + argi, + inp_idx, + default_comm_cost, + producer_strategy, + ) + for inp_idx, default_comm_cost in enumerate(redist_costs) + ] + ) + out_data.append((per_arg_compute, arg_rows)) + return node_idx, out_data + def concretize_symint(val): """Concretize a SymInt to a plain int, pass through other values. @@ -188,10 +271,13 @@ def concretize_gm(gm): return concrete_gm, orig_to_concrete, concrete_to_orig -@dataclass +@dataclass(slots=True) class DecisionVar: """A decision variable in the ILP, representing one (node, arg, output_placement, - input_placement) choice with its associated costs and strategy metadata.""" + input_placement) choice with its associated costs and strategy metadata. + + slots=True: there are millions of these on large/3D meshes, so dropping the + per-instance __dict__ materially cuts both build time and memory.""" var: Any # pulp.LpVariable cost: float @@ -203,6 +289,70 @@ class DecisionVar: input_spec: Any # DTensorSpec +@dataclass +class LPRelaxationResult: + objective: float + status: str + solve_s: float + total_s: float + + +@dataclass +class DPTopology: + nodes: list[torch.fx.Node] + predecessors: dict[torch.fx.Node, list[torch.fx.Node]] + node_to_index: dict[torch.fx.Node, int] + + +class DPBasedShardingSolver: + """EXPERIMENTAL / incomplete — not part of the supported solver path. + + Only reachable when ``ShardingOptimizer`` is built with the non-default + ``solver_backend="dp"`` (not exposed through ``AutoParallel``), and today it + only builds a topological order: :meth:`get_solution` raises + ``NotImplementedError``. Kept for in-progress work; do not rely on it. + """ + + def __init__(self, optimizer): + self.optimizer = optimizer + self.topology: Optional[DPTopology] = None + + def build_topological_order(self): + nodes = [node for node in self.optimizer.nodes if node.op != "output"] + node_to_index = {node: i for i, node in enumerate(nodes)} + predecessors = {} + + for node in nodes: + node_predecessors = self.optimizer._all_input_nodes(node) + predecessors[node] = node_predecessors + node_index = node_to_index[node] + for pred in node_predecessors: + pred_index = node_to_index.get(pred) + if pred_index is None: + raise RuntimeError( + f"Predecessor {pred} for node {node} is missing from " + "the DP topology" + ) + if pred_index >= node_index: + raise RuntimeError( + f"Predecessor {pred} for node {node} does not appear " + "before it in topological order" + ) + + self.topology = DPTopology( + nodes=nodes, + predecessors=predecessors, + node_to_index=node_to_index, + ) + return self.topology + + def get_solution(self, verbose=False): + raise NotImplementedError( + "DP-based sharding solver only builds topological order today; " + "strategy selection is not implemented yet." + ) + + def _assert_has_tensor_meta(spec_or_specs, node, label): """Assert that all DTensorSpecs in a spec (possibly a tuple) have tensor_meta.""" if isinstance(spec_or_specs, (list, tuple)): @@ -224,8 +374,23 @@ def __init__( mesh, force_grad_reduce_in_higher_precision=False, repeated_subgraphs=False, + solver_backend="ilp", + build_pulp=True, ): self.orig_gm = gm + if solver_backend not in {"ilp", "dp"}: + raise ValueError( + f"Unsupported solver_backend={solver_backend!r}; " + "expected 'ilp' or 'dp'" + ) + self.solver_backend = solver_backend + # When False, skip creating PuLP variables and constraints entirely. + # decision_var costs + strategies + cluster_links are still built, which + # is all the approximate solver needs (it derives the constraint topology + # directly). This avoids constructing millions of PuLP objects on large / + # 3D meshes, where that dominates build time. + self.build_pulp = build_pulp + self.prob = None # The optimizer works on a concretized copy of the graph where all # symbolic shapes are replaced with their concrete hint values. This # centralizes dynamic-shape handling: the optimization pipeline @@ -245,20 +410,92 @@ def __init__( # so that _apply_memory_constraint can exclude constrained params and # remove_constraints can keep this in sync. self._node_constraint_names: dict[str, str] = {} + # Maps node_name → list of (mesh_dim, placement) per-axis constraints. + # A per-axis constraint keeps a param in the memory budget (unlike a full + # node constraint) but restricts which strategies it can use, so the + # budget must compute its best achievable memory ratio over only the + # strategies that satisfy these constraints. + self._node_axis_constraints: dict[ + str, list[tuple[int, Placement]] + ] = defaultdict(list) + # Variables pinned to 0 by axis constraints applied with method="fix". + # Stored so they can be restored by remove_constraints / for re-solving. + self._fixed_vars: list = [] self._name_counters: dict[str, int] = {} + # Set by _build_decision_vars: the (node, arg, out, inp) keys whose + # strategy edge has finite cost. Invalid (infinite-cost) edges are + # pruned and get no variable. None means "no pruning filter". + self._valid_keys: set[tuple] | None = None + self.profile: dict[str, Any] = { + "mesh": self._profile_mesh(), + "model": self._profile_model(), + "timings": {}, + } + t_init_start = time.perf_counter() t0 = time.perf_counter() self.strats = self.build_sharding_metadata() + t_strategy = time.perf_counter() - t0 + self.profile["timings"]["strategy_enumeration_s"] = t_strategy + self.profile["strategies"] = self._profile_strategies() + logger.info( + "ShardingOptimizer phase profile: phase=strategy_enumeration " + "mesh_shape=%s mesh_dim_names=%s mesh_size=%s model_params=%s " + "graph_nodes=%s strategy_options=%s option_tuples=%s elapsed=%.3fs", + self.profile["mesh"]["shape"], + self.profile["mesh"]["dim_names"], + self.profile["mesh"]["size"], + self._format_billions(self.profile["model"]["parameter_numel"]), + self.profile["model"]["graph_nodes"], + self.profile["strategies"]["strategy_options"], + self.profile["strategies"]["option_tuples"], + t_strategy, + ) # nodes/node_map are derived from strats (not graph.nodes) so that # shape-computation nodes skipped by build_sharding_metadata don't # appear and indices stay consistent. self.nodes = list(self.strats.keys()) self.node_map = {node: i for i, node in enumerate(self.nodes)} - logger.debug("Placement options took %.3fs", time.perf_counter() - t0) + logger.debug("Placement options took %.3fs", t_strategy) from autoparallel.shardings.placement_options import get_placement_options_timer get_placement_options_timer().report() - self.cluster_links: dict[tuple, tuple] = {} + # Node-level: cluster-copy node idx -> root node idx (option indices are + # identical between copy and root; resolved on demand via + # _cluster_root_key / _linked_option_keys). + self.cluster_links: dict[int, int] = {} + self._root_to_copies: dict[int, list[int]] = defaultdict(list) + if self.solver_backend == "dp": + t0 = time.perf_counter() + self.solver = DPBasedShardingSolver(self) + topology = self.solver.build_topological_order() + t1 = time.perf_counter() + self.profile["dp"] = { + "topology_nodes": len(topology.nodes), + "topology_edges": sum( + len(preds) for preds in topology.predecessors.values() + ), + } + self.profile["timings"].update( + { + "topology_construction_s": t1 - t0, + "init_total_s": t1 - t_init_start, + } + ) + logger.info( + "ShardingOptimizer phase profile: phase=dp_topology " + "mesh_shape=%s mesh_dim_names=%s mesh_size=%s model_params=%s " + "topology_nodes=%s topology_edges=%s elapsed=%.3fs", + self.profile["mesh"]["shape"], + self.profile["mesh"]["dim_names"], + self.profile["mesh"]["size"], + self._format_billions(self.profile["model"]["parameter_numel"]), + self.profile["dp"]["topology_nodes"], + self.profile["dp"]["topology_edges"], + t1 - t0, + ) + return + if repeated_subgraphs: t = time.time() clusters = get_identical_regions(self.gm.graph, self.strats) @@ -268,13 +505,78 @@ def __init__( t0 = time.perf_counter() self.decision_vars = self._build_decision_vars() t1 = time.perf_counter() + logger.info( + "ShardingOptimizer phase profile: phase=decision_vars " + "mesh_shape=%s mesh_dim_names=%s mesh_size=%s model_params=%s " + "unique_ilp_vars=%s logical_decision_vars=%s " + "cluster_copied_decision_vars=%s pulp_var_creation=%.3fs " + "compute_cost=%.3fs edge_cost=%.3fs cost_estimation=%.3fs " + "elapsed=%.3fs", + self.profile["mesh"]["shape"], + self.profile["mesh"]["dim_names"], + self.profile["mesh"]["size"], + self._format_billions(self.profile["model"]["parameter_numel"]), + self._decision_var_profile["unique_pulp_variables"], + self._decision_var_profile["logical_decision_variables"], + self._decision_var_profile["cluster_copied_decision_variables"], + self._decision_var_profile["pulp_var_creation_s"], + self._decision_var_profile["compute_cost_estimation_s"], + self._decision_var_profile["edge_cost_estimation_s"], + self._decision_var_profile["cost_estimation_s"], + t1 - t0, + ) self.validate() t2 = time.perf_counter() - self.prob = pulp.LpProblem("AutoParallel", pulp.LpMinimize) - self.add_default_constraints() + if self.build_pulp: + self.prob = pulp.LpProblem("AutoParallel", pulp.LpMinimize) + self.add_default_constraints() t3 = time.perf_counter() + decision_var_build_s = t1 - t0 + cost_estimation_s = self._decision_var_profile["cost_estimation_s"] + decision_var_overhead_s = max( + decision_var_build_s + - self._decision_var_profile["pulp_var_creation_s"] + - cost_estimation_s, + 0.0, + ) + self.profile["timings"].update( + { + "decision_var_build_s": decision_var_build_s, + "decision_var_overhead_s": decision_var_overhead_s, + "validation_s": t2 - t1, + "constraint_construction_s": t3 - t2, + "ilp_construction_s": ( + self._decision_var_profile["pulp_var_creation_s"] + + decision_var_overhead_s + + (t3 - t2) + ), + "init_total_s": t3 - t_init_start, + } + ) n_unique_vars = len(self.pulp_variables) - n_constraints = len(self.prob.constraints) + n_constraints = len(self.prob.constraints) if self.prob is not None else 0 + self.profile["ilp"] = { + "unique_variables": n_unique_vars, + "logical_decision_variables": self._decision_var_profile[ + "logical_decision_variables" + ], + "cluster_copied_decision_variables": self._decision_var_profile[ + "cluster_copied_decision_variables" + ], + "constraints": n_constraints, + } + logger.info( + "ShardingOptimizer phase profile: phase=constraints " + "mesh_shape=%s mesh_dim_names=%s mesh_size=%s model_params=%s " + "unique_ilp_vars=%s constraints=%s elapsed=%.3fs", + self.profile["mesh"]["shape"], + self.profile["mesh"]["dim_names"], + self.profile["mesh"]["size"], + self._format_billions(self.profile["model"]["parameter_numel"]), + n_unique_vars, + n_constraints, + t3 - t2, + ) logger.debug( "ILP construction took %.3fs " "(decision_vars=%.3fs, validate=%.3fs, constraints=%.3fs)", @@ -289,6 +591,157 @@ def __init__( len(self.decision_vars), n_constraints, ) + self._log_init_profile() + + def _profile_mesh(self): + try: + mesh_shape = tuple(int(d) for d in self.mesh.shape) + except Exception: + mesh_shape = tuple() + try: + mesh_size = int(self.mesh.size()) + except Exception: + mesh_size = math.prod(mesh_shape) if mesh_shape else None + return { + "ndim": getattr(self.mesh, "ndim", len(mesh_shape)), + "shape": mesh_shape, + "dim_names": getattr(self.mesh, "mesh_dim_names", None), + "size": mesh_size, + } + + def _profile_model(self): + graph_nodes = list(self.graph.nodes) + op_counts = defaultdict(int) + tensor_nodes = 0 + for node in graph_nodes: + op_counts[node.op] += 1 + if _produces_tensor(node.meta.get("val")): + tensor_nodes += 1 + + param_numel = 0 + param_bytes = 0 + unknown_param_nodes = 0 + try: + param_nodes = get_param_nodes(self.graph) + except Exception: + param_nodes = [] + unknown_param_nodes = None + + for node in param_nodes: + val = node.meta.get("val") + if not isinstance(val, torch.Tensor): + unknown_param_nodes += 1 + continue + numel = self._safe_tensor_numel(val) + if numel is None: + unknown_param_nodes += 1 + continue + param_numel += numel + try: + param_bytes += numel * val.element_size() + except Exception: + pass + + return { + "graph_nodes": len(graph_nodes), + "tensor_nodes": tensor_nodes, + "op_counts": dict(op_counts), + "parameter_nodes": len(param_nodes), + "parameter_numel": param_numel, + "parameter_bytes": param_bytes, + "unknown_parameter_nodes": unknown_param_nodes, + } + + @staticmethod + def _safe_tensor_numel(tensor): + try: + numel = tensor.numel() + if isinstance(numel, int): + return numel + return int(numel) + except Exception: + pass + + shape = getattr(tensor, "shape", None) + if shape is None: + return None + + total = 1 + for dim in shape: + dim = concretize_symint(dim) + if not isinstance(dim, int): + return None + total *= dim + return total + + def _profile_strategies(self): + strategy_options = 0 + option_tuples = 0 + max_strategies_per_node = 0 + for node in self.strats: + if node.op == "output" or not hasattr(self.strats[node], "strategies"): + continue + strategies = self.strats[node].strategies + strategy_options += len(strategies) + max_strategies_per_node = max(max_strategies_per_node, len(strategies)) + option_tuples += sum(1 for _ in self.walk_over_options(node)) + return { + "nodes": len(self.strats), + "strategy_options": strategy_options, + "option_tuples": option_tuples, + "max_strategies_per_node": max_strategies_per_node, + } + + @staticmethod + def _format_billions(count): + if count is None: + return "unknown" + if count >= 1_000_000_000: + return f"{count / 1_000_000_000:.2f}B" + if count >= 1_000_000: + return f"{count / 1_000_000:.2f}M" + return str(count) + + @staticmethod + def _safe_float(value): + try: + return float(value) + except Exception: + return math.nan + + def _log_init_profile(self): + mesh = self.profile["mesh"] + model = self.profile["model"] + strategies = self.profile["strategies"] + ilp = self.profile["ilp"] + timings = self.profile["timings"] + logger.info( + "ShardingOptimizer init profile: " + "mesh_shape=%s mesh_dim_names=%s mesh_size=%s " + "model_params=%s param_nodes=%s graph_nodes=%s tensor_nodes=%s " + "strategy_options=%s option_tuples=%s " + "unique_ilp_vars=%s logical_decision_vars=%s constraints=%s " + "timings={strategy_enumeration=%.3fs,cost_estimation=%.3fs," + "ilp_construction=%.3fs,validation=%.3fs,total=%.3fs}", + mesh["shape"], + mesh["dim_names"], + mesh["size"], + self._format_billions(model["parameter_numel"]), + model["parameter_nodes"], + model["graph_nodes"], + model["tensor_nodes"], + strategies["strategy_options"], + strategies["option_tuples"], + ilp["unique_variables"], + ilp["logical_decision_variables"], + ilp["constraints"], + timings["strategy_enumeration_s"], + timings["cost_estimation_s"], + timings["ilp_construction_s"], + timings["validation_s"], + timings["init_total_s"], + ) + logger.debug("ShardingOptimizer init profile detail: %s", self.profile) def _get_next_name(self, prefix): idx = self._name_counters.setdefault(prefix, 0) @@ -305,77 +758,95 @@ def _normalize_node(self, node): def build_sharding_metadata(self): strats = {} - for node in self.graph.nodes: - if node.op in ("placeholder", "get_attr"): - val = node.meta.get("val") - if isinstance(val, torch.Tensor): - strats[node] = _create_all_options(self.mesh, val.shape, tensor=val) - elif node.op == "placeholder": - # Non-tensor placeholders (e.g. baked-in booleans/strings): - # keep them in strats with empty-shape replicate options - # so the constraint system can reference them. - strats[node] = _create_all_options(self.mesh, ()) + # Enumeration's redistribute_cost matrices are overwritten with real + # costs in _build_decision_vars, so skip computing them here (see + # _skip_enumeration_redistribute_cost). + with _skip_enumeration_redistribute_cost(): + for node in self.graph.nodes: + if node.op in ("placeholder", "get_attr"): + val = node.meta.get("val") + if isinstance(val, torch.Tensor): + strats[node] = _create_all_options( + self.mesh, val.shape, tensor=val + ) + elif node.op == "placeholder": + # Non-tensor placeholders (e.g. baked-in booleans/strings): + # keep them in strats with empty-shape replicate options + # so the constraint system can reference them. + strats[node] = _create_all_options(self.mesh, ()) + else: + # Non-tensor get_attr: GraphModule submodules used by + # HOPs — not added to strats, invisible to the ILP. + # _all_input_nodes filters them. + assert node.op == "get_attr" + assert any( + isinstance(u.target, torch._ops.HigherOrderOperator) + or "local_map" in u.name + for u in node.users + ), f"Non-tensor get_attr {node} is not used by a HOP" + elif node.op == "call_function": + if not _produces_tensor(node.meta.get("val")): + # Shape-computation nodes (sym_size, operator.mul, etc.) + # produce scalars, not tensors — skip sharding. + continue + user_strats = tree_map_only( + torch.fx.Node, + lambda x: strats.get(x, x.meta.get("val")), + node.args, + ) + user_args = tree_map_only( + torch.fx.Node, lambda x: x.meta.get("val"), node.args + ) + user_kwargs = tree_map_only( + torch.fx.Node, lambda x: x.meta.get("val"), node.kwargs + ) + strats[node] = get_placement_options_for_node( + self.mesh, node, user_strats, user_args, user_kwargs + ) + elif node.op == "output": + user_strats = tree_map_only( + torch.fx.Node, lambda x: strats[x], node.args + ) + strats[node] = user_strats else: - # Non-tensor get_attr: GraphModule submodules used by - # HOPs — not added to strats, invisible to the ILP. - # _all_input_nodes filters them. - assert node.op == "get_attr" - assert any( - isinstance(u.target, torch._ops.HigherOrderOperator) - or "local_map" in u.name - for u in node.users - ), f"Non-tensor get_attr {node} is not used by a HOP" - elif node.op == "call_function": - if not _produces_tensor(node.meta.get("val")): - # Shape-computation nodes (sym_size, operator.mul, etc.) - # produce scalars, not tensors — skip sharding. - continue - user_strats = tree_map_only( - torch.fx.Node, - lambda x: strats.get(x, x.meta.get("val")), - node.args, - ) - user_args = tree_map_only( - torch.fx.Node, lambda x: x.meta.get("val"), node.args - ) - user_kwargs = tree_map_only( - torch.fx.Node, lambda x: x.meta.get("val"), node.kwargs - ) - strats[node] = get_placement_options_for_node( - self.mesh, node, user_strats, user_args, user_kwargs - ) - elif node.op == "output": - user_strats = tree_map_only( - torch.fx.Node, lambda x: strats[x], node.args - ) - strats[node] = user_strats - else: - raise ValueError(f"Unexpected node op: {node.op}") + raise ValueError(f"Unexpected node op: {node.op}") return strats def create_cluster_links(self, clusters): - """Create a mapping between identical optimization nodes to reduce the - optimization space. If cluster_links[key1] == key2, the optimization - problem uses key2's variable in place of key1.""" + """Map each cluster-copy node to its root node (node-level). The optimizer + reuses the root's decision variable for every copy, and the per-(arg, out, + inp) option index is identical between a copy and its root, so we store + only the node->node map and reconstruct option keys on demand (see + _cluster_root_key / _linked_option_keys). Materializing one dict entry per + option-tuple instead costs tens of millions of entries (and seconds of + build time) on large/3D meshes.""" for cluster_group in clusters: cluster0 = cluster_group[0] for cluster_i in cluster_group[1:]: for n0, ni in zip(cluster0, cluster_i): - idx0 = self.node_map[n0] - idx1 = self.node_map[ni] - options_n0 = list(self.walk_over_options(n0)) - options_ni = list(self.walk_over_options(ni)) - assert options_n0 == options_ni, ( - f"Problem with graph clustering: {n0} and {ni} don't have the same number " - "of input/output placements. Please report a bug" + assert len(self.strats[n0].strategies) == len( + self.strats[ni].strategies + ), ( + f"Problem with graph clustering: {n0} and {ni} don't have " + "the same number of strategies. Please report a bug" ) - for argi, out_idx, inp_idx in options_n0: - self.cluster_links[(idx1, argi, out_idx, inp_idx)] = ( - idx0, - argi, - out_idx, - inp_idx, - ) + self.cluster_links[self.node_map[ni]] = self.node_map[n0] + + def _cluster_root_key(self, key): + """Resolve an option key to its cluster-root option key, or return it + unchanged when the node is not a cluster copy. The (arg, out, inp) indices + are identical between a copy and its root.""" + root_idx = self.cluster_links.get(key[0]) + return key if root_idx is None else (root_idx, key[1], key[2], key[3]) + + def _linked_option_keys(self, root_key): + """The option keys on the cluster copies of root_key's node (each a mirror + of root_key with the copy's node index). A copy mirrors its root's + per-option validity, so callers pass valid root keys only.""" + copies = self._root_to_copies.get(root_key[0]) + if not copies: + return () + return [(c, root_key[1], root_key[2], root_key[3]) for c in copies] def _all_input_nodes(self, node): """Variant of node.all_input_nodes that preserves duplicate nodes. @@ -409,15 +880,26 @@ def walk_over_options(self, node, constrain_arg=None): for inp_idx in range(len(strategy.redistribute_cost[argi])): yield argi, out_idx, inp_idx - def _create_pulp_variables(self): - """Create PuLP binary variables for all decision points, resolving - cluster links so that identical nodes share the same variable. + def _create_pulp_variables(self, variable_category=pulp.LpBinary): + """Create PuLP variables for all decision points, resolving cluster + links so that identical nodes share the same variable. Returns a dict mapping root (node_idx, argi, out_idx, inp_idx) keys to their PuLP variables. Linked keys are not stored; use _get_pulp_variable() to resolve them through cluster_links. + + Keys whose strategy is invalid (infinite cost) are pruned: if + self._valid_keys is set, only those keys get a variable. These + variables would otherwise be forced to zero by an inf-cost + constraint, so skipping them shrinks the ILP without changing the + optimum (see _build_decision_vars). """ - cluster_linked_node_idxs = {key[0] for key in self.cluster_links} + if variable_category not in {pulp.LpBinary, pulp.LpContinuous}: + raise ValueError( + f"Unsupported variable_category={variable_category!r}; " + "expected pulp.LpBinary or pulp.LpContinuous" + ) + cluster_linked_node_idxs = set(self.cluster_links) pulp_variables = {} for node, _ in self.strats.items(): @@ -428,20 +910,31 @@ def _create_pulp_variables(self): continue for argi, out_idx, inp_idx in self.walk_over_options(node): key = (node_idx, argi, out_idx, inp_idx) + if self._valid_keys is not None and key not in self._valid_keys: + continue root_node = self.nodes[node_idx] + bounds = ( + {"lowBound": 0, "upBound": 1} + if variable_category == pulp.LpContinuous + else {} + ) pulp_variables[key] = pulp.LpVariable( f"n={root_node},s={node_idx},arg={argi}," f"output_p={out_idx},input_p={inp_idx}", - cat=pulp.LpBinary, + cat=variable_category, + **bounds, ) return pulp_variables def _get_pulp_variable(self, key): """Look up the PuLP variable for a key, resolving through - cluster_links if the key belongs to a linked node.""" - root_key = self.cluster_links.get(key, key) - return self.pulp_variables[root_key] + cluster_links if the key belongs to a linked node. + + Returns None if the key was pruned (invalid/infinite-cost strategy). + """ + root_key = self._cluster_root_key(key) + return self.pulp_variables.get(root_key) def _compute_edge_costs( self, @@ -480,74 +973,94 @@ def _compute_edge_costs( def _build_decision_vars(self): """Build DecisionVar entries for every (node_idx, argi, out_idx, inp_idx) - combination in the strategy space.""" - t_pulp_start = time.perf_counter() - self.pulp_variables = self._create_pulp_variables() - t_pulp_end = time.perf_counter() - + combination in the strategy space. + + Strategy edges whose total cost is infinite (invalid redistributions) + are pruned: no variable is created for them. Such a variable would be + forced to zero by an inf-cost constraint anyway, so dropping it leaves + the optimum unchanged while removing ~30% of the variables and the + corresponding ~80% of constraints that are pure ``var == 0`` bounds. + + When ``build_pulp`` is False (approximate solver only) no PuLP variables + are created (``DecisionVar.var`` is None); the valid-key set is still + built so the approximate solver can treat a key absent from + ``decision_vars`` as forbidden. + """ # Precompute which node indices are cluster-linked so we can # copy costs from the root instead of recomputing them. - self._cluster_linked_node_idxs = {key[0] for key in self.cluster_links} + self._cluster_linked_node_idxs = set(self.cluster_links) t_compute = 0.0 t_edge = 0.0 n_vars = 0 + n_pruned = 0 n_cluster_copied = 0 + t_pulp_start = time.perf_counter() + self.pulp_variables = {} + self._valid_keys = set() decision_vars = {} strats_items = [ (self.node_map[node], node, strat) for node, strat in self.strats.items() ] - # Build DVs for root nodes only (not cluster-linked). - for node_idx, node, op_strategy in strats_items: - if node.op == "output": - continue - if node_idx in self._cluster_linked_node_idxs: - continue - - num_args = len(op_strategy.strategies[0].input_specs) - - for out_idx, output_strategy in enumerate(op_strategy.strategies): - tc0 = time.perf_counter() - compute_cost = estimate_strategy_runtime_cost(node, output_strategy) - tc1 = time.perf_counter() - t_compute += tc1 - tc0 - per_arg_compute = compute_cost / num_args - + # Phase A: compute every root node's per-edge costs. This (the comm-cost + # estimate over millions of edges) dominates build, is per-node + # independent, and mutates nothing, so it runs across forked workers. + root_idxs = [ + node_idx + for node_idx, node, _ in strats_items + if node.op != "output" and node_idx not in self._cluster_linked_node_idxs + ] + tc0 = time.perf_counter() + node_results = self._compute_node_edge_costs(root_idxs) + t_edge = time.perf_counter() - tc0 + + # Phase B: assemble decision vars (and PuLP variables) from the computed + # costs. Serial because PuLP vars and DecisionVars hold parent-owned + # strategy objects; byte-identical to computing the costs inline. This + # also writes the real costs back into each strat's redistribute_cost + # (overwriting the enumeration dummies) for the cluster batch-copy and + # _compute_solution_cost readers below. + for node_idx, out_data in node_results: + node = self.nodes[node_idx] + op_strategy = self.strats[node] + for out_idx, (per_arg_compute, arg_rows) in enumerate(out_data): + output_strategy = op_strategy.strategies[out_idx] for argi, redist_costs in enumerate(output_strategy.redistribute_cost): - for inp_idx, default_comm_cost in enumerate(redist_costs): - key = (node_idx, argi, out_idx, inp_idx) - - all_input_nodes = self._all_input_nodes(node) - producer_strategy = ( - self.strats[all_input_nodes[argi]] - if all_input_nodes - else None - ) - te0 = time.perf_counter() - comm_cost, transition_cost = self._compute_edge_costs( - node, - output_strategy, - argi, - inp_idx, - default_comm_cost, - producer_strategy, - ) - te1 = time.perf_counter() - t_edge += te1 - te0 - + input_spec = output_strategy.input_specs[argi] + for inp_idx, (comm_cost, transition_cost) in enumerate( + arg_rows[argi] + ): redist_costs[inp_idx] = comm_cost + cost = comm_cost + per_arg_compute + transition_cost + # Prune invalid (infinite-cost) edges: no variable, no + # DecisionVar. A key absent from decision_vars is treated + # as forbidden by both the ILP and the approximate solver. + if not math.isfinite(cost): + n_pruned += 1 + continue + key = (node_idx, argi, out_idx, inp_idx) + if self.build_pulp: + var = pulp.LpVariable( + f"n={node},s={node_idx},arg={argi}," + f"output_p={out_idx},input_p={inp_idx}", + cat=pulp.LpBinary, + ) + self.pulp_variables[key] = var + else: + var = None + self._valid_keys.add(key) decision_vars[key] = DecisionVar( - var=self.pulp_variables[key], - cost=comm_cost + per_arg_compute + transition_cost, + var=var, + cost=cost, compute_cost=per_arg_compute, comm_cost=comm_cost, sharding_transition_cost=transition_cost, strategy=output_strategy, output_spec=output_strategy.output_specs, - input_spec=output_strategy.input_specs[argi], + input_spec=input_spec, ) n_vars += 1 @@ -555,10 +1068,7 @@ def _build_decision_vars(self): # The root pass above updated redistribute_cost in place with # edge-computed costs; linked strats need the same values for # _compute_solution_cost and other readers. - linked_node_to_root_node: dict[int, int] = {} - for linked_key, root_key in self.cluster_links.items(): - linked_node_to_root_node[linked_key[0]] = root_key[0] - for linked_node_idx, root_node_idx in linked_node_to_root_node.items(): + for linked_node_idx, root_node_idx in self.cluster_links.items(): linked_node = self.nodes[linked_node_idx] root_node = self.nodes[root_node_idx] linked_op = self.strats[linked_node] @@ -570,34 +1080,88 @@ def _build_decision_vars(self): list(costs) for costs in root_spec.redistribute_cost ] n_cluster_copied = len(self.cluster_links) - n_vars += n_cluster_copied - self._root_to_linked: dict[tuple, list[tuple]] = defaultdict(list) - for linked_key, root_key in self.cluster_links.items(): - self._root_to_linked[root_key].append(linked_key) + # Root node idx -> [copy node idxs]. Option keys are reconstructed on + # demand (see _linked_option_keys); a copy mirrors its root's per-option + # validity, so no per-option filtering is needed here. + self._root_to_copies = defaultdict(list) + for copy_idx, root_idx in self.cluster_links.items(): + self._root_to_copies[root_idx].append(copy_idx) + t_pulp_end = time.perf_counter() logger.debug( - "_build_decision_vars breakdown (%d vars, %d cluster-copied): " - "pulp_vars=%.3fs, compute_cost=%.3fs, edge_cost=%.3fs", - n_vars, + "_build_decision_vars breakdown (%d vars, %d pruned-inf, %d cluster-copied): " + "build=%.3fs, compute_cost=%.3fs, edge_cost=%.3fs", + len(decision_vars), + n_pruned, n_cluster_copied, t_pulp_end - t_pulp_start, t_compute, t_edge, ) + self._decision_var_profile = { + "logical_decision_variables": n_vars, + "cluster_copied_decision_variables": n_cluster_copied, + "unique_pulp_variables": len(self.pulp_variables), + "pulp_var_creation_s": t_pulp_end - t_pulp_start, + "compute_cost_estimation_s": t_compute, + "edge_cost_estimation_s": t_edge, + "cost_estimation_s": t_compute + t_edge, + } + self.profile["timings"].update( + { + "pulp_var_creation_s": t_pulp_end - t_pulp_start, + "compute_cost_estimation_s": t_compute, + "edge_cost_estimation_s": t_edge, + "cost_estimation_s": t_compute + t_edge, + } + ) return decision_vars + def _compute_node_edge_costs(self, root_idxs): + """Phase A of _build_decision_vars: per-root-node edge costs. Parallel + across forked workers when enabled; workers read this optimizer from the + fork-inherited address space (no pickling of the mesh / strategy graph) + and return only primitive cost tuples. The computation is deterministic, + so the parallel result is byte-identical to the serial path.""" + global _FORK_OPT + _FORK_OPT = self + try: + # Forking a process that has already initialized CUDA crashes the + # workers ("Cannot re-initialize CUDA in forked subprocess") once they + # touch the NCCL cost model. Real-GPU runs (examples, torchrun) and + # any test that has touched CUDA hit this, so fall back to the + # (byte-identical) serial path whenever CUDA is live. + if ( + _PARALLEL_BUILD_WORKERS <= 1 + or len(root_idxs) < 64 + or torch.cuda.is_initialized() + ): + return [_par_node_edge_costs(ni) for ni in root_idxs] + import multiprocessing as mp + + ctx = mp.get_context("fork") + with ctx.Pool(_PARALLEL_BUILD_WORKERS) as pool: + # imap (ordered), not imap_unordered: results come back in + # root_idxs order so decision_vars is assembled in the same node + # order as the serial path. This keeps the PuLP objective's + # lpSum term order identical too, so even the ILP path is + # bit-for-bit unchanged (float addition is not associative). + return list(pool.imap(_par_node_edge_costs, root_idxs, chunksize=4)) + finally: + _FORK_OPT = None + def _resolve_decision_var(self, key): """Return a DecisionVar for key, reconstructing on the fly for linked keys.""" dv = self.decision_vars.get(key) if dv is not None: return dv - root_key = self.cluster_links[key] + root_key = self._cluster_root_key(key) root_dv = self.decision_vars[root_key] node_idx, argi, out_idx, _ = key strategy = self.strats[self.nodes[node_idx]].strategies[out_idx] return DecisionVar( - var=self._get_pulp_variable(key), + var=self._get_pulp_variable(key) if self.pulp_variables else None, cost=root_dv.cost, compute_cost=root_dv.compute_cost, comm_cost=root_dv.comm_cost, @@ -607,6 +1171,28 @@ def _resolve_decision_var(self, key): input_spec=strategy.input_specs[argi], ) + def _find_decision_var(self, node_idx, argi, out_idx): + """Return a DecisionVar for any surviving inp_idx of (node, arg, out), + or None if every edge for that output strategy was pruned. + + compute_cost is identical across inp_idx for a given out_idx, so callers + that only need per-strategy costs can use whichever edge survived. + """ + strategy = self.strats[self.nodes[node_idx]].strategies[out_idx] + n_inp = ( + len(strategy.redistribute_cost[argi]) if strategy.redistribute_cost else 1 + ) + for inp_idx in range(n_inp): + key = (node_idx, argi, out_idx, inp_idx) + if key in self.decision_vars: + return self._resolve_decision_var(key) + if ( + key[0] in self.cluster_links + and self._cluster_root_key(key) in self.decision_vars + ): + return self._resolve_decision_var(key) + return None + def _collect_vars(self, node, node_idx, argi, group_by, resolve_clusters=False): """Collect PuLP variables for a node's options, grouped by strategy index. @@ -619,12 +1205,14 @@ def _collect_vars(self, node, node_idx, argi, group_by, resolve_clusters=False): result = {} for _, out_idx, inp_idx in self.walk_over_options(node, argi): key = (node_idx, argi, out_idx, inp_idx) - if key in self.cluster_links: + if key[0] in self.cluster_links: if not resolve_clusters: continue - var = self.pulp_variables[self.cluster_links[key]] + var = self.pulp_variables.get(self._cluster_root_key(key)) else: - var = self.pulp_variables[key] + var = self.pulp_variables.get(key) + if var is None: # pruned (invalid/infinite-cost) strategy edge + continue group_key = out_idx if group_by == "out_idx" else inp_idx result.setdefault(group_key, []).append(var) return result @@ -635,6 +1223,11 @@ def validate(self): continue if node not in self.strats: continue + # Cluster copies are structurally identical to their root (same + # strategies and input structure, asserted in create_cluster_links), + # so validating the root covers them. + if self.node_map[node] in self.cluster_links: + continue strat = self.strats[node] strat0 = strat.strategies[0] all_input_nodes = self._all_input_nodes(node) @@ -679,7 +1272,9 @@ def add_unique_decision_constraint(self): arg_vars = {} for argi, out_idx, inp_idx in self.walk_over_options(node): key = (node_idx, argi, out_idx, inp_idx) - var = self.pulp_variables[key] + var = self.pulp_variables.get(key) + if var is None: # pruned (invalid) strategy edge + continue arg_vars.setdefault(argi, []).append(var) for eqs in arg_vars.values(): self.prob += ( @@ -703,20 +1298,24 @@ def add_same_output_across_args_constraint(self): continue if len(self._all_input_nodes(node)) <= 1: continue - vars_per_output = {} + # Group vars by (argi, out_idx). Pruning can leave an arg with no + # vars for a given out_idx, so we key explicitly by out_idx rather + # than relying on positional alignment: a missing entry means an + # empty sum (== 0), which correctly forbids that output strategy. + num_args = len(self._all_input_nodes(node)) + vars_per_output: dict[tuple[int, int], list] = {} for argi, out_idx, inp_idx in self.walk_over_options(node): key = (node_idx, argi, out_idx, inp_idx) - var = self.pulp_variables[key] + var = self.pulp_variables.get(key) + if var is None: # pruned (invalid) strategy edge + continue vars_per_output.setdefault((argi, out_idx), []).append(var) - eqs_per_arg = [[] for _ in self._all_input_nodes(node)] - for (argi, out_idx), value in vars_per_output.items(): - eqs_per_arg[argi].append(pulp.lpSum(value)) - arg0 = eqs_per_arg[0] - for arg_eqs in eqs_per_arg[1:]: - assert len(arg0) == len(arg_eqs) - for i in range(len(arg0)): + all_out_idxs = {oi for (_, oi) in vars_per_output} + for out_idx in all_out_idxs: + arg0_eq = pulp.lpSum(vars_per_output.get((0, out_idx), [])) + for argi in range(1, num_args): self.prob += ( - arg0[i] == arg_eqs[i], + arg0_eq == pulp.lpSum(vars_per_output.get((argi, out_idx), [])), self._get_next_name("same_across_args"), ) @@ -790,13 +1389,15 @@ def add_output_input_consistent_constraint(self): ) continue - assert ( - vars_producer.keys() == vars_consumer.keys() - ), f"{vars_producer}, {vars_consumer}" - - for k in vars_producer: + # Pruning can leave a producer output strategy with no matching + # consumer var (the consumer cannot accept that placement) or + # vice versa. Iterate the union and treat a missing side as an + # empty sum (== 0): this forbids the unmatched output strategy, + # exactly as the old inf-cost (== 0) variables did. + for k in vars_producer.keys() | vars_consumer.keys(): self.prob += ( - pulp.lpSum(vars_producer[k]) == pulp.lpSum(vars_consumer[k]), + pulp.lpSum(vars_producer.get(k, [])) + == pulp.lpSum(vars_consumer.get(k, [])), self._get_next_name("output_input_consistent"), ) @@ -805,6 +1406,11 @@ def add_inf_cost_constraint(self): are forced to zero. ∀i,a,o,j: c_{i,a,o,j} = ∞ ⟹ x_{i,a,o,j} = 0 + + Freshly built optimizers prune these edges in _build_decision_vars, so + no variable exists and this is a no-op. It still runs for optimizers + loaded from save files produced before pruning was introduced, whose + decision_vars may still contain infinite-cost entries. """ for key, dv in self.decision_vars.items(): if not math.isfinite(dv.cost): @@ -877,28 +1483,142 @@ def apply_prefetch_discount(self, scale=0.0): # ---- Solution ---- def _set_objective(self): - """Add the cost minimization objective to the ILP.""" + """Add the cost minimization objective to the ILP. + + Idempotent: a no-op if the objective has already been set. This lets the + approximate solver populate ``prob.objective`` (so its assignment can be + scored with ``pulp.value(prob.objective)``) without clobbering or + double-adding it, and keeps repeated get_solution() calls safe. + """ + if self.prob.objective is not None: + return terms = [] for key, dv in self.decision_vars.items(): - multiplier = 1 + len(self._root_to_linked.get(key, [])) + multiplier = 1 + len(self._root_to_copies.get(key[0], ())) terms.append(dv.var * dv.cost * multiplier) self.prob += pulp.lpSum(terms) + def get_lower_bound(self, verbose=False): + """Solve the LP relaxation and return a lower bound on the ILP objective. + + This relaxes the existing binary PuLP variables to continuous variables + in [0, 1], solves the current problem with all constraints already added, + then restores the optimizer state. The result is a certificate only: + fractional LP values are not valid sharding placements. + """ + if self.solver_backend == "dp": + raise NotImplementedError( + "LP relaxation is only available for the PuLP-backed optimizer" + ) + + t0 = time.perf_counter() + old_objective = self.prob.objective + old_status = self.prob.status + old_sol_status = getattr(self.prob, "sol_status", None) + old_selected_keys_marker = object() + old_selected_keys = getattr(self, "selected_keys", old_selected_keys_marker) + var_states = { + var: (var.cat, var.lowBound, var.upBound, var.varValue) + for var in self.pulp_variables.values() + } + + try: + if self.prob.objective is None: + self._set_objective() + # The relaxation must include the parameter-memory constraint, or it + # is a lower bound on a different (unconstrained) problem and can fall + # below the true ILP optimum. + self._apply_memory_constraint() + + for var in self.pulp_variables.values(): + var.cat = pulp.LpContinuous + var.lowBound = 0 + var.upBound = 1 + var.varValue = None + + solver = pulp.PULP_CBC_CMD(msg=verbose) + t_solve0 = time.perf_counter() + with tempfile.TemporaryDirectory() as tmpdir: + solver.tmpDir = tmpdir + self.prob.solve(solver) + solve_s = time.perf_counter() - t_solve0 + + status = pulp.LpStatus.get(self.prob.status, self.prob.status) + objective = self._safe_float(pulp.value(self.prob.objective)) + result = LPRelaxationResult( + objective=objective, + status=status, + solve_s=solve_s, + total_s=time.perf_counter() - t0, + ) + self.profile["last_lp_relaxation"] = { + "objective": result.objective, + "status": result.status, + "solve_s": result.solve_s, + "total_s": result.total_s, + } + logger.info( + "ShardingOptimizer LP relaxation profile: " + "mesh_shape=%s mesh_dim_names=%s mesh_size=%s model_params=%s " + "unique_ilp_vars=%s constraints=%s status=%s objective=%.4f " + "timings={solve=%.3fs,total=%.3fs}", + self.profile["mesh"]["shape"], + self.profile["mesh"]["dim_names"], + self.profile["mesh"]["size"], + self._format_billions(self.profile["model"]["parameter_numel"]), + len(self.pulp_variables), + len(self.prob.constraints), + result.status, + result.objective, + result.solve_s, + result.total_s, + ) + return result + finally: + for var, (cat, low_bound, up_bound, value) in var_states.items(): + var.cat = cat + var.lowBound = low_bound + var.upBound = up_bound + var.varValue = value + self.prob.objective = old_objective + self.prob.status = old_status + if old_sol_status is None: + if hasattr(self.prob, "sol_status"): + delattr(self.prob, "sol_status") + else: + self.prob.sol_status = old_sol_status + if old_selected_keys is old_selected_keys_marker: + if hasattr(self, "selected_keys"): + delattr(self, "selected_keys") + else: + self.selected_keys = old_selected_keys + def _solve(self, verbose=False): self._apply_memory_constraint() - solver = pulp.PULP_CBC_CMD(msg=verbose) + # The sharding ILP has a near-totally-unimodular (flow-like) structure: + # CBC's LP relaxation is naturally integral, so it solves in seconds + # with zero branch-and-bound. CBC's integer *preprocessing* (probing, + # substitutions over hundreds of thousands of binary columns) is then + # pure overhead — it dominates the solve. Disabling it (correctness is + # unaffected; CBC still does full branch-and-bound if the relaxation is + # fractional) makes the solve ~10x faster on large graphs. + # Pass as a single string: PuLP prefixes each options entry with "-", + # so this becomes the CBC flag "-preprocess off". + solver = pulp.PULP_CBC_CMD(msg=verbose, options=["preprocess off"]) # Use a dedicated temp directory for PuLP's intermediate files (.mps, # .sol, etc.) so they are always cleaned up, even if the process is # killed. Without this, leftover files can fill up /tmp (tmpfs). + t0 = time.perf_counter() with tempfile.TemporaryDirectory() as tmpdir: solver.tmpDir = tmpdir self.prob.solve(solver) + solve_s = time.perf_counter() - t0 self.selected_keys = [ key for key, dv in self.decision_vars.items() if dv.var.value() == 1 ] for root_key in list(self.selected_keys): - self.selected_keys.extend(self._root_to_linked.get(root_key, [])) + self.selected_keys.extend(self._linked_option_keys(root_key)) if self.prob.status == -1: logger.warning(self.get_violated_constraints_log()) @@ -910,6 +1630,132 @@ def _solve(self, verbose=False): "constraints, and consider relaxing input/output constraints or " "using a larger mesh." ) + return solve_s + + def _log_solve_profile( + self, + solve_kind, + objective_value, + objective_s, + solve_s, + extract_s, + total_s, + ): + # Optimizers loaded from a save file skip init-time profiling; there is + # nothing to extend, and the phase timings below are absent. + profile = getattr(self, "profile", None) + if not profile or "init_total_s" not in profile.get("timings", {}): + return + mesh = self.profile["mesh"] + model = self.profile["model"] + timings = self.profile["timings"] + status = pulp.LpStatus.get(self.prob.status, self.prob.status) + pipeline_total_s = timings["init_total_s"] + total_s + logger.info( + "ShardingOptimizer %s profile: " + "mesh_shape=%s mesh_dim_names=%s mesh_size=%s model_params=%s " + "unique_ilp_vars=%s constraints=%s status=%s objective=%.4f " + "timings={strategy_enumeration=%.3fs,cost_estimation=%.3fs," + "ilp_construction=%.3fs,objective=%.3fs,solve=%.3fs," + "extract=%.3fs,total_solve_call=%.3fs,total_pipeline=%.3fs}", + solve_kind, + mesh["shape"], + mesh["dim_names"], + mesh["size"], + self._format_billions(model["parameter_numel"]), + len(self.pulp_variables), + len(self.prob.constraints), + status, + objective_value, + timings["strategy_enumeration_s"], + timings["cost_estimation_s"], + timings["ilp_construction_s"], + objective_s, + solve_s, + extract_s, + total_s, + pipeline_total_s, + ) + self.profile["last_solve"] = { + "kind": solve_kind, + "objective": objective_value, + "status": status, + "constraints": len(self.prob.constraints), + "unique_variables": len(self.pulp_variables), + "objective_s": objective_s, + "solve_s": solve_s, + "extract_s": extract_s, + "total_s": total_s, + "pipeline_total_s": pipeline_total_s, + } + logger.debug("ShardingOptimizer solve profile detail: %s", self.profile) + + def solve_lp_relaxation(self, verbose=False, frac_tol=1e-6, extract=False): + """Solve the continuous relaxation of the ILP (binary variables relaxed + to [0, 1]) and report diagnostics, restoring the binary categories on + exit so a later ILP solve is unaffected. + + Returns a dict with the relaxation objective (a lower bound on the ILP + optimum), the solve time, the number/fraction of decision variables that + came out fractional, and the solver status. This is the lens for + understanding why constraints (e.g. propagated annotations) speed up the + ILP: a relaxation that is tighter (objective closer to the ILP optimum) + and less fractional leaves branch-and-bound far less work. + + For this sharding problem the relaxation is empirically integral, so the + relaxation optimum equals the ILP optimum. With ``extract=True`` and an + integral solution, the dict also contains a ``"solution"`` key with the + per-node strategy dict (same form as :meth:`get_solution`) — i.e. the LP + relaxation can be used as a much cheaper exact solve, skipping + branch-and-bound. ``"solution"`` is ``None`` when the relaxation came + out fractional. + + Requires the objective to have been set (e.g. via a prior get_solution, + or _set_objective). + """ + variables = self.prob.variables() + original_cats = [v.cat for v in variables] + self._apply_memory_constraint() + t0 = time.perf_counter() + try: + for v in variables: + v.cat = pulp.LpContinuous # bounds are already [0, 1] for binaries + solver = pulp.PULP_CBC_CMD(msg=verbose) + with tempfile.TemporaryDirectory() as tmpdir: + solver.tmpDir = tmpdir + self.prob.solve(solver) + solve_time = time.perf_counter() - t0 + objective = pulp.value(self.prob.objective) + n_fractional = 0 + n_vars = 0 + for v in variables: + val = v.value() + if val is None: + continue + n_vars += 1 + if min(val, 1.0 - val) > frac_tol: + n_fractional += 1 + solution = None + if extract and n_fractional == 0: + self.selected_keys = [ + key + for key, dv in self.decision_vars.items() + if dv.var.value() is not None and dv.var.value() > 0.5 + ] + for root_key in list(self.selected_keys): + self.selected_keys.extend(self._linked_option_keys(root_key)) + solution = self._to_orig_solution(self._extract_and_validate_solution()) + finally: + for v, cat in zip(variables, original_cats): + v.cat = cat + return { + "objective": objective, + "solve_time": solve_time, + "n_fractional": n_fractional, + "n_vars": n_vars, + "status": pulp.LpStatus[self.prob.status], + "solution": solution, + } def _extract_and_validate_solution(self): """Validate the ILP solution and return the optimal strategy per node.""" @@ -953,14 +1799,30 @@ def _to_concrete_solution(self, solution): return {self._orig_to_concrete[node]: spec for node, spec in solution.items()} def get_solution(self, verbose=False): + if self.solver_backend == "dp": + return self.solver.get_solution(verbose=verbose) + t0 = time.perf_counter() + t_objective0 = time.perf_counter() self._set_objective() - self._solve(verbose) - obj_value = pulp.value(self.prob.objective) + t_objective1 = time.perf_counter() + solve_s = self._solve(verbose) + obj_value = self._safe_float(pulp.value(self.prob.objective)) + t_extract0 = time.perf_counter() + solution = self._to_orig_solution(self._extract_and_validate_solution()) + t_extract1 = time.perf_counter() logger.debug( "ILP solve took %.3fs (objective=%.4f)", time.perf_counter() - t0, obj_value ) - return self._to_orig_solution(self._extract_and_validate_solution()) + self._log_solve_profile( + "solve", + obj_value, + t_objective1 - t_objective0, + solve_s, + t_extract1 - t_extract0, + t_extract1 - t0, + ) + return solution def resolve(self, verbose=False): """Re-solve the ILP after adding or removing constraints. @@ -969,14 +1831,25 @@ def resolve(self, verbose=False): be called multiple times after modifying constraints. """ t0 = time.perf_counter() - self._solve(verbose) - obj_value = pulp.value(self.prob.objective) + solve_s = self._solve(verbose) + obj_value = self._safe_float(pulp.value(self.prob.objective)) + t_extract0 = time.perf_counter() + solution = self._to_orig_solution(self._extract_and_validate_solution()) + t_extract1 = time.perf_counter() logger.debug( "ILP re-solve took %.3fs (objective=%.4f)", time.perf_counter() - t0, obj_value, ) - return self._to_orig_solution(self._extract_and_validate_solution()) + self._log_solve_profile( + "re-solve", + obj_value, + 0.0, + solve_s, + t_extract1 - t_extract0, + t_extract1 - t0, + ) + return solution def remove_constraints(self, names): """Remove constraints by name, allowing re-solve to revert to the @@ -1072,8 +1945,12 @@ def _compute_solution_cost(self, solution): # Use pre-computed costs from decision vars instead of # estimate_strategy_runtime_cost, which needs node.meta["val"] - # (absent on loaded optimizers). - dv = self._resolve_decision_var((node_idx, 0, out_idx, 0)) + # (absent on loaded optimizers). The (.,0,out_idx,0) edge may be + # pruned, so find any surviving inp_idx for arg 0 (compute_cost is + # identical across inp_idx for a given out_idx). + dv = self._find_decision_var(node_idx, 0, out_idx) + if dv is None: + continue num_args = max(len(strategy.input_specs), 1) total_compute += dv.compute_cost * num_args @@ -1113,6 +1990,8 @@ def _compute_solution_cost(self, solution): # ---- Logging ---- def get_violated_constraints_log(self): + if self.prob is None: + return "Violated constraints: [] (no PuLP problem; lite build)" violated_constraints = [ (k, c) for k, c in self.prob.constraints.items() if not c.valid() ] @@ -1158,10 +2037,8 @@ def get_json(self): # Build node-level cluster mapping: linked_node -> root_node cluster_roots: dict[torch.fx.Node, torch.fx.Node] = {} - for linked_key, root_key in self.cluster_links.items(): - linked_node = self.nodes[linked_key[0]] - root_node = self.nodes[root_key[0]] - cluster_roots[linked_node] = root_node + for copy_idx, root_idx in self.cluster_links.items(): + cluster_roots[self.nodes[copy_idx]] = self.nodes[root_idx] _normalize_cluster_layer(cluster_roots) @@ -1408,6 +2285,8 @@ def _add_node_constraint( for argi, out_idx, inp_idx in self.walk_over_options(node): if out_idx in output_constraint_indices: var = self._get_pulp_variable((node_idx, argi, out_idx, inp_idx)) + if var is None: # pruned (invalid) strategy edge + continue vars_per_arg.setdefault(argi, []).append(var) names = [] for eqs in vars_per_arg.values(): @@ -1435,8 +2314,10 @@ def _add_paired_output_constraint(self, node_a, node_b, constraint_name): # This placement exists in node_a but not in node_b. # Disable it: force sum of its decision variables to 0. v_a = [ - self._get_pulp_variable((idx_a, 0, out_idx, inp_idx)) + v for inp_idx in range(num_inp_a) + if (v := self._get_pulp_variable((idx_a, 0, out_idx, inp_idx))) + is not None ] self.prob += ( pulp.lpSum(v_a) == 0, @@ -1445,12 +2326,16 @@ def _add_paired_output_constraint(self, node_a, node_b, constraint_name): continue out_idx_b = strat_b.index(sp) v_a = [ - self._get_pulp_variable((idx_a, 0, out_idx, inp_idx)) + v for inp_idx in range(num_inp_a) + if (v := self._get_pulp_variable((idx_a, 0, out_idx, inp_idx))) + is not None ] v_b = [ - self._get_pulp_variable((idx_b, 0, out_idx_b, inp_idx)) + v for inp_idx in range(num_inp_b) + if (v := self._get_pulp_variable((idx_b, 0, out_idx_b, inp_idx))) + is not None ] self.prob += ( pulp.lpSum(v_b) == pulp.lpSum(v_a), @@ -1634,6 +2519,8 @@ def _apply_memory_constraint(self): """ if self._memory_constraint is None: return + if self.prob is None: + return # approx (lite) build reads the factors from _constraint_log memory_factor_low, memory_factor_high = self._memory_constraint # Remove previous memory constraints before rebuilding @@ -1651,9 +2538,18 @@ def _apply_memory_constraint(self): continue node_idx = self.node_map[node] num_out_strat = len(self.strats[node].strategies) + # Per-axis constraints restrict which strategies this param may use, + # which raises its best achievable memory ratio (e.g. a param pinned + # to Replicate on the tensor axis can no longer be sharded there). + # The budget must reflect that, or it would under-allocate and make + # the problem spuriously infeasible. + axis_constraints = self._node_axis_constraints.get(node.name, []) ratios: list[float] = [] + allowed_ratios: list[float] = [] for out_idx in range(num_out_strat): - dv = self._resolve_decision_var((node_idx, 0, out_idx, 0)) + dv = self._find_decision_var(node_idx, 0, out_idx) + if dv is None: # every edge for this strategy was pruned + continue spec: DTensorSpec = dv.input_spec assert spec.tensor_meta is not None tensor_shape: torch.Size = spec.tensor_meta.shape @@ -1663,7 +2559,12 @@ def _apply_memory_constraint(self): ratio = new_size / old_size ratios.append(ratio) elms.append(dv.var * ratio) - best_ratio: float = min(ratios) + out_spec = self.strats[node].strategies[out_idx].output_specs + if isinstance(out_spec, DTensorSpec) and all( + out_spec.placements[m] == p for m, p in axis_constraints + ): + allowed_ratios.append(ratio) + best_ratio: float = min(allowed_ratios) if allowed_ratios else min(ratios) budget_low += max(best_ratio, memory_factor_low) budget_high += max(best_ratio, memory_factor_high) @@ -1708,6 +2609,8 @@ def add_node_constraint(self, node, placement=None, constraint_name=None): raise RuntimeError( f"Couldn't find appropriate constraint {node} {constraint_name} {placement}" ) + if self.prob is None: + return [] # approx (lite) build replays this from _constraint_log names = self._add_node_constraint( node, output_constraint_indices=output_constraint_indices, @@ -1717,6 +2620,90 @@ def add_node_constraint(self, node, placement=None, constraint_name=None): self._node_constraint_names[name] = node.name return names + def add_node_axis_constraint( + self, node, mesh_dim, placement, constraint_name=None, method="constraint" + ): + """Force a node's output placement on a single mesh axis, leaving the + other axes free for the ILP. + + This is the per-mesh-axis analogue of :meth:`add_node_constraint` and is + what sharding propagation emits: it can pin the tensor-parallel axis of a + weight while leaving the data axis open for FSDP. Unlike + :meth:`add_node_constraint` it does *not* register the node in + ``_node_constraint_names``, so a partially-constrained parameter is still + counted by the memory budget and can be sharded on its free axes. + + ``method`` controls how the pin is enforced: + + * ``"constraint"`` adds an ``== 1`` equality over the matching decision + variables (removable by name via :meth:`remove_constraints`). + * ``"fix"`` instead sets the upper bound of the *non-matching* decision + variables to 0. This shrinks the problem (the solver's presolve drops + fixed columns) rather than adding a row, which scales much better on + large meshes where adding thousands of equality rows otherwise slows + the solve. It is not removable by constraint name. + + For nodes with tuple output_specs the placement is matched against the + first DTensorSpec element, matching :meth:`add_node_constraint`. + """ + node = self._normalize_node(node) + if constraint_name is None: + constraint_name = "axis_constraint" + self._constraint_log.append( + ( + "add_node_axis_constraint", + { + "node_name": node.name, + "mesh_dim": mesh_dim, + "placement": placement, + "constraint_name": constraint_name, + "method": method, + }, + ) + ) + assert node in self.strats, (node, self.strats.keys()) + strat = self.strats[node] + output_constraint_indices = [] + for i, s in enumerate(strat.strategies): + specs = s.output_specs + spec = None + if isinstance(specs, DTensorSpec): + spec = specs + elif isinstance(specs, (list, tuple)): + spec = next((x for x in specs if isinstance(x, DTensorSpec)), None) + if spec is not None and spec.placements[mesh_dim] == placement: + output_constraint_indices.append(i) + if len(output_constraint_indices) == 0: + raise RuntimeError( + f"Couldn't find a strategy for {node} with {placement} on mesh " + f"dim {mesh_dim} (constraint {constraint_name})" + ) + self._node_axis_constraints[node.name].append((mesh_dim, placement)) + if method == "fix": + self._fix_node_output_indices(node, set(output_constraint_indices)) + return [] + if self.prob is None: + return [] # approx (lite) build replays this from _constraint_log + return self._add_node_constraint( + node, + output_constraint_indices=output_constraint_indices, + constraint_name=constraint_name, + ) + + def _fix_node_output_indices(self, node, keep_out_idxs): + """Pin a node's output strategy by fixing every decision variable whose + out_idx is not in ``keep_out_idxs`` to 0 (upper bound).""" + node_idx = self.node_map[node] + for argi, out_idx, inp_idx in self.walk_over_options(node): + if out_idx in keep_out_idxs: + continue + var = self._get_pulp_variable((node_idx, argi, out_idx, inp_idx)) + if var is None: # pruned (invalid) strategy edge, or lite (no-PuLP) build + continue + if var.upBound != 0: + var.upBound = 0 + self._fixed_vars.append(var) + def _add_io_placement_constraints( self, nodes_dict, diff --git a/autoparallel/propagation.py b/autoparallel/propagation.py new file mode 100644 index 00000000..ae1e5366 --- /dev/null +++ b/autoparallel/propagation.py @@ -0,0 +1,487 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shardy-like sharding propagation to seed and shrink the ILP search space. + +The ILP in :mod:`optimize_sharding` enumerates, for every node, every valid +combination of input/output placements and lets the solver pick the global +optimum. For large models this search space is enormous even though, in +practice, a handful of user decisions ("these weights are tensor-parallel", +"the batch is data-parallel") already pin down the strategy for the vast +majority of the graph. + +This module lets the user attach a small number of *sharding annotations* and +then propagates them through the graph the way `Shardy +`_ does: it pushes each known sharding along +edges that require no resharding, narrowing every node's set of candidate +strategies until the unambiguous nodes are fully determined. Determined nodes +are turned into ILP constraints, which collapses the search space and the solve +time while leaving the genuinely ambiguous decisions (and where to place the +necessary collectives) to the ILP. + +Key design points that mirror Shardy: + +* **Per-mesh-axis propagation.** A placement is propagated one mesh axis at a + time. This is what lets, e.g., the tensor-parallel sharding of a weight flow + through a matmul on the ``tp`` axis while the ``dp`` axis is independently + resolved (data-parallel batch, with FSDP all-gathers left to the ILP). It is + the analogue of Shardy projecting tensor shardings onto per-factor axes. +* **Conservative, reshard-free propagation.** Along an edge we only narrow a + consumer to the placements it can take *without* a reshard from the producer + (zero ``redistribute_cost``). At a genuine reshard boundary (a necessary + collective, e.g. an all-reduce or all-gather) no zero-cost option exists, so + propagation stops there and the ILP decides the collective. This never + empties a domain. +* **Priority rounds.** Annotations carry a priority (lower = applied first, + matching Shardy). Data/activation annotations propagate before weight + annotations so that, where they compete for the same mesh axis (the ``dp`` + axis of a matmul), the data-parallel sharding wins and the weight is + all-gathered rather than the activation being resharded. +""" + +import logging +from collections import defaultdict, deque +from dataclasses import dataclass, field +from typing import Optional + +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Placement + +logger = logging.getLogger(__name__) + +# A per-axis placement value; ``None`` means "open" (unconstrained on that axis). +AxisPlacement = Optional[Placement] + + +@dataclass(frozen=True) +class ShardingAnnotation: + """A user-provided sharding hint for one tensor (graph node). + + Args: + placements: one entry per mesh dimension. Each entry is a + :class:`Placement` (e.g. ``Shard(0)``, ``Replicate()``) or ``None`` + to leave that mesh axis open for propagation / the ILP to decide. + Leaving an axis open is the common case for weights: the user pins + the tensor-parallel axis and lets FSDP on the data axis be chosen by + the optimizer. + priority: lower numbers are propagated first. Activation/IO hints + should have a smaller priority than weight hints so the + data-parallel axis wins shared-axis conflicts. + """ + + placements: tuple[AxisPlacement, ...] + priority: int = 0 + + +# Micro-strategy: a single strategy projected onto one mesh axis. +# ``in_reqs`` is the per-axis input placement required for each tensor argument +# (``None`` for non-tensor / undefined args); ``out`` is the per-axis output +# placement produced. +@dataclass(frozen=True) +class _Micro: + in_reqs: tuple[AxisPlacement, ...] + out: AxisPlacement + + +@dataclass +class PropagationResult: + """Summary of a propagation run, for logging and tests.""" + + determined: dict = field(default_factory=dict) # node -> [(mesh_dim, placement)] + strategies_before: int = 0 + strategies_after: int = 0 + nodes_touched: int = 0 + nodes_determined: int = 0 + axis_constraints: int = 0 + + @property + def reduction(self) -> float: + if self.strategies_before == 0: + return 0.0 + return 1.0 - self.strategies_after / self.strategies_before + + +class ShardingPropagator: + """Propagates sharding annotations over an optimizer's strategy graph. + + The propagator works on the optimizer's concrete graph and reuses its + per-node ``OpStrategy`` list (``optimizer.strats``) as the per-op sharding + rules. It maintains, for every single-output node and every mesh axis, the + set of still-feasible per-axis (input-requirement, output) micro-strategies + and shrinks them to a fixed point. + """ + + def __init__(self, optimizer): + self.opt = optimizer + self.mesh = optimizer.mesh + self.ndim = optimizer.mesh.ndim + + # node -> list (indexed by mesh dim) of list[_Micro] + self.micros: dict = {} + # node -> list (indexed by mesh dim) of set[int] (feasible micro indices) + self.dom: dict = {} + # nodes whose domain has been narrowed below the initial full set + self.touched: set = set() + self._initial_strategy_count: dict = {} + + self._build_micros() + + # ---- construction ---- + + def _build_micros(self): + for node, op_strat in self.opt.strats.items(): + if node.op == "output": + continue + strategies = op_strat.strategies + if not strategies: + continue + # Multi-output nodes (tuple output_specs, e.g. SDPA) are propagation + # barriers: there is no single output placement to project, so we + # neither narrow them nor propagate across them. Their getitem + # users are single-output and handled normally. + if not isinstance(strategies[0].output_specs, DTensorSpec): + continue + + args = self.opt._all_input_nodes(node) + n_args = len(args) + self._initial_strategy_count[node] = len(strategies) + + per_axis_index: list = [dict() for _ in range(self.ndim)] + per_axis_micros: list = [[] for _ in range(self.ndim)] + for s in strategies: + out_pl = s.output_specs.placements + in_pls = [] + for a in range(n_args): + isp = s.input_specs[a] if a < len(s.input_specs) else None + in_pls.append( + isp.placements if isinstance(isp, DTensorSpec) else None + ) + for m in range(self.ndim): + in_reqs = tuple(None if pl is None else pl[m] for pl in in_pls) + micro = _Micro(in_reqs=in_reqs, out=out_pl[m]) + idx = per_axis_index[m] + if micro not in idx: + idx[micro] = len(per_axis_micros[m]) + per_axis_micros[m].append(micro) + self.micros[node] = per_axis_micros + self.dom[node] = [ + set(range(len(per_axis_micros[m]))) for m in range(self.ndim) + ] + + # ---- accessors ---- + + def _out_set(self, node, m) -> set: + micros = self.micros[node][m] + return {micros[i].out for i in self.dom[node][m]} + + def _in_req_set(self, node, m, a) -> set: + micros = self.micros[node][m] + return {micros[i].in_reqs[a] for i in self.dom[node][m]} + + def _consumer_edges(self, node): + """Yield (consumer, arg_index) for each tensor edge out of ``node``.""" + for user in node.users: + if user not in self.dom: + continue + in_nodes = self.opt._all_input_nodes(user) + for a, src in enumerate(in_nodes): + if src is node: + yield user, a + + # ---- seeding ---- + + def seed(self, node, placements: tuple) -> bool: + node = self.opt._normalize_node(node) + if node not in self.dom: + logger.debug("seed: %s is not a single-output node, ignoring", node) + return False + changed = False + for m in range(self.ndim): + want = placements[m] if m < len(placements) else None + if want is None: + continue + micros = self.micros[node][m] + # Seeding is authoritative: recompute from the full strategy set so a + # user annotation overrides any earlier (lower-priority) propagation + # that may have narrowed this axis away from the annotated value. + keep = {i for i in range(len(micros)) if micros[i].out == want} + if not keep: + available = {micros[i].out for i in range(len(micros))} + raise ValueError( + f"Annotation {placements} is not achievable for node " + f"{node} on mesh dim {m}: this op only supports " + f"{available} on that axis" + ) + if keep != self.dom[node][m]: + self.dom[node][m] = keep + changed = True + if changed: + self.touched.add(node) + return changed + + # ---- narrowing ---- + + def _narrow_from_producers(self, node) -> bool: + """Narrow ``node`` (as a consumer) toward reshard-free inputs.""" + changed = False + args = self.opt._all_input_nodes(node) + for a, producer in enumerate(args): + if producer not in self.dom: + continue # barrier or non-tensor producer + for m in range(self.ndim): + prod_outs = self._out_set(producer, m) + cur = self.dom[node][m] + micros = self.micros[node][m] + keep = {i for i in cur if micros[i].in_reqs[a] in prod_outs} + # Only tighten when a zero-reshard option exists; an empty keep + # means this edge is a genuine reshard boundary -> leave it to + # the ILP. + if keep and keep != cur: + self.dom[node][m] = keep + changed = True + return changed + + def _narrow_from_consumer(self, node) -> bool: + """Narrow ``node`` (as a producer) toward what its single consumer wants. + + Restricted to single-consumer producers: a multi-consumer value (e.g. a + residual stream) may legitimately be resharded for some consumers, so we + do not let one consumer dictate it. + + Placeholders (parameters, buffers, graph inputs) are never narrowed this + way: their placement is the *stored* sharding, which legitimately differs + from the *compute* sharding the consumer needs by a reshard (e.g. an FSDP + all-gather on the data axis). Inferring the storage sharding from the + consumer would wrongly pin, e.g., a weight to Replicate on the data axis + and defeat FSDP. A placeholder's sharding comes only from its own + annotation; everything else about it is left to the ILP. + """ + if node.op in ("placeholder", "get_attr"): + return False + edges = list(self._consumer_edges(node)) + if len(edges) != 1: + return False + consumer, a = edges[0] + changed = False + for m in range(self.ndim): + cons_reqs = self._in_req_set(consumer, m, a) + cur = self.dom[node][m] + micros = self.micros[node][m] + keep = {i for i in cur if micros[i].out in cons_reqs} + if keep and keep != cur: + self.dom[node][m] = keep + changed = True + return changed + + def _narrow_node(self, node) -> bool: + c1 = self._narrow_from_producers(node) + c2 = self._narrow_from_consumer(node) + changed = c1 or c2 + if changed: + self.touched.add(node) + return changed + + def propagate(self): + """Run the worklist narrowing to a fixed point.""" + wl = deque(self.dom.keys()) + inq = set(self.dom.keys()) + steps = 0 + while wl: + node = wl.popleft() + inq.discard(node) + steps += 1 + if not self._narrow_node(node): + continue + # Re-enqueue neighbors whose domains may now narrow further. + neighbors = list(self.opt._all_input_nodes(node)) + neighbors += [u for u in node.users] + for nb in neighbors: + if nb in self.dom and nb not in inq: + wl.append(nb) + inq.add(nb) + logger.debug("propagation fixpoint reached in %d worklist steps", steps) + + # ---- results ---- + + def determined(self) -> dict: + """node -> list[(mesh_dim, placement)] for every determined axis of a + node that propagation actually touched.""" + res = {} + for node in self.dom: + if node not in self.touched: + continue + axes = [] + for m in range(self.ndim): + outs = self._out_set(node, m) + if len(outs) == 1: + axes.append((m, next(iter(outs)))) + if axes: + res[node] = axes + return res + + def _feasible_strategy_count(self, node, determined_axes) -> int: + """How many of ``node``'s strategies satisfy all determined axes.""" + strategies = self.opt.strats[node].strategies + count = 0 + for s in strategies: + spec = s.output_specs + if not isinstance(spec, DTensorSpec): + count += 1 + continue + if all(spec.placements[m] == p for m, p in determined_axes): + count += 1 + return count + + def run(self, annotations) -> dict: + """Seed ``annotations`` in priority order and propagate to a fixed point. + + ``annotations`` is a list of ``(node, ShardingAnnotation)``. Returns the + ``determined()`` mapping. + """ + by_priority: dict = defaultdict(list) + for node, ann in annotations: + by_priority[ann.priority].append((node, ann)) + for priority in sorted(by_priority): + for node, ann in by_priority[priority]: + self.seed(node, ann.placements) + self.propagate() + return self.determined() + + def _paired_boundary_nodes(self) -> set: + """Backward nodes tied to a forward node by a forward/backward + consistency constraint: parameter gradients, input gradients, and output + tangents. These must be left to the pairing (which mirrors the forward + decision onto them); constraining them independently can contradict it. + """ + from torch._functorch._aot_autograd.fx_utils import ( + get_param_and_grad_nodes, + get_plain_input_and_grad_nodes, + get_plain_output_and_tangent_nodes, + ) + + graph = self.opt.graph + nodes = set() + for _p, grad in get_param_and_grad_nodes(graph).values(): + if grad is not None: + nodes.add(grad) + for _i, grad in get_plain_input_and_grad_nodes(graph).values(): + if grad is not None: + nodes.add(grad) + for _o, tangent in get_plain_output_and_tangent_nodes(graph).values(): + if tangent is not None: + nodes.add(tangent) + return nodes + + def _backward_node_set(self) -> set: + """Nodes belonging to the backward pass: everything reachable from a + tangent (incoming-gradient) placeholder. + + Propagation does not constrain these. Their sharding is tied to the + forward pass by the optimizer's forward/backward consistency constraints + (param<->grad, input<->grad, output<->tangent), so constraining them + independently risks contradicting that pairing (e.g. forcing a weight's + gradient to a placement its parameter cannot take). Leaving them to the + ILP keeps the problem feasible while the forward constraints already + collapse most of the backward search space through the pairing. + """ + seeds = [ + n + for n in self.opt.graph.nodes + if n.op == "placeholder" and n.name.startswith("tangents") + ] + backward = set() + stack = list(seeds) + while stack: + n = stack.pop() + for u in n.users: + if u not in backward: + backward.add(u) + stack.append(u) + return backward + + def _total_strategy_count(self) -> int: + total = 0 + for node, op_strat in self.opt.strats.items(): + if node.op == "output": + continue + total += len(op_strat.strategies) + return total + + def apply_to_optimizer( + self, forward_only=False, aggressive=False, method="fix" + ) -> PropagationResult: + """Emit per-axis constraints for every determined axis of every touched + node and return a summary of the search-space reduction. + + Nodes the user already constrained explicitly are skipped, as are the + forward/backward *paired boundary* nodes (parameter/input gradients and + output tangents), whose sharding is decided by the pairing rather than + propagation. When ``forward_only`` is set, all backward-pass nodes are + skipped (more conservative; only the forward graph is constrained). A + node is also skipped if its determined axes do not co-occur in any single + strategy (a safety net, not expected in practice). + + By default (``aggressive=False``) an axis is only pinned when it is a + genuine ``Shard``. A Shard encodes the tensor-parallel structure the + annotations describe and is invariant in the optimum. ``Replicate`` and + ``Partial`` are deliberately *not* pinned: + + * Pinning ``Replicate`` would forbid the ILP from instead sharding that + axis (e.g. choosing sequence parallelism on the residual stream). + * ``Partial`` is a pending reduction whose collective (all-reduce / + reduce-scatter) the ILP places; pinning it fixes where the reduction + happens and can even be infeasible (a Partial value cannot be added to + a Replicate residual without first reducing it). + + Both are genuine cost tradeoffs, so leaving them open keeps the optimum + reachable while costing little search-space reduction. + + ``method`` is forwarded to :meth:`ShardingOptimizer.add_node_axis_constraint`: + ``"fix"`` (default) removes the ruled-out decision variables so the + problem actually shrinks, ``"constraint"`` adds equality rows instead. + """ + determined = self.determined() + already = set(self.opt._node_constraint_names.values()) + excluded = self._paired_boundary_nodes() + if forward_only: + excluded |= self._backward_node_set() + + result = PropagationResult(determined=determined) + result.strategies_before = self._total_strategy_count() + result.nodes_touched = len(self.touched) + + strategies_saved = 0 + for node, axes in determined.items(): + if node.name in already or node in excluded: + continue + pin_axes = [(m, p) for m, p in axes if aggressive or p.is_shard()] + if not pin_axes: + continue + full = len(self.opt.strats[node].strategies) + feasible = self._feasible_strategy_count(node, pin_axes) + if feasible == 0 or feasible == full: + continue + for m, p in pin_axes: + self.opt.add_node_axis_constraint( + node, m, p, constraint_name="propagated", method=method + ) + result.axis_constraints += 1 + result.nodes_determined += 1 + strategies_saved += full - feasible + + result.strategies_after = result.strategies_before - strategies_saved + logger.info( + "propagation: touched %d nodes, constrained %d nodes with %d " + "per-axis constraints; output-strategy choices %d -> %d (%.1f%% " + "reduction)", + result.nodes_touched, + result.nodes_determined, + result.axis_constraints, + result.strategies_before, + result.strategies_after, + 100.0 * result.reduction, + ) + return result diff --git a/autoparallel/serialization.py b/autoparallel/serialization.py index 46cb3fde..c9474746 100644 --- a/autoparallel/serialization.py +++ b/autoparallel/serialization.py @@ -135,7 +135,7 @@ def save_optimizer(opt, path): # Re-key strats by node name, saving only root nodes (non-linked). # Linked nodes share identical strats with their root and are # reconstructed on load from cluster_links. - linked_node_names = {opt.nodes[lk[0]].name for lk in opt.cluster_links} + linked_node_names = {opt.nodes[c].name for c in opt.cluster_links} strats_by_name = { node.name: strat for node, strat in opt.strats.items() @@ -193,8 +193,7 @@ def save_optimizer(opt, path): "dv_costs_keys": dv_costs_keys, "dv_costs_vals": dv_costs_vals, "cluster_links_node_by_name": { - opt.nodes[lk[0]].name: opt.nodes[rk[0]].name - for lk, rk in opt.cluster_links.items() + opt.nodes[c].name: opt.nodes[r].name for c, r in opt.cluster_links.items() }, "constraint_log": opt._constraint_log, "selected_keys_by_name": selected_keys_by_name, @@ -257,35 +256,45 @@ def load_optimizer(cls, path): opt.strats = strats opt.nodes = list(strats.keys()) opt.node_map = {node: i for i, node in enumerate(opt.nodes)} + opt.solver_backend = "ilp" opt.force_grad_reduce_in_higher_precision = save_dict[ "force_grad_reduce_in_higher_precision" ] opt._constraint_log = [] opt._memory_constraint = None opt._node_constraint_names = {} + opt._node_axis_constraints = defaultdict(list) + opt._fixed_vars = [] opt._name_counters = {} - - # Reconstruct cluster_links by expanding the node-level mapping over - # all (argi, out_idx, inp_idx) combinations. - opt.cluster_links = {} - for linked_name, root_name in cluster_links_node_by_name.items(): - linked_node = nodes_by_name[linked_name] - root_node = nodes_by_name[root_name] - linked_idx = opt.node_map[linked_node] - root_idx = opt.node_map[root_node] - for argi, out_idx, inp_idx in opt.walk_over_options(linked_node): - opt.cluster_links[(linked_idx, argi, out_idx, inp_idx)] = ( - root_idx, - argi, - out_idx, - inp_idx, - ) - opt._cluster_linked_node_idxs = {key[0] for key in opt.cluster_links} + # Loaded optimizers rebuild the PuLP problem below but carry no init-time + # profiling; an empty profile lets solve-time profile writes/guards no-op. + opt.build_pulp = True + opt.profile = {"timings": {}} + + # cluster_links is node-level: copy node idx -> root node idx. + opt.cluster_links = { + opt.node_map[nodes_by_name[linked_name]]: opt.node_map[nodes_by_name[root_name]] + for linked_name, root_name in cluster_links_node_by_name.items() + } + opt._cluster_linked_node_idxs = set(opt.cluster_links) # Mesh placeholder — provides shape/dim_names for get_json() and ndim # for add_node_constraint() default placement, without needing a PG opt.mesh = _MeshPlaceholder(save_dict["mesh_shape"], save_dict["mesh_dim_names"]) + # Map saved decision-var keys to loaded node indices. Only these keys had + # a finite-cost (valid) strategy edge at save time; invalid edges were + # pruned and must not get a variable, so seed _valid_keys before creating + # the PuLP variables (see ShardingOptimizer._build_decision_vars). + save_node_names = save_dict["dv_costs_node_names"] + keys_t = save_dict["dv_costs_keys"].tolist() + vals_t = save_dict["dv_costs_vals"].tolist() + mapped_keys = [ + (opt.node_map[nodes_by_name[save_node_names[k[0]]]], k[1], k[2], k[3]) + for k in keys_t + ] + opt._valid_keys = set(mapped_keys) + # Rebuild PuLP variables and decision vars from saved costs. t2 = time.perf_counter() opt.pulp_variables = opt._create_pulp_variables() @@ -296,19 +305,14 @@ def load_optimizer(cls, path): len(opt.pulp_variables), ) # Reconstruct decision_vars from compact tensors. - save_node_names = save_dict["dv_costs_node_names"] - keys_t = save_dict["dv_costs_keys"].tolist() - vals_t = save_dict["dv_costs_vals"].tolist() opt.decision_vars = {} - for (save_node_idx, argi, out_idx, inp_idx), ( + for key, ( compute_cost, comm_cost, transition_cost, - ) in zip(keys_t, vals_t): - node_name = save_node_names[save_node_idx] - node = nodes_by_name[node_name] - node_idx = opt.node_map[node] - key = (node_idx, argi, out_idx, inp_idx) + ) in zip(mapped_keys, vals_t): + node_idx, argi, out_idx, inp_idx = key + node = opt.nodes[node_idx] strategy = opt.strats[node].strategies[out_idx] opt.decision_vars[key] = DecisionVar( var=opt.pulp_variables[key], @@ -329,9 +333,9 @@ def load_optimizer(cls, path): len(opt.decision_vars), ) - opt._root_to_linked = defaultdict(list) - for linked_key, root_key in opt.cluster_links.items(): - opt._root_to_linked[root_key].append(linked_key) + opt._root_to_copies = defaultdict(list) + for copy_idx, root_idx in opt.cluster_links.items(): + opt._root_to_copies[root_idx].append(copy_idx) opt.prob = pulp.LpProblem("AutoParallel", pulp.LpMinimize) opt.add_default_constraints() @@ -384,7 +388,7 @@ def _restore_solution(opt, selected_keys_by_name, nodes_by_name): # Expand cluster links for root_key in list(opt.selected_keys): - opt.selected_keys.extend(opt._root_to_linked.get(root_key, [])) + opt.selected_keys.extend(opt._linked_option_keys(root_key)) def save_placements(opt, path): diff --git a/docs/README.md b/docs/README.md index 9299286f..4aa2dc2d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -23,5 +23,6 @@ If you're new to the project, use the reading order below. ## Advanced usage +- [Sharding Annotations and Shardy-like Propagation](sharding_annotations.md) - [Using `local_map` for MoE and Custom Communication Patterns](local_map_and_moe.md) - [Saving and Loading Optimizer State](save_load.md) diff --git a/docs/codebase_pipeline.md b/docs/codebase_pipeline.md new file mode 100644 index 00000000..533c4c09 --- /dev/null +++ b/docs/codebase_pipeline.md @@ -0,0 +1,593 @@ +# AutoParallel Codebase Pipeline + +This document is a code-oriented guide for new contributors. It explains the +main pipeline, the important modules, and how data moves from a user model to a +parallelized module. + +AutoParallel is experimental and tightly coupled to PyTorch internals such as +FX, Dynamo export, AOTAutograd, DTensor, and Inductor. The best mental model is: + +```text +user model + -> fake/global tracing + -> joint forward/backward FX graph + -> per-node sharding strategy enumeration + -> ILP optimization + -> graph lowering with redistributions + -> parallel nn.Module with sharded params/buffers + -> optional torch.compile backend passes +``` + +## Public Entry Points + +The public API is exported from `autoparallel/__init__.py`: + +- `auto_parallel(...)`: simple wrapper for common usage. +- `AutoParallel(...)`: context-manager API for debugging and custom constraints. +- `autoparallel_backend(...)`: `torch.compile` backend wrapper for activation + checkpointing and communication/compute overlap passes. +- `with_sharding_constraint(...)`: model-level constraint helper. + +The main implementation lives in `autoparallel/api.py`. + +## End-to-End Pipeline + +### 1. User Defines Model, Mesh, and Example Inputs + +Users provide: + +- an `nn.Module`, often built on the `meta` device, +- a PyTorch `DeviceMesh`, +- example inputs, +- output placement constraints, +- optionally mixed precision and parameter memory constraints. + +The simple API accepts real tensors or DTensors as `sample_inputs`. DTensor +inputs are important because their placements become input constraints. Regular +tensors are treated as replicated on every mesh dimension. + +Relevant files: + +- `autoparallel/api.py` +- `autoparallel/input_validation.py` +- `docs/api_walkthrough.md` +- `examples/example_autoparallel.py` +- `examples/example_hf.py` + +### 2. Input Metadata Is Normalized + +In `auto_parallel(...)`, sample inputs are converted into metadata: + +- global shapes, +- dtypes, +- devices, +- input placement tuples, +- pytree structure. + +This is handled by `_extract_input_info(...)` and `_make_input_fn(...)` in +`autoparallel/input_validation.py`. + +The generated `input_fn()` creates fresh tensors with the same global metadata. +It is called later inside `FakeTensorMode`, so the tensors become fake tensors +instead of real allocations. + +### 3. AutoParallel Context Setup + +`AutoParallel.__init__` prepares the optimization environment: + +- deep-copies the user model so tracing and dtype wrappers do not mutate it, +- canonicalizes and applies mixed precision wrappers if requested, +- moves meta parameters and buffers into fake tensors on the mesh device, +- stores the mesh, cost model, and dynamic-shape setting, +- optionally creates a `ShapeEnv` for symbolic shapes. + +`AutoParallel.__enter__` then: + +- configures the NCCL topology cost model, +- enters the `DeviceMesh` context, +- traces the model into a joint graph, +- disables Inductor comprehensive padding while AutoParallel is active, +- constructs a `ShardingOptimizer`. + +Relevant files: + +- `autoparallel/api.py` +- `autoparallel/tracing.py` +- `autoparallel/cast_parametrization.py` +- `autoparallel/cost_models/nccl_cost_model.py` +- `autoparallel/cost_models/collective_runtime_estimation.py` + +### 4. Model Is Traced Into a Joint FX Graph + +Tracing happens in `build_joint_graph(...)` in `autoparallel/api.py`. + +The flow is: + +1. Call `input_fn()` under `FakeTensorMode`. +2. Optionally convert fake inputs to symbolic dynamic inputs. +3. Capture a forward graph with Dynamo export. +4. Restore model state after capture. +5. Add unused params and buffers so they still appear in the parameter specs. +6. Use AOTAutograd to export a joint forward/backward graph. +7. Clean up and normalize the graph. +8. Optionally replace `view -> mm -> view` patterns with `einsum`. +9. Add alias nodes to expose more optimization opportunities. + +The resulting graph is a single FX graph containing forward computation, +backward computation, parameter nodes, gradients, tangents, and outputs. +AutoParallel optimizes this joint graph rather than optimizing only the forward +path. + +Relevant files: + +- `autoparallel/api.py` +- `autoparallel/tracing.py` +- `autoparallel/graph_passes/graph_utils.py` +- `autoparallel/graph_passes/extract_forward.py` + +## Sharding Strategy Generation + +### 5. The Optimizer Builds Placement Options + +`ShardingOptimizer` is implemented in `autoparallel/optimize_sharding.py`. + +It first creates a concrete copy of the graph with symbolic dimensions replaced +by their hinted concrete values. The optimizer uses this concrete graph for +strategy enumeration, cost estimation, graph clustering, and ILP construction. +The original graph is kept for `apply_sharding`, which may still need symbolic +shape metadata. + +For each tensor-producing node, `build_sharding_metadata()` creates an +`OpStrategy`. An `OpStrategy` is a list of possible `OpSpec` choices. Each +`OpSpec` describes: + +- expected input DTensor specs, +- produced output DTensor specs, +- redistribution costs from predecessor placements. + +Placeholders and parameters start with all valid placements generated by +`_create_all_options(...)`. Call-function nodes get strategies from +`get_placement_options_for_node(...)`. + +Relevant files: + +- `autoparallel/optimize_sharding.py` +- `autoparallel/shardings/placement_options.py` +- `autoparallel/shardings/propagation_rules.py` + +### 6. Placement Rules Come From DTensor Plus AutoParallel Overrides + +`autoparallel/shardings/placement_options.py` dispatches strategy generation. + +For normal ops: + +- if AutoParallel has a custom rule in `_op_rules`, it uses that, +- otherwise it asks PyTorch DTensor for an op strategy through helper wrappers. + +AutoParallel adds custom rules in `autoparallel/shardings/propagation_rules.py`. +These rules cover cases where the default DTensor propagation is missing, +too strict, or not shaped for AutoParallel's optimizer. + +Important examples: + +- view and reshape-like ops, +- `operator.getitem`, +- pointwise behavior, +- tensor factory ops, +- matmul/einsum behavior, +- local-map and MoE-related higher-order ops, +- flex attention higher-order ops. + +After strategies are generated, AutoParallel: + +- propagates tensor metadata, +- fills missing redistribution costs, +- removes invalid shardings where tensor dimensions are too small for the mesh, +- deduplicates equivalent configurations, +- caches repeated placement-option lookups. + +## Cost Model + +### 7. Compute Cost + +Compute cost is estimated in `autoparallel/cost_models/compute_estimation.py`. + +The broad idea is: + +- count FLOPs when possible, +- estimate memory read/write time, +- estimate compute time from device throughput, +- use the max of memory time and compute time, +- apply a small launch floor for tiny kernels, +- treat pure view-like shape operations as cheap or free. + +The module contains hardware limit tables for several GPU families and a flop +counter extension for `einsum`. + +### 8. Communication Cost + +Communication cost is estimated in +`autoparallel/cost_models/collective_runtime_estimation.py`. + +The key transition types are: + +- `Shard -> Replicate`: all-gather, +- `Partial -> Replicate`: all-reduce, +- `Partial -> Shard`: reduce-scatter, +- `Shard(dim_a) -> Shard(dim_b)`: all-to-all, +- `Replicate -> Shard`: local narrowing, usually no collective. + +By default, `AutoParallel.__enter__` detects an NCCL topology config and the +cost model dispatches to `autoparallel/cost_models/nccl_cost_model.py`. This is +important because intra-node and inter-node collectives have very different +costs. + +Redistribution cost also includes penalties for non-contiguous layouts and +non-dim-0 shard reshuffling, because those cases need extra memory movement. + +### 9. Transition Cost + +The optimizer also adds a small sharding-transition penalty when a producer and +consumer use different placements. This is a tie-breaker that encourages +placement stability when communication and compute costs are otherwise similar. + +## ILP Optimization + +### 10. Decision Variables + +The ILP is built in `ShardingOptimizer`. + +A decision variable represents: + +```text +(node, argument index, output strategy index, producer input strategy index) +``` + +Each variable has: + +- total cost, +- compute cost, +- communication cost, +- transition cost, +- selected `OpSpec`, +- input and output DTensor specs. + +For repeated subgraphs, graph clustering can link equivalent decision variables +so the ILP is smaller. + +Relevant files: + +- `autoparallel/optimize_sharding.py` +- `autoparallel/graph_passes/graph_clustering.py` + +### 11. Default Constraints + +The optimizer adds these constraints before solving: + +- uniqueness: each node argument selects exactly one choice, +- same-output consistency: all tensor arguments of a multi-input op agree on + one output strategy, +- flow consistency: producer output placement matches consumer input placement, +- invalid-cost constraints: impossible configurations cannot be selected, +- forward/backward consistency constraints, +- gradient-reduce dtype constraints. + +User-facing constraints are layered on top: + +- `add_input_constraints(...)`, +- `add_output_constraints(...)`, +- `add_parameter_memory_constraint(...)`, +- node constraints through optimizer helpers, +- model-embedded `with_sharding_constraint(...)`. + +### 12. Solving + +`get_solution(...)` sets the objective and solves the ILP with PuLP's CBC +solver. The objective minimizes total estimated runtime cost across the joint +graph: + +```text +compute cost + communication cost + transition cost +``` + +The result is a mapping: + +```text +FX node -> chosen OpSpec +``` + +Public debugging helpers include: + +- `get_log(...)`, +- `print_costs_for_node(...)`, +- `explain_placement(...)`, +- `diff_solutions(...)`, +- `save(...)` and `load(...)`, +- `save_placements(...)` and `load_placements(...)`, +- `get_json(...)`. + +## Applying the Solution + +### 13. Lowering the Graph to Local Execution + +`apply_placement(...)` calls `apply_sharding_to_model(...)` in +`autoparallel/apply_sharding.py`. + +The important class is `ApplyShardingInterpreter`, an FX interpreter that walks +the original joint graph and inserts the behavior implied by the chosen +placements. + +For each operation, it: + +- looks up the producer specs and target input specs, +- redistributes local tensors when placements differ, +- handles `operator.getitem` specially for tuple outputs, +- localizes shape arguments for tensor factories and view ops, +- wraps view inputs in DTensor in static mode when DTensor should perform + global-to-local shape conversion, +- executes the original op, +- converts DTensor outputs back to local tensors. + +The output is a parallel FX graph that operates on local tensors and explicit +collective/redistribution behavior. + +Relevant files: + +- `autoparallel/apply_sharding.py` +- `autoparallel/shardings/ordered_sharding.py` + +### 14. Parameters and Buffers Are Sharded + +`_shard_params_and_buffers(...)` builds DTensor parameters and buffers from the +solved placements. It uses the original graph's named parameter and buffer +descriptors to map FQNs to FX nodes. + +The returned dictionaries are: + +```text +fqn -> sharded Parameter +fqn -> sharded buffer DTensor +``` + +`make_parallel_module(...)` then constructs the final module. + +Relevant files: + +- `autoparallel/apply_sharding.py` +- `autoparallel/module_construction.py` + +### 15. Parallel Module Construction + +`autoparallel/module_construction.py` creates a new module class that mirrors +the user's original model class. + +It preserves: + +- user-defined instance attributes, +- nested module structure, +- `ModuleDict`-like containers when possible, +- parameter aliases, +- buffer aliases, +- module aliases, +- orphan submodules needed by initialization code. + +It also replaces the module's `forward` with the AutoParallel-generated +function and wraps `init_weights` if the model has one. + +### 16. Runtime Forward + +The generated `forward` in `AutoParallel.apply_placement(...)`: + +1. Flattens user inputs. +2. Validates local runtime shapes and dtypes against traced expectations. +3. Reads DTensor parameters and buffers from the module. +4. Converts parameters and buffers to local tensors. +5. Boxes params, buffers, and runtime inputs into the AOTAutograd-compiled + function. +6. Uses the joint forward/backward function when gradients are enabled. +7. Uses a forward-only extracted graph under `torch.no_grad()`. + +The returned parallel module expects local per-rank tensors at runtime, not +global tensors. + +### 17. Initialization and Loading + +A common workflow is: + +```python +with torch.device("meta"): + model = MyModel(...) + +parallel_model = auto_parallel(...) +parallel_model.to_empty(device="cuda") +parallel_model.init_weights() +``` + +`autoparallel/init_weights.py` makes typical single-GPU initialization code +work with sharded DTensor parameters. It intercepts parameter and buffer +assignments during `init_weights` and copies the assigned full tensor into the +existing DTensor placement. + +Save/load support lives in: + +- `autoparallel/serialization.py` +- `docs/save_load.md` + +## Optional Compilation Pipeline + +The eager parallel module can be passed to: + +```python +torch.compile(parallel_model, backend=autoparallel_backend()) +``` + +`autoparallel/compile.py` wraps Inductor and can enable: + +- activation checkpointing joint pass, +- collective bucketing, +- overlap scheduling, +- insertion of overlap dependencies, +- prefetch limits. + +Activation checkpointing logic is in: + +- `autoparallel/graph_passes/activation_checkpointing.py` + +Other graph and scheduling passes live under: + +- `autoparallel/graph_passes/` +- `autoparallel/graph_passes/async_tp/` +- `autoparallel/graph_passes/autobucketing_inductor/` + +## Important Supporting Areas + +### Custom Ops and Constraints + +`autoparallel/collectives.py` exposes sharding constraints and related +collective helpers. Model authors can use `with_sharding_constraint(...)` inside +model code to force an intermediate placement. + +`autoparallel/ops.py` contains registered AutoParallel-specific operations. + +### Local Map and MoE + +AutoParallel has special handling for `local_map` and MoE-style communication. +Placement options for local-map higher-order ops are generated in +`placement_options.py`, while user-facing examples and explanations are in: + +- `docs/hc_and_moe.md` +- `examples/example_local_map.py` +- `examples/example_dcp.py` +- `examples/native_ds3/` + +### Dynamic Shapes + +When `dynamic=True`, `AutoParallel` traces with symbolic dimensions. The +optimizer still works on a concretized graph, but `apply_sharding` preserves the +original symbolic graph and recreates local fake inputs with fresh symbols for +lowering. Runtime input validation allows dimensions marked dynamic to vary. + +Relevant files: + +- `autoparallel/api.py` +- `autoparallel/optimize_sharding.py` +- `autoparallel/apply_sharding.py` +- `autoparallel/input_validation.py` +- `tests/test_dynamic_shapes.py` + +### JSON and Visualization + +The optimizer can export strategy decisions to JSON with `get_json()`. + +Relevant files: + +- `autoparallel/export_json.py` +- `autoparallel/visualizer/build_display_from_json.py` +- `tests/test_export_json.py` + +## Directory Map + +```text +autoparallel/ + api.py public APIs and orchestration + tracing.py fake tensor conversion and decomposition setup + input_validation.py sample input metadata and runtime checks + optimize_sharding.py ILP optimizer and debugging helpers + apply_sharding.py graph lowering and sharded param creation + module_construction.py final parallel module construction + init_weights.py DTensor-aware init_weights wrapper + compile.py torch.compile backend wrapper + collectives.py sharding constraints and collective helpers + ops.py custom operator registrations + serialization.py optimizer and placement save/load + export_json.py visualization/export format + +autoparallel/shardings/ + placement_options.py per-node strategy generation + propagation_rules.py custom DTensor propagation rules + dtensor_sharding_helpers.py wrappers around DTensor strategy APIs + ordered_sharding.py optimized redistribution ordering + +autoparallel/cost_models/ + compute_estimation.py operation runtime estimates + collective_runtime_estimation.py redistribution cost estimates + nccl_cost_model.py NCCL topology-aware cost model + +autoparallel/graph_passes/ + graph_utils.py graph cleanup and helper analysis + graph_clustering.py repeated-subgraph detection + activation_checkpointing.py recomputation/AC tagging and pass + extract_forward.py forward-only graph extraction + auto_bucketing.py bucketing helpers + async_tp/ async tensor-parallel passes + autobucketing_inductor/ Inductor-oriented bucketing passes + +docs/ user and contributor documentation +examples/ runnable examples +tests/ behavior and regression tests +``` + +## How to Read the Code + +For a first pass, read in this order: + +1. `docs/basic_concepts.md` +2. `docs/api_walkthrough.md` +3. `autoparallel/api.py` +4. `autoparallel/optimize_sharding.py` +5. `autoparallel/shardings/placement_options.py` +6. `autoparallel/shardings/propagation_rules.py` +7. `autoparallel/apply_sharding.py` +8. `autoparallel/module_construction.py` +9. `autoparallel/compile.py` + +Then use tests to understand edge cases: + +- `tests/test_api.py` +- `tests/test_auto_parallel_simple.py` +- `tests/test_optimize_placement.py` +- `tests/test_propagation_rules.py` +- `tests/test_apply_sharding.py` +- `tests/test_dynamic_shapes.py` +- `tests/test_flex_attention.py` +- `tests/test_inference_path.py` + +## Debugging Workflow + +When investigating a model or optimizer decision: + +1. Start with the full `AutoParallel` API instead of `auto_parallel(...)`. +2. Add explicit input and output constraints. +3. Add a parameter memory constraint if you expect FSDP-like sharding. +4. Call `optimize_placement(verbose=True)`. +5. Read the optimizer log for chosen placements and cost breakdowns. +6. Use `print_costs_for_node(...)` for a suspicious node. +7. Use `explain_placement(...)` to compare a target placement with the chosen + placement. +8. Temporarily add a node constraint and compare with `diff_solutions(...)`. +9. Inspect the parallel graph emitted by structured logs or `parallel_gm`. + +Common symptoms: + +- Replicated parameters: missing or loose parameter memory constraint. +- Infeasible ILP: contradictory input/output/node constraints or shard dim too + small for the mesh. +- Unexpected all-gather/all-reduce: producer and consumer placements disagree. +- Shape mismatch at runtime: passing global tensors to a module that expects + local tensors. +- Dynamic-shape compile failure: check whether symbolic dims were concretized + too early or local shape args were not localized. + +## Contributor Notes + +- Prefer existing DTensor strategy APIs before adding custom propagation rules. +- Add custom rules only when the default rule is missing or does not preserve + the metadata AutoParallel needs. +- Keep optimizer constraints explicit; hidden state makes debugging ILP failures + difficult. +- Add focused tests when touching strategy enumeration, cost modeling, + constraints, or graph lowering. +- Be careful with aliases: parameters, buffers, and modules can share identity, + and the code intentionally preserves those relationships. +- The traced graph uses global shapes; the returned module executes on local + tensors. Many bugs come from mixing those two worlds. diff --git a/docs/sharding_annotations.md b/docs/sharding_annotations.md new file mode 100644 index 00000000..b9248cb9 --- /dev/null +++ b/docs/sharding_annotations.md @@ -0,0 +1,183 @@ +# Sharding Annotations and Shardy-like Propagation + +By default AutoParallel hands the entire sharding decision to the ILP: every +node enumerates every valid placement and the solver picks the global optimum. +That is the right default for a fresh model, but at scale the search space is +large even though the user often already knows the high-level plan — "the +attention and MLP projections are tensor-parallel; the batch is data-parallel". + +This page describes how to express that plan as a few **sharding annotations** +and have AutoParallel **propagate** them through the graph the way +[Shardy](https://github.com/openxla/shardy) does, turning the unambiguous part +of the graph into ILP constraints. This shrinks the search space and the solve +time while leaving the genuine cost tradeoffs to the solver. With a typical +tensor-parallel annotation on LLaMA-3 it reaches the *same* objective as the +full ILP on a noticeably smaller problem. + +If you are new to the project, start with +[Getting Started](getting_started.md) and +[How AutoParallel Chooses a Strategy](how_autoparallel_chooses_a_strategy.md). + +## The annotation API + +Annotations are added on the `AutoParallel` context manager, after the input / +output constraints and before `optimize_placement`: + +```python +with AutoParallel(model, input_fn, mesh) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + + # Annotate the tensor-parallel plan. A glob matches the weight in every + # layer at once. Only the tp axis is pinned; the data axis is left open. + column_parallel = (None, Shard(0)) # shard the output dim + row_parallel = (None, Shard(1)) # shard the input dim + for proj in ["wq", "wk", "wv"]: + autop.annotate_parameter(f"layers.*.attention.{proj}.weight", column_parallel) + autop.annotate_parameter("layers.*.attention.wo.weight", row_parallel) + for proj in ["w1", "w3"]: + autop.annotate_parameter(f"layers.*.feed_forward.{proj}.weight", column_parallel) + autop.annotate_parameter("layers.*.feed_forward.w2.weight", row_parallel) + + autop.propagate_annotations() # propagate + constrain + sharding = autop.optimize_placement() +``` + +A placement is a tuple with one entry per mesh dimension. Each entry is a +`Placement` (`Shard(d)`, `Replicate()`, ...) or **`None`** to leave that mesh +axis *open* for propagation / the ILP to decide. Leaving the data axis open is +the common case for weights: you pin the tensor-parallel axis and let the +optimizer choose FSDP vs DDP on the data axis. + +The available annotation methods are: + +- `annotate_parameter(fqn, placements, priority=1)` — `fqn` is a parameter + fully-qualified name or a glob pattern (e.g. `"layers.*.attention.wq.weight"`). +- `annotate_input(idx, placements, priority=0)` / + `annotate_output(idx, placements, priority=0)` — graph input/output by index. +- `annotate_node(node, placements, priority=0)` — an arbitrary FX node. + +`priority` controls the order annotations propagate (lower first, matching +Shardy). Activations/IO default to a higher priority than weights so that where +they compete for the same mesh axis (the data axis of a matmul) the +data-parallel sharding wins and the weight is all-gathered, rather than the +activation being resharded. + +`propagate_annotations()` returns a `PropagationResult` summarizing the +reduction (`nodes_determined`, `axis_constraints`, `reduction`). + +## How propagation works + +Propagation mirrors the structure of Shardy's propagation, expressed over +AutoParallel's existing per-node strategy lists (which already encode each op's +sharding rule): + +- **Per-mesh-axis.** A placement is propagated one mesh axis at a time. This is + what lets a weight's tensor-parallel sharding flow through a matmul on the + `tp` axis while the `dp` axis is resolved independently (data-parallel batch, + with FSDP all-gathers left to the ILP). It is the analogue of Shardy + projecting tensor shardings onto per-factor axes. + +- **Reshard-free.** Along an edge a consumer is only narrowed to the placements + it can take *without* a reshard from the producer (zero redistribution cost). + At a genuine reshard boundary — a necessary collective such as an all-reduce + or all-gather — no zero-cost option exists, so propagation stops there and the + ILP decides the collective. + +- **To a fixed point.** A worklist re-examines a node's neighbors whenever its + set of candidate shardings shrinks, until nothing changes. + +- **Priority rounds.** Annotations propagate in priority order; later rounds + cannot override what an earlier round determined. + +Once propagation reaches a fixed point, every mesh axis of a node whose sharding +became unambiguous is turned into a per-axis ILP constraint +(`add_node_axis_constraint`), which constrains that one axis and leaves the rest +of the node free. + +### What is and isn't pinned + +Propagation deliberately only pins genuine **`Shard`** placements — the +tensor-parallel structure the annotations describe, which is invariant in the +optimum. It does *not* pin: + +- **`Replicate`** — pinning it would forbid the ILP from instead sharding that + axis (for example choosing sequence parallelism on the residual stream). +- **`Partial`** — a pending reduction whose collective the ILP places; pinning + it fixes where the reduction happens and can even be infeasible (a `Partial` + value cannot be added to a `Replicate` residual without first reducing it). + +Both are genuine cost tradeoffs, so leaving them open keeps the optimum +reachable at little cost to the reduction. + +Two more correctness rules keep the constraint set feasible and faithful: + +- **Parameters are sources only.** A parameter's placement is its *stored* + sharding, which legitimately differs from the *compute* sharding a consumer + needs by a reshard (an FSDP all-gather). Propagation never infers a + parameter's sharding from its consumers, so an open data axis stays free for + FSDP, and a per-axis parameter constraint still counts toward the memory + budget on its free axes. +- **Backward pass via the pairing.** The forward/backward consistency + constraints already tie each gradient to its forward tensor, so the + parameter/input gradients and output tangents are left for the pairing to + decide; the rest of the backward graph is constrained normally (and the + forward annotations are mirrored onto the gradients to drive that). + +## How a pin is applied: variable fixing vs constraints + +`propagate_annotations(method=...)` (forwarded to +`ShardingOptimizer.add_node_axis_constraint`) controls how each determined axis +is committed to the ILP: + +- **`"fix"` (default)** sets the upper bound of the ruled-out decision variables + to 0, so the solver's presolve drops those columns and the problem actually + shrinks. +- **`"constraint"`** adds an `== 1` equality row over the matching variables. + It is removable by name, but on a large mesh adding thousands of rows without + removing any columns can *slow* the solve. + +Variable fixing is strictly better for solve time (and never worse for the +objective), which is why it is the default. + +## Solver performance and the LP relaxation + +`ShardingOptimizer.solve_lp_relaxation()` solves the continuous relaxation +(binaries relaxed to `[0, 1]`) and reports the objective, solve time, and how +many variables came out fractional. It exposes two facts that matter for +performance: + +1. **The relaxation is integral.** On LLaMA-3 (2D and 3D meshes), with and + without annotations, the LP relaxation comes out with *zero* fractional + variables and an integrality gap of 0% — its optimum already *is* the integer + optimum. So `solve_lp_relaxation(extract=True)` returns a valid optimal + per-node strategy dict (same form as `get_solution`) while skipping + branch-and-bound, which is several times faster than the MILP solve (e.g. on + the 16-layer 2D model, ~10s vs ~50s; on a 2M-variable 3D problem, ~45s vs + ~160s). This is the single biggest available speedup and is exact whenever + the relaxation is integral (it falls back to `None` when it is not). + +2. **Where annotations help the MILP.** Because the relaxation is integral, + there is little branch-and-bound to cut, so the annotation speedup is + scale-dependent: on a ~400k-variable problem the MILP overhead is a large + fraction and pinning the TP structure gives ~1.7–1.8×; on a ~2M-variable + problem the solve is dominated by the relaxation/model size itself, so the + speedup shrinks toward ~1× even though the *search space* shrinks more (the + extra mesh axis gives more axes to pin — e.g. −59% strategy choices on 3D vs + −36% on 2D). The annotation speedup on the *LP* solve is correspondingly + modest (~1.1–1.4×). The takeaway: annotations reduce the search space and + keep the optimum exact, but for raw solve time on this (integral) problem the + larger lever is solving the relaxation directly. + +A separate, orthogonal cost is that building the ILP for a 3-axis mesh is slow: +per-node strategy enumeration grows with the number of mesh axes (it is cubic +for a 3-axis mesh, dominated by the 4D attention tensors), which is independent +of the solve and of annotations. + +## Example + +`examples/example_llama3_annotated.py` runs the full ILP and the +annotated+propagation path on a LLaMA-3-1B model on a 2D mesh and prints the +comparison: the annotated path reaches the same objective on a search space +reduced by roughly a third, with a correspondingly faster solve. diff --git a/examples/example_llama3_annotated.py b/examples/example_llama3_annotated.py new file mode 100644 index 00000000..7e1f1ecb --- /dev/null +++ b/examples/example_llama3_annotated.py @@ -0,0 +1,145 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +"""Sharding annotations + Shardy-like propagation on LLaMA3-1B (2D mesh). + +By default AutoParallel hands the whole sharding decision to the ILP. At scale +a user usually already knows the tensor-parallel plan ("these projections are +column-parallel, those are row-parallel"). This example shows how to express +that plan as a few *annotations*, propagate it through the graph the way Shardy +does, and turn the unambiguous part of the graph into ILP constraints. + +The annotations pin only the **tensor-parallel (tp) axis** of the transformer +body weights. Everything else -- the data/FSDP axis, the residual stream +(replicate vs sequence-parallel), the vocab/embedding sharding, and where the +collectives go -- is left to the ILP. Propagation then determines the sharding +of the activations that *follow* from the plan with no resharding and constrains +them, which shrinks the search space and the solve time while leaving the +genuine cost tradeoffs to the solver. + +Run it (no GPUs needed -- uses a fake process group): + + python examples/example_llama3_annotated.py +""" + +import logging +import time + +import pulp +import torch +from torch.distributed.tensor.placement_types import Replicate, Shard +from torch.testing._internal.distributed.fake_pg import FakeStore + +from autoparallel._testing.models.llama3 import ( + Transformer, + TransformerModelArgs, + apply_ac, +) +from autoparallel.api import AutoParallel + +logging.basicConfig(level=logging.WARNING) + +world_size = 64 +fake_store = FakeStore() +torch.distributed.init_process_group( + "fake", store=fake_store, rank=0, world_size=world_size +) + +# 2D mesh: data/FSDP on dp, tensor-parallel on tp. +dp, tp = world_size // 8, 8 +mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", (dp, tp), mesh_dim_names=("dp", "tp") +) + +# Small-batch / long-sequence regime, where tensor parallelism is worthwhile. +vocab_size = 128256 +seqlen = 2048 +batch_size = 2 * dp + + +def model_fn(): + # LLaMA-3.2-1B-ish config. + return Transformer( + TransformerModelArgs( + dim=2048, + n_layers=16, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.5, + multiple_of=256, + rope_theta=500000, + vocab_size=vocab_size, + max_seq_len=seqlen, + ) + ) + + +def input_fn(): + return torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda") + + +def annotate_tp_plan(autop): + """The 'conscious' tensor-parallel plan, as a handful of annotations. + + Only the tp axis is pinned (the data axis is left ``None`` = open). A glob + pattern annotates the matching weight in every layer at once. + """ + column_parallel = (None, Shard(0)) # shard the output dim (dim 0 of [out, in]) + row_parallel = (None, Shard(1)) # shard the input dim (dim 1 of [out, in]) + for proj in ["wq", "wk", "wv"]: + autop.annotate_parameter(f"layers.*.attention.{proj}.weight", column_parallel) + autop.annotate_parameter("layers.*.attention.wo.weight", row_parallel) + for proj in ["w1", "w3"]: + autop.annotate_parameter( + f"layers.*.feed_forward.{proj}.weight", column_parallel + ) + autop.annotate_parameter("layers.*.feed_forward.w2.weight", row_parallel) + + +with torch.device("meta"): + model = model_fn() +apply_ac(model, mode="full") + +with AutoParallel(model, input_fn, mesh, repeated_subgraphs=True) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) # vocab-parallel logits + opt = autop.sharding_optimizer + print( + f"ILP: {len(opt.strats)} nodes, {len(opt.decision_vars)} decision variables " + f"on a ({dp}, {tp}) mesh" + ) + + # --- Baseline: full ILP, no annotations --- + t = time.perf_counter() + autop.optimize_placement(verbose=False) + t_baseline = time.perf_counter() - t + obj_baseline = pulp.value(opt.prob.objective) + print( + f"baseline full ILP : objective {obj_baseline:11.1f} solve {t_baseline:6.1f}s" + ) + + # --- Annotated: propagate the TP plan, then solve the reduced problem --- + annotate_tp_plan(autop) + result = autop.propagate_annotations(verbose=False) + t = time.perf_counter() + opt.resolve(verbose=False) + t_annotated = time.perf_counter() - t + obj_annotated = pulp.value(opt.prob.objective) + print( + f"annotated + propag: objective {obj_annotated:11.1f} solve {t_annotated:6.1f}s" + ) + + gap = 100 * (obj_annotated - obj_baseline) / obj_baseline + print( + f"\npropagation pinned {result.nodes_determined} nodes " + f"({result.axis_constraints} per-axis constraints), shrinking the " + f"output-strategy search space by {100 * result.reduction:.1f}% " + f"({result.strategies_before} -> {result.strategies_after})" + ) + print( + f"objective gap vs full ILP: {gap:+.2f}% " + f"solve speedup: {t_baseline / max(t_annotated, 1e-9):.1f}x" + ) diff --git a/pyproject.toml b/pyproject.toml index 31b0df19..3c5a55c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,3 +61,9 @@ exclude = [ "autoparallel/tools/overlap_simulator/repro_.*\\.py", "autoparallel/visualizer/build_display_from_json\\.py", ] + +[tool.pyrefly] +search-path = [ + "/home/wangkj/.conda/envs/pt-dev/lib/python3.12/site-packages", + "/data/users/wangkj/pytorch", +] diff --git a/tests/conftest.py b/tests/conftest.py index d5d23ea1..22af2357 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -40,6 +40,16 @@ def apply_cuda_patches(func): return func +@pytest.fixture(autouse=True) +def _reset_placement_options_cache(): + """The placement-options cache is a process-global; clear it before each test + so optimizer builds never reuse stale strategies from a prior test's model.""" + from autoparallel.shardings.placement_options import reset_placement_options_cache + + reset_placement_options_cache() + yield + + @pytest.fixture(scope="module", autouse=True) def init_pg(): world_size = 256 diff --git a/tests/test_approximate_sharding.py b/tests/test_approximate_sharding.py new file mode 100644 index 00000000..b75ebe52 --- /dev/null +++ b/tests/test_approximate_sharding.py @@ -0,0 +1,245 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math + +import pulp +import pytest +import torch +from conftest import apply_cuda_patches +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs +from autoparallel.api import AutoParallel +from autoparallel.approximate_sharding import ApproximateShardingSolver + + +def _fake_2d_mesh(): + return torch.distributed.device_mesh.init_device_mesh( + "cuda", (4, 2), mesh_dim_names=("dp", "tp") + ) + + +def _tiny_llama3_autop(mesh, solver="ilp"): + vocab_size = 128 + seq_len = 16 + batch_size = 2 * mesh.shape[0] + model_args = TransformerModelArgs( + dim=64, + n_layers=2, + n_heads=4, + n_kv_heads=2, + vocab_size=vocab_size, + multiple_of=32, + rope_theta=500000, + max_seq_len=seq_len, + ) + with torch.device("meta"): + model = Transformer(model_args) + + def input_fn(): + return torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + ) + return AutoParallel( + model, input_fn, mesh, mp_policy, repeated_subgraphs=True, solver=solver + ) + + +def _add_constraints(autop, mesh): + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0),) + (Replicate(),) * (mesh.ndim - 1)]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + + +@apply_cuda_patches +@pytest.mark.filterwarnings("ignore:Constructing LpVariable") +@pytest.mark.filterwarnings("ignore:Overwriting previously set objective") +def test_approx_objective_close_to_ilp(): + """The approximate solver should be much faster than the ILP while staying + within a small objective gap on a tiny LLaMA3 block + 2D mesh.""" + mesh = _fake_2d_mesh() + with _tiny_llama3_autop(mesh) as autop: + _add_constraints(autop, mesh) + opt = autop.sharding_optimizer + + autop.optimize_placement(verbose=False, solver="approx") + approx_objective = pulp.value(opt.prob.objective) + # The approx assignment must be ILP-feasible (flow consistency etc.); + # an infeasible assignment can score artificially low and silently pass + # the objective bound below. + violated = [n for n, c in opt.prob.constraints.items() if not c.valid()] + assert not violated, f"approx violated {len(violated)} constraints" + + autop.optimize_placement(verbose=False, solver="ilp") + ilp_objective = pulp.value(opt.prob.objective) + + assert math.isfinite(approx_objective) + assert ilp_objective > 0 + assert approx_objective >= ilp_objective - 1e-6 # ILP is optimal + assert approx_objective <= ilp_objective * 1.20 + 1e-6, ( + f"approx={approx_objective} ilp={ilp_objective} " + f"gap={(approx_objective / ilp_objective - 1) * 100:.1f}%" + ) + + +@apply_cuda_patches +@pytest.mark.filterwarnings("ignore:Constructing LpVariable") +@pytest.mark.filterwarnings("ignore:Overwriting previously set objective") +def test_approx_memory_constrained_matches_ilp(): + """A non-tight parameter-memory budget routes the approx solver through the + Lagrangian relaxation. The result must respect the budget and stay within a + small objective gap of the budget-constrained ILP optimum.""" + mesh = _fake_2d_mesh() + with _tiny_llama3_autop(mesh) as autop: + # high=0.5 > 1/world_size, so the budget is non-tight (params are not + # pinned at build time) and can bind the runtime-optimal placement. + autop.add_parameter_memory_constraint(low=0.0, high=0.5) + autop.add_input_constraints([(Shard(0),) + (Replicate(),) * (mesh.ndim - 1)]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + opt = autop.sharding_optimizer + + autop.optimize_placement(verbose=False, solver="approx") + approx_objective = pulp.value(opt.prob.objective) + # Materialize the memory rows and check the approx assignment against ALL + # constraints, including the budget it was solved under. + opt._apply_memory_constraint() + violated = [n for n, c in opt.prob.constraints.items() if not c.valid()] + assert not violated, f"approx violated {len(violated)} constraints" + + autop.optimize_placement(verbose=False, solver="ilp") + ilp_objective = pulp.value(opt.prob.objective) + + assert math.isfinite(approx_objective) + assert approx_objective >= ilp_objective - 1e-6 # ILP is optimal + assert approx_objective <= ilp_objective * 1.05 + 1e-6, ( + f"approx={approx_objective} ilp={ilp_objective} " + f"gap={(approx_objective / ilp_objective - 1) * 100:.2f}%" + ) + + +@apply_cuda_patches +@pytest.mark.filterwarnings("ignore:Constructing LpVariable") +@pytest.mark.filterwarnings("ignore:Overwriting previously set objective") +def test_lp_solver_matches_ilp(): + """The LP-relaxation solver returns an integral, ILP-feasible assignment whose + objective equals the exact ILP optimum (the relaxation is integral here).""" + mesh = _fake_2d_mesh() + with _tiny_llama3_autop(mesh) as autop: + _add_constraints(autop, mesh) + opt = autop.sharding_optimizer + + autop.optimize_placement(verbose=False, solver="lp") + lp_objective = pulp.value(opt.prob.objective) + violated = [n for n, c in opt.prob.constraints.items() if not c.valid()] + assert not violated, f"lp violated {len(violated)} constraints" + + autop.optimize_placement(verbose=False, solver="ilp") + ilp_objective = pulp.value(opt.prob.objective) + + assert math.isfinite(lp_objective) + assert lp_objective == pytest.approx(ilp_objective, rel=1e-6) + + +@apply_cuda_patches +@pytest.mark.filterwarnings("ignore:Constructing LpVariable") +@pytest.mark.filterwarnings("ignore:Overwriting previously set objective") +def test_optimality_check_logs_certified_gap(caplog): + """optimality_check=True solves the LP lower bound and logs the certified gap.""" + mesh = _fake_2d_mesh() + with _tiny_llama3_autop(mesh) as autop: + _add_constraints(autop, mesh) + with caplog.at_level(logging.INFO, logger="autoparallel.api"): + autop.optimize_placement( + verbose=False, solver="approx", optimality_check=True + ) + assert any("optimality check" in r.message for r in caplog.records) + + +@apply_cuda_patches +@pytest.mark.filterwarnings("ignore:Constructing LpVariable") +def test_approx_objective_is_faithful(): + """The solver's internal energy must equal the exact ILP objective evaluated + on its assignment (pulp.value), so comparisons against the ILP are valid.""" + mesh = _fake_2d_mesh() + with _tiny_llama3_autop(mesh) as autop: + _add_constraints(autop, mesh) + opt = autop.sharding_optimizer + + solver = ApproximateShardingSolver(opt) + solver.get_solution(verbose=False) + + pulp_objective = pulp.value(opt.prob.objective) + internal_energy = solver.total_objective() + assert math.isfinite(internal_energy) + assert internal_energy == pytest.approx(pulp_objective, rel=1e-6) + # No forbidden decision variable should be selected. + assert all(key not in solver.forbidden for key in opt.selected_keys) + # And every ILP constraint must hold (flow consistency, paired, memory). + violated = [n for n, c in opt.prob.constraints.items() if not c.valid()] + assert not violated, f"approx violated {len(violated)} constraints" + + +@apply_cuda_patches +@pytest.mark.filterwarnings("ignore:Constructing LpVariable") +def test_approx_respects_input_output_constraints(): + """User input/output placement constraints must be honored by the solution.""" + mesh = _fake_2d_mesh() + x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) + out_sharding = (Shard(0), Shard(2)) + with _tiny_llama3_autop(mesh) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + + solution = autop.optimize_placement(verbose=False, solver="approx") + assert solution + + placements = { + spec.placements + for strat in solution.values() + for spec in ( + strat.output_specs + if isinstance(strat.output_specs, (list, tuple)) + else (strat.output_specs,) + ) + if isinstance(spec, DTensorSpec) + } + assert x_sharding in placements + assert out_sharding in placements + + +@apply_cuda_patches +@pytest.mark.filterwarnings("ignore:Constructing LpVariable") +def test_lite_build_matches_full(): + """Building with solver="approx" skips PuLP variables/constraints (faster + setup); the resulting assignment must be byte-identical to running the + approximate solver on a full PuLP build.""" + mesh = _fake_2d_mesh() + + with _tiny_llama3_autop(mesh, solver="ilp") as autop: + _add_constraints(autop, mesh) + assert autop.sharding_optimizer.prob is not None + autop.optimize_placement(verbose=False, solver="approx") + obj_full = autop.sharding_optimizer.profile["approximate"]["objective"] + keys_full = set(autop.sharding_optimizer.selected_keys) + + with _tiny_llama3_autop(mesh, solver="approx") as autop: + _add_constraints(autop, mesh) + # Lite build: no PuLP problem or variables were constructed. + assert autop.sharding_optimizer.prob is None + assert not autop.sharding_optimizer.pulp_variables + solution = autop.optimize_placement(verbose=False) + obj_lite = autop.sharding_optimizer.profile["approximate"]["objective"] + keys_lite = set(autop.sharding_optimizer.selected_keys) + assert solution + + assert obj_lite == pytest.approx(obj_full, rel=1e-9) + assert keys_lite == keys_full diff --git a/tests/test_dp_solver.py b/tests/test_dp_solver.py new file mode 100644 index 00000000..3dbb2d10 --- /dev/null +++ b/tests/test_dp_solver.py @@ -0,0 +1,158 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +import operator + +import pytest +import torch +import torch.nn.functional as F + +from autoparallel.graph_passes.graph_utils import all_input_nodes +from autoparallel.optimize_sharding import DPBasedShardingSolver + + +class _FakeOptimizer: + def __init__(self, graph): + self.graph = graph + self.strats = {node: object() for node in graph.nodes} + self.nodes = list(self.strats.keys()) + + def _all_input_nodes(self, node): + return [ + input_node + for input_node in all_input_nodes(node) + if input_node in self.strats + ] + + +def _assert_predecessors_match_graph_indegrees(topology): + topology_nodes = set(topology.nodes) + assert set(topology.predecessors) == topology_nodes + assert set(topology.node_to_index) == topology_nodes + + for node in topology.nodes: + expected_predecessors = [ + input_node + for input_node in all_input_nodes(node) + if input_node in topology_nodes + ] + predecessors = topology.predecessors[node] + assert len(predecessors) == len(expected_predecessors) + assert predecessors == expected_predecessors + + +def test_dp_solver_builds_topological_order_for_merge_graph(): + class MergeModule(torch.nn.Module): + def forward(self, x, y): + a = x + y + b = x * 2 + return a + b + + graph = torch.fx.symbolic_trace(MergeModule()).graph + solver = DPBasedShardingSolver(_FakeOptimizer(graph)) + + topology = solver.build_topological_order() + + assert all(node.op != "output" for node in topology.nodes) + assert topology.nodes == [node for node in graph.nodes if node.op != "output"] + _assert_predecessors_match_graph_indegrees(topology) + + for node, predecessors in topology.predecessors.items(): + node_index = topology.node_to_index[node] + for pred in predecessors: + assert topology.node_to_index[pred] < node_index + + merge = topology.nodes[-1] + assert [pred.name for pred in topology.predecessors[merge]] == ["add", "mul"] + + +def test_dp_solver_preserves_duplicate_predecessors(): + class DuplicateInputModule(torch.nn.Module): + def forward(self, x): + return x + x + + graph = torch.fx.symbolic_trace(DuplicateInputModule()).graph + solver = DPBasedShardingSolver(_FakeOptimizer(graph)) + + topology = solver.build_topological_order() + _assert_predecessors_match_graph_indegrees(topology) + + add_node = next(node for node in topology.nodes if node.op == "call_function") + predecessors = topology.predecessors[add_node] + assert len(predecessors) == 2 + assert predecessors[0] is predecessors[1] + assert predecessors[0].name == "x" + + +def test_dp_solver_topology_for_tiny_transformer_forward(): + class TinyTransformerBlock(torch.nn.Module): + def __init__(self): + super().__init__() + self.q = torch.nn.Linear(8, 8) + self.k = torch.nn.Linear(8, 8) + self.v = torch.nn.Linear(8, 8) + self.o = torch.nn.Linear(8, 8) + self.ff1 = torch.nn.Linear(8, 16) + self.ff2 = torch.nn.Linear(16, 8) + + def forward(self, x): + q = self.q(x) + k = self.k(x) + v = self.v(x) + scores = q @ k.transpose(-2, -1) / math.sqrt(8) + attn = F.softmax(scores, dim=-1) + attn_out = attn @ v + x = x + self.o(attn_out) + hidden = F.relu(self.ff1(x)) + return x + self.ff2(hidden) + + block = TinyTransformerBlock() + assert block(torch.randn(2, 4, 8)).shape == (2, 4, 8) + + graph = torch.fx.symbolic_trace(block).graph + solver = DPBasedShardingSolver(_FakeOptimizer(graph)) + + topology = solver.build_topological_order() + _assert_predecessors_match_graph_indegrees(topology) + node_names = [node.name for node in topology.nodes] + + assert node_names == [ + "x", + "q", + "k", + "v", + "transpose", + "matmul", + "truediv", + "softmax", + "matmul_1", + "o", + "add", + "ff1", + "relu", + "ff2", + "add_1", + ] + + add_nodes = [node for node in topology.nodes if node.target is operator.add] + assert [node.name for node in add_nodes] == ["add", "add_1"] + assert [pred.name for pred in topology.predecessors[add_nodes[0]]] == ["x", "o"] + assert [pred.name for pred in topology.predecessors[add_nodes[1]]] == [ + "add", + "ff2", + ] + + +def test_dp_solver_solution_is_not_implemented(): + class SimpleModule(torch.nn.Module): + def forward(self, x): + return x + 1 + + graph = torch.fx.symbolic_trace(SimpleModule()).graph + solver = DPBasedShardingSolver(_FakeOptimizer(graph)) + + with pytest.raises(NotImplementedError, match="only builds topological order"): + solver.get_solution() diff --git a/tests/test_lp_relaxation.py b/tests/test_lp_relaxation.py new file mode 100644 index 00000000..1b03e6fe --- /dev/null +++ b/tests/test_lp_relaxation.py @@ -0,0 +1,103 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import pulp +import pytest +import torch +from conftest import apply_cuda_patches +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs +from autoparallel.api import AutoParallel + + +def _fake_dp4_tp4_mesh(): + return torch.distributed.device_mesh.init_device_mesh( + "cuda", + (4, 4), + mesh_dim_names=("dp", "tp"), + ) + + +def _llama3_example_autop(device_mesh): + vocab_size = 128 + seq_len = 16 + batch_size = 2 * device_mesh.shape[0] + model_args = TransformerModelArgs( + dim=64, + n_layers=1, + n_heads=4, + n_kv_heads=2, + vocab_size=vocab_size, + multiple_of=32, + rope_theta=500000, + max_seq_len=seq_len, + ) + with torch.device("meta"): + model = Transformer(model_args) + + def input_fn(): + return torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + return AutoParallel( + model, + input_fn, + device_mesh, + mp_policy, + repeated_subgraphs=True, + ) + + +@apply_cuda_patches +@pytest.mark.filterwarnings("ignore:Constructing LpVariable") +@pytest.mark.filterwarnings("ignore:Using LpProblem.constraints") +def test_lp_relaxation_certifies_llama3_example_search(): + mesh = _fake_dp4_tp4_mesh() + with _llama3_example_autop(mesh) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + x_sharding = (Shard(0), Replicate()) + out_sharding = (Shard(0), Shard(2)) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + + opt = autop.sharding_optimizer + + binary_vars = list(opt.pulp_variables.values()) + assert binary_vars + assert all(var.cat == pulp.LpInteger for var in binary_vars) + assert all(var.lowBound == 0 and var.upBound == 1 for var in binary_vars) + + continuous_vars = opt._create_pulp_variables(pulp.LpContinuous) + assert continuous_vars + assert all(var.cat == pulp.LpContinuous for var in continuous_vars.values()) + assert all( + var.lowBound == 0 and var.upBound == 1 for var in continuous_vars.values() + ) + + lower_bound = opt.get_lower_bound() + assert lower_bound.status == "Optimal" + assert math.isfinite(lower_bound.objective) + assert lower_bound.objective >= 0 + + assert not hasattr(opt, "selected_keys") + assert opt.prob.objective is None + assert all(var.cat == pulp.LpInteger for var in opt.pulp_variables.values()) + + solution = opt.get_solution() + feasible_cost = pulp.value(opt.prob.objective) + certificate_gap = ( + feasible_cost - lower_bound.objective + ) / lower_bound.objective + assert solution + assert lower_bound.objective <= feasible_cost + 1e-5 + assert certificate_gap >= -1e-8 + assert math.isfinite(certificate_gap) diff --git a/tests/test_optimize_placement.py b/tests/test_optimize_placement.py index 59a4cf7c..20ac9d95 100644 --- a/tests/test_optimize_placement.py +++ b/tests/test_optimize_placement.py @@ -841,3 +841,36 @@ def input_fn(): # With memory budget enforced and no node constraint, the optimizer # should shard this param again assert solution[orig_node].output_specs.placements == (Shard(0),) + + +@apply_cuda_patches +def test_invalid_strategies_are_pruned(device_mesh_2d): + """Infinite-cost (invalid) strategy edges must not be materialized as + variables or constraints, and pruning them must not change the optimum.""" + import math + + mesh = device_mesh_2d + model_fn, input_fn = _make_model_and_input_fn(mesh, "transformer_block") + with torch.device("meta"): + model = model_fn() + + with AutoParallel(model, input_fn, mesh) as autop: + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Replicate())]) + autop.add_parameter_memory_constraint(low=None, high=None) + opt = autop.sharding_optimizer + + # Invariant: every materialized decision var is finite-cost, and the + # PuLP variable set is exactly the set of valid (finite) keys. + assert all(math.isfinite(dv.cost) for dv in opt.decision_vars.values()) + assert set(opt.pulp_variables) == opt._valid_keys + assert all(k in opt._valid_keys for k in opt.decision_vars) + + # No inf-cost (== 0) constraints should be emitted any more. + assert not any(name.startswith("inf_cases") for name in opt.prob.constraints) + + # The pruned problem must still solve to a valid solution. + solution = autop.optimize_placement() + param_nodes = get_param_nodes(autop.gm.graph) + for node in param_nodes: + assert node in solution diff --git a/tests/test_propagation.py b/tests/test_propagation.py new file mode 100644 index 00000000..34bb7af2 --- /dev/null +++ b/tests/test_propagation.py @@ -0,0 +1,222 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import pulp +import pytest +import torch +import torch.nn.functional as F +from conftest import apply_cuda_patches +from torch import nn +from torch._functorch._aot_autograd.fx_utils import get_param_and_grad_nodes +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel.api import AutoParallel +from autoparallel.propagation import ShardingAnnotation, ShardingPropagator + + +class TPBlock(nn.Module): + """A minimal transformer block: attention + SwiGLU FFN, the structure a + column/row-parallel tensor-parallel plan applies to.""" + + def __init__(self, dim=512, hidden=1024, nheads=8): + super().__init__() + self.nheads = nheads + self.wq = nn.Linear(dim, dim, bias=False) + self.wk = nn.Linear(dim, dim, bias=False) + self.wv = nn.Linear(dim, dim, bias=False) + self.wo = nn.Linear(dim, dim, bias=False) + self.w1 = nn.Linear(dim, hidden, bias=False) + self.w2 = nn.Linear(hidden, dim, bias=False) + self.w3 = nn.Linear(dim, hidden, bias=False) + + def forward(self, x): + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + o = F.scaled_dot_product_attention(q, k, v) + o = o.permute(0, 2, 1, 3).flatten(-2) + h = self.wo(o) + x + return h + self.w2(F.silu(self.w1(h)) * self.w3(h)) + + +def _input_fn(): + bs = 32 + return torch.randn(bs, 128, 512, device="cuda", requires_grad=True) + + +def _enter_autop(mesh): + with torch.device("meta"): + model = TPBlock() + autop = AutoParallel(model, _input_fn, mesh) + autop.__enter__() + autop.add_parameter_memory_constraint(low=None, high=None) + x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + return autop + + +def _annotate_tp(autop): + col, row = (None, Shard(0)), (None, Shard(1)) + for proj in ["wq", "wk", "wv", "w1", "w3"]: + autop.annotate_parameter(f"{proj}.weight", col) + for proj in ["wo", "w2"]: + autop.annotate_parameter(f"{proj}.weight", row) + + +@apply_cuda_patches +def test_propagation_matches_full_ilp(device_mesh_2d): + """Annotating the TP plan and propagating shrinks the search space while the + reduced ILP reaches the same optimum as the full ILP.""" + autop = _enter_autop(device_mesh_2d) + try: + opt = autop.sharding_optimizer + autop.optimize_placement(verbose=False) + obj_full = pulp.value(opt.prob.objective) + + _annotate_tp(autop) + result = autop.propagate_annotations(verbose=False) + opt.resolve(verbose=False) + obj_annotated = pulp.value(opt.prob.objective) + + assert opt.prob.status == 1 # Optimal + # Same optimum (propagation only pins reshard-free, unambiguous sharding). + assert obj_annotated == pytest.approx(obj_full, rel=1e-6) + # And it actually pruned a meaningful chunk of the search space. + assert result.reduction > 0.1 + assert result.nodes_determined > 0 + finally: + autop.__exit__(None, None, None) + + +@apply_cuda_patches +def test_lp_relaxation_is_integral_and_exact(device_mesh_2d): + """The LP relaxation of the sharding ILP is integral here, so solving it is a + cheaper exact solve: same objective as the ILP, with an extractable solution.""" + autop = _enter_autop(device_mesh_2d) + try: + opt = autop.sharding_optimizer + autop.optimize_placement(verbose=False) + obj_ilp = pulp.value(opt.prob.objective) + + lp = opt.solve_lp_relaxation(extract=True) + assert lp["n_fractional"] == 0 # relaxation is integral + assert lp["objective"] == pytest.approx(obj_ilp, rel=1e-6) + assert lp["solution"] is not None + # one strategy per (single-output) decision node + assert len(lp["solution"]) > 0 + finally: + autop.__exit__(None, None, None) + + +@apply_cuda_patches +def test_axis_constraint_fix_method_matches_constraint(device_mesh_2d): + """Pinning an axis by fixing variables gives the same result as the equality + constraint, and is exact.""" + autop = _enter_autop(device_mesh_2d) + try: + opt = autop.sharding_optimizer + fqn = {d.target: n for d, (n, _) in get_param_and_grad_nodes(opt.graph).items()} + wq = fqn["wq.weight"] + opt.add_node_axis_constraint(wq, mesh_dim=1, placement=Shard(0), method="fix") + solution = autop.optimize_placement(verbose=False) + assert opt.prob.status == 1 + placements = solution[opt._concrete_to_orig.get(wq, wq)].output_specs.placements + assert placements[1] == Shard(0) + finally: + autop.__exit__(None, None, None) + + +@apply_cuda_patches +def test_add_node_axis_constraint_pins_one_axis(device_mesh_2d): + """A per-axis constraint pins the chosen mesh axis and leaves the other free.""" + autop = _enter_autop(device_mesh_2d) + try: + opt = autop.sharding_optimizer + fqn = {d.target: n for d, (n, _) in get_param_and_grad_nodes(opt.graph).items()} + wq = fqn["wq.weight"] + opt.add_node_axis_constraint(wq, mesh_dim=1, placement=Shard(0)) + solution = autop.optimize_placement(verbose=False) + placements = solution[opt._concrete_to_orig.get(wq, wq)].output_specs.placements + # tp axis pinned to Shard(0); dp axis decided by the ILP. + assert placements[1] == Shard(0) + finally: + autop.__exit__(None, None, None) + + +@apply_cuda_patches +def test_axis_constraint_keeps_param_shardable_for_fsdp(device_mesh_2d): + """A per-axis tp constraint must not exclude a parameter from the memory + budget: it should still be shardable on the (free) data axis for FSDP.""" + autop = _enter_autop(device_mesh_2d) + try: + opt = autop.sharding_optimizer + fqn = {d.target: n for d, (n, _) in get_param_and_grad_nodes(opt.graph).items()} + wq = fqn["wq.weight"] + # Column-parallel on tp; data axis left open. + opt.add_node_axis_constraint(wq, mesh_dim=1, placement=Shard(0)) + solution = autop.optimize_placement(verbose=False) + assert opt.prob.status == 1 # feasible despite the tight memory budget + placements = solution[opt._concrete_to_orig.get(wq, wq)].output_specs.placements + # FSDP shards the data axis too (tight 1/world_size budget). + assert placements[0] == Shard(0) + assert placements[1] == Shard(0) + finally: + autop.__exit__(None, None, None) + + +@apply_cuda_patches +def test_seed_unachievable_raises(device_mesh_2d): + autop = _enter_autop(device_mesh_2d) + try: + opt = autop.sharding_optimizer + prop = ShardingPropagator(opt) + fqn = {d.target: n for d, (n, _) in get_param_and_grad_nodes(opt.graph).items()} + wq = fqn["wq.weight"] + # wq.weight is 2D; sharding a non-existent tensor dim 5 is impossible. + with pytest.raises(ValueError): + prop.seed(wq, (None, Shard(5))) + finally: + autop.__exit__(None, None, None) + + +@apply_cuda_patches +def test_propagation_determines_matmul_outputs(device_mesh_2d): + """Seeding the column-parallel weights determines the tp axis of the matmul + outputs (sharded on the output feature) with no resharding.""" + autop = _enter_autop(device_mesh_2d) + try: + opt = autop.sharding_optimizer + prop = ShardingPropagator(opt) + annotations = [] + fqn = {d.target: n for d, (n, _) in get_param_and_grad_nodes(opt.graph).items()} + for proj in ["wq", "wk", "wv", "w1", "w3"]: + annotations.append( + (fqn[f"{proj}.weight"], ShardingAnnotation((None, Shard(0)), 1)) + ) + for proj in ["wo", "w2"]: + annotations.append( + (fqn[f"{proj}.weight"], ShardingAnnotation((None, Shard(1)), 1)) + ) + determined = prop.run(annotations) + + # Every column-parallel matmul output should be tp-sharded (not replicated). + einsum_nodes = opt.graph.find_nodes( + op="call_function", target=torch.ops.aten.einsum.default + ) + if not einsum_nodes: + einsum_nodes = opt.graph.find_nodes( + op="call_function", target=torch.ops.aten.mm.default + ) + n_tp_pinned = 0 + for n in einsum_nodes: + if n in determined: + tp = dict(determined[n]).get(1) + if isinstance(tp, Shard): + n_tp_pinned += 1 + assert n_tp_pinned > 0 + finally: + autop.__exit__(None, None, None)