Skip to content

Topo-span-bounded FSDP bucketing for AutoParallel#481

Open
fmassa wants to merge 2 commits into
mainfrom
fmassa/better_bucketing
Open

Topo-span-bounded FSDP bucketing for AutoParallel#481
fmassa wants to merge 2 commits into
mainfrom
fmassa/better_bucketing

Conversation

@fmassa

@fmassa fmassa commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds FSDP-bucketing improvements to AutoParallel's auto_bucketing.py: a monkey-patch installed at import time that fixes multi-process-group handling and bounds each bucket's graph topo-span at formation time. Replaces PyTorch's default behavior, which on AP-generated graphs (multiple PGs, interleaved collectives) under-buckets, and adds a single declarative knob to prevent the MM*N batching problem that would otherwise inflate backward peak memory.

What's in the branch

autoparallel/graph_passes/auto_bucketing.py (+211 LOC vs main)

Adds _patch_fsdp_bucketing(), invoked at module-load time, which replaces two functions in upstream PyTorch:

  1. identify_fsdp_groups — primary-group-only. Returns just the group with the most FSDP all-gathers, instead of all groups. Without this, minority groups (e.g. 1 TP AG, a handful of norm AGs on the combined group) pollute the bucketing pool and constrain DP-group merging.

  2. greedy_bucket_collective_by_mb — group-by-key (no graph-adjacency requirement) with two close-bucket conditions:

    • bytes cap (existing in upstream, preserved): close a bucket when adding the next collective would exceed bucket_cap_mb.
    • topo-span cap (new): close a bucket when its span (rank(latest_member) - rank(first_member)) would exceed aten_autobucketing_config.max_topo_span. Ranks are snapshotted at function entry, before any graph mutation. Bounds how far compute can be displaced when bucketing rewires the dependency graph and stable_topological_sort runs afterwards.

Additional changes:

  • aten_autobucketing_config.max_compute_pre_fetch: 5 → 50 (≈3 layers of headroom for hiding bucketed AGs vs prior ≈0.3 layers).
  • New aten_autobucketing_config.max_topo_span: int | None = 1500. Set to None to disable the span bound and degenerate to bytes-only behavior.
  • INFO-level metrics inside _patched_greedy_bucket: per-invocation (num_buckets, max_observed_span, n_close_bytes, n_close_span, max_topo_span). Makes regressions visible.
  • INFO log of max_consec_compute_between_rs after aten_autobucketing_reordering_pass — the regression metric the patch is keeping bounded.
  • _max_consec_compute_between_rs(graph) helper exposed for callers/tests.

tests/test_auto_bucketing_patches.py (+280 LOC, new file)

8 FX-level tests (~3 s, no GPU, no init_pg), each builds a synthetic graph that satisfies PyTorch's is_fsdp_all_gather/is_wait_tensor contracts:

  • identify_fsdp_groups: primary group wins on imbalanced counts; empty graph returns empty; ties pick exactly one.
  • greedy_bucket: merges within caps; splits on bytes; splits on span when bytes are under cap; max_topo_span=None disables span; descendant collectives never co-bucket.

Mutation-verified: reverting either patch causes the corresponding test to fail with an informative message.

Validation on production workload

LLaMA-3 8B, 32 layers, seqlen=8192, global batch=32, 128 H100s (job 6513650).

Unconstrained (vs TorchTitan reference 366/414/1743 ms):

Metric New (max_topo_span=1500) Prior best (lucky) Prior typical
min ms 379.3 378.3 370–377
avg ms 389.6 384.2 424–569
max ms 436.6 412.0 1856–5982
alloc GiB 7.69 7.55 7.52–9.06
rsrvd GiB 8.06 8.20 8.20–10.86

Min latency within ±9 ms of prior best. Avg latency second-best across all runs (only the lucky prior beats it). Max latency best of all runs — typical prior runs had multi-second tails that disappear with the new bucketing-time prevention. Variance (avg − min) collapses from 5–200 ms to 10 ms.

Constrained:

Metric New Prior range
min ms 443.6 439.8–445.9
avg ms 484.3 447–628
rsrvd GiB 7.51 7.49–7.59

Min within prior noise; avg middle-of-pack; rsrvd memory best-of-all-runs.

Telemetry from the shipped run shows max_topo_span=1500 is dormant at this config — only 0–1 span closures per invocation across all graphs; behavior is effectively bytes-cap-only plus primary-group filter plus descendant-only conflict detection. Backward max_consec_compute_between_rs=12 (vs 8 under the prior post-hoc cap — 50% looser); the small min-latency gap may be attributable to those extra 4 MMs concentrating before each RS. A max_topo_span=1000 run is in flight to test whether tightening the gate closes the gap without reintroducing variance.

Design rationale

Bucketing-time prevention rather than post-hoc correction. One declarative knob (max_topo_span) replaces a coupled pair of patches plus a downstream remediator that had silent-failure modes. Cherry-style behavior remains recoverable as a knob setting (max_topo_span=None) rather than a separate codepath. INFO-level metrics make regressions visible in benchmark logs.

Authored with Claude.

Ports three related changes to `auto_bucketing.py` from `fmassa/double_recomp` (cherry of 428e9d2). Together they address an interaction between FSDP all-gather bucketing and `stable_topological_sort` that batches recomputation MMs upfront and inflates backward peak memory.

**1. `_patch_fsdp_bucketing()` — monkey-patches PyTorch's FSDP bucketing**
- *Primary-group-only*: `identify_fsdp_groups` keeps only the group with the most FSDP all-gathers, so minority groups (1 tp AG, ~65 norm AGs on the combined group) no longer pollute the bucketing pool.
- *Non-adjacent bucketing*: `greedy_bucket_collective_by_mb` collects all eligible collectives per group key instead of requiring graph adjacency, allowing dp AGs interleaved with tp activation collectives to bucket together.

**2. `max_compute_pre_fetch`: 5 → 50**
The previous value allowed only ~0.3 layers of prefetch (≈17 compute nodes/layer), insufficient to hide 5–7ms full-mesh AGs. 50 gives ≈3 layers of headroom.

**3. `_cap_compute_batch_size(max_consecutive=8)`**
After bucketing rewires dependencies, `stable_topological_sort` reorders 525/540 compute ops into an `MM*40` block before the first backward RS, blowing up peak memory. Snapshots the original compute/RS interleaving before scheduling, then for any post-schedule segment with >8 compute nodes between RS ops, chains chunks and pulls forward an RS that originally sat between them. Falls back gracefully if a cycle is detected.

**Why this matters (from prior investigation, LLaMA-3 8B, 128 H100s, dp=16/tp=8):**
- Bucketing patches + prefetch bump closed the unconstrained AP-to-reference gap from 12.1% → 4.6% (385ms vs 368ms reference). Prefetch alone: −20ms (−4.6%) unconstrained.
- `_cap_compute_batch_size` brings constrained from a runaway `MM*40` recomp batch down to **358ms / 16.78 GB** (vs reference 339ms / 5.97 GB), trading +5.9ms latency for −4.3 GB peak memory vs the uncapped variant. A more aggressive `_restore_compute_order` was rejected — it killed 62.8ms of overlap for only 1.3 GB extra savings.

Authored with Claude.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 2, 2026
@fmassa fmassa requested a review from IvanKobzarev June 2, 2026 19:57
Summary

Replaces the `_cap_compute_batch_size` post-hoc remediator with a topo-span close-bucket condition inside the patched FSDP bucketing. Same load-bearing knob (`bucket_cap_mb`) gains a companion (`max_topo_span`) that bounds the dependency-footprint problem the cap was trying to undo. The bucketing-time prevention has shipped at `max_topo_span=1500` on `fmassa/double_recomp` and matches the prior cap-based path on min latency while substantially improving variance.

Motivation

The prior cherry (`9871a48`) combined two opposing patches: non-adjacent bucketing (which widens bucket dependency footprints to improve NCCL throughput) and `_cap_compute_batch_size` (which reorders compute post-hoc to undo the resulting MM*N batching). The cap relied on pre-bucket reduce_scatter names surviving bucketing — a property that doesn't always hold when bucketing renames RS ops — so the cap silently degrades to chain-deps-only and doesn't bound `max_consec_compute_between_RS`. Bounding bucket topo-span at the source addresses the same MM*N harm without a downstream remediator, surfaces one tunable knob users can reason about, and degenerates to the prior cherry's behavior when `max_topo_span=None`.

What changed

`autoparallel/graph_passes/auto_bucketing.py`, −56 net LOC:

- **Kept** `_patched_identify_fsdp_groups` (primary-group-only filter) and `max_compute_pre_fetch=50`.
- **Replaced** the non-adjacent `_patched_greedy_bucket` with a version that adds `close_for_span` as a third close-bucket condition alongside `close_for_bytes`. Snapshots ranks at function entry, closes a bucket when `current_rank - bucket_start_rank > max_topo_span`.
- **Added** `aten_autobucketing_config.max_topo_span: int | None = 1500`. Set to `None` to restore the prior bytes-only (cherry) behavior with no codepath difference.
- **Added** INFO-level metrics inside `_patched_greedy_bucket`: per-invocation `(num_buckets, max_observed_span, n_close_bytes, n_close_span, max_topo_span)`. The cap's silent-failure mode is no longer possible.
- **Added** `_max_consec_compute_between_rs(graph)` helper, logged at INFO after `aten_autobucketing_reordering_pass` as a regression metric matching what the old cap enforced.
- **Removed** `_cap_compute_batch_size` (~120 LOC) and the pre/post-bucket name-snapshotting that fed it.

Validation

`tests/test_auto_bucketing_patches.py` (8 tests, ~3 s, no GPU):

- `identify_fsdp_groups`: primary group wins on imbalanced counts; empty graph returns empty; ties pick exactly one.
- `greedy_bucket`: merges within caps; splits on bytes; splits on span when bytes are under cap; `max_topo_span=None` disables span; descendant collectives never co-bucket.

Mutation-verified — reverting either patch causes the corresponding test to fail with an informative message.

Real benchmark results (LLaMA-3 8B, 32 layers, 128 H100s, seqlen=8192, global batch=32)

`max_topo_span=1500`, job 6513650, vs the most-stable prior cap-based run (6477238):

**Unconstrained**

| Run                        | min ms  | avg ms    | max ms    | alloc GiB | rsrvd GiB | MFU       |
| -------------------------- | ------- | --------- | --------- | --------- | --------- | --------- |
| New (`max_topo_span=1500`) | 379.3   | **389.6** | **436.6** | 7.69      | **8.06**  | 492.5 %   |
| Prior `_cap` (cleanest)    | 378.3   | 384.2     | 412.0     | 7.55      | 8.20      | 499.4 %   |
| Prior `_cap` (typical)     | 370–377 | 424–569   | 1.8–6.0 s | 7.55      | 8.20      | 337–452 % |
| TorchTitan reference       | 366.2   | 414.1     | 1742.8    | 8.62      | 9.43      | 463.4 %   |

- **Min latency**: 379.3 ms, ~+1–9 ms over best prior runs; within the prior-run min spread.
- **Avg latency**: 389.6 ms — second-best across all runs (only the lucky 6477238 beats it at 384.2 ms); typical prior cap-based avg was 424–569 ms.
- **Max latency**: 436.6 ms — **best of all benchmarked runs**; every other prior run except 6477238 had >1.8 s tails.
- **Variance**: avg − min collapses from 5–200 ms range to 10 ms. The cap's silent-failure failure mode (which produced occasional multi-second tails) appears to be the source of prior variance; bucketing-time prevention is much more deterministic.
- **Memory**: +0.14 GiB alloc vs cap-based prior; rsrvd memory is the best of all runs.

**Constrained**

| Run                        | min ms      | avg ms  | max ms      | alloc GiB | rsrvd GiB |
| -------------------------- | ----------- | ------- | ----------- | --------- | --------- |
| New (`max_topo_span=1500`) | 443.6       | 484.3   | 662.9       | 7.11      | **7.51**  |
| Prior `_cap` (range)       | 439.8–445.9 | 447–628 | 0.98–5.58 s | 7.12      | 7.49–7.59 |

- Min within prior noise band; avg/max middle of prior variance; **rsrvd memory is the best of all runs**.

Telemetry from the shipped run

The `max_topo_span=1500` knob is **dormant in practice** at this configuration:

- Fwd: 27 buckets, max span 619, closures `(bytes=26, span=0)`, `max_consec_compute_between_rs=8`
- Bwd group A: 43 buckets, max span 1325, closures `(bytes=40, span=1)`
- Bwd group B: 64 buckets, max span 561, closures `(bytes=62, span=0)`
- Bwd `max_consec_compute_between_rs=12` (vs 8 under the old explicit cap — 50% looser)

The span gate fires at most once across all bucketing invocations; actual behavior is byte-cap-only plus primary-group filter. The 4-extra MMs concentrated before each RS may explain the ~9 ms unconstrained min-latency gap vs the best prior run — testing `max_topo_span=1000` is in progress to validate whether tightening the gate closes that gap without reintroducing variance.

What this gives up vs the cap-based approach

- Min latency: within ±9 ms of best prior min, plausibly recoverable by tuning `max_topo_span` downward.
- Memory: rsrvd best-of-all-runs in both modes; alloc within 2% of prior best.

What this gains

- Avg/max latency stability: variance collapses dramatically vs the prior cap-based path (typical prior tails of 1.8–6.0 s disappear in unconstrained; 0.98–5.58 s tails in constrained reduce to 0.66 s).
- One declarative knob (`max_topo_span`) replaces two coupled patches plus a downstream remediator with a known name-matching brittleness.
- Loud metrics: `n_close_bytes` / `n_close_span` / `max_consec_compute_between_rs` are logged at INFO; the prior cap's silent-failure mode is no longer possible.
- ~56 net LOC removed.

Authored with Claude.
@fmassa fmassa changed the title Improve FSDP bucketing and cap compute batches between ReduceScatters Topo-span-bounded FSDP bucketing for AutoParallel Jun 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant