Skip to content

[pull] main from NVIDIA:main#613

Merged
pull[bot] merged 4 commits into
phu0ngng:mainfrom
NVIDIA:main
May 22, 2026
Merged

[pull] main from NVIDIA:main#613
pull[bot] merged 4 commits into
phu0ngng:mainfrom
NVIDIA:main

Conversation

@pull
Copy link
Copy Markdown

@pull pull Bot commented May 22, 2026

See Commits and Changes for more details.


Created by pull[bot] (v2.0.0-alpha.4)

Can you help keep this open source service alive? 💖 Please sponsor : )

ptrendx and others added 4 commits May 21, 2026 15:51
* Debugging

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* More debugging

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* Distinguish the users based on the write permissions rather than relying
on the member field, which could be set to private

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* Fix

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

---------

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
* Add MXFP8 grouped MLP SReLU fusion

Signed-off-by: sraman-rgb <sraman@nvidia.com>

* Address grouped MLP fused op review comments

Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>

* Avoid quantizing ScaledSReLU backward in basic op

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>

* Wire ScaledSReLU recompute in grouped MLP

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address grouped MLP ScaledSReLU review comments

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>

* Gate ScaledSReLU recompute support

Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>

* Use version check for dSReLU reuse arg

Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>

* Reuse forward dSReLU recompute decision

Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>

* Reject activation recompute without grouped MLP fusion

Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>

* Rename grouped MLP activation recompute flag

Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>

---------

Signed-off-by: sraman-rgb <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Co-authored-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: vthumbe1503 <vthumbe@nvidia.com>
* Batch CP attention tests via a persistent NCCL pool

The existing test path spawns one torchrun per parametrized case, paying
NCCL init + CUDA context + Python startup on every call. With ~hundreds
of cases the launch overhead dominates wall time and was a primary driver
of the L3 timeout that prior batching PRs worked around.

This change replaces the per-case subprocess with one long-lived
torchrun per (world_size). NCCL is initialized once at session start and
reused across cases. Pytest sends one JSON request per case over rank-0
stdin; the worker dispatches to run_dpa_with_cp(**kwargs), gathers
(ok, error) from every rank, and writes one JSON response on rank-0
stdout.

run_attention_with_cp.py is left almost untouched; a new
NVTE_CP_POOL_PG=1 env var gates the dist.init_process_group() and
dist.destroy_process_group() calls so the function reuses the pool's
main PG instead of creating its own. The per-case cp_comm_group (and
a2a+p2p sub-groups) are explicitly destroyed at function exit to
prevent communicator leakage across cases.

The PoolWorker class adds two pieces of error recovery that the prior
subprocess-per-case design got for free: a select-based per-call
timeout (default 600s, NVTE_CP_POOL_TIMEOUT_SEC) and auto-respawn on
worker death or timeout. A test-level exception is reported as an
AssertionError and the pool keeps running for the next case.

Two pool sizes are needed because cp_comm_type='a2a+p2p' requires
world_size=4 and the others use world_size=2; you can't resize an
active PG. Pools are spawned lazily so a 2-GPU-only run never pays the
4-GPU init.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Reset FP8 state and barrier between pool cases

Two resilience fixes carried over from the existing batching PR
(sudhakars/cp_test_batching_pr) without which the pool will
cascade-fail FP8 tests and silently propagate NCCL desync.

1. FP8GlobalStateManager.reset() between cases. FP8 quantizer state
   (recipe handles, autocast counters) lives in module-level globals.
   Reusing one Python process across cases otherwise carries that state
   forward. The prior batching PR landed an explicit fix for the same
   issue ("Fix FP8 cascade failures") after observing real test
   failures from this.

2. dist.barrier() after each case. If one rank's case errored before
   its last collective, the others can be stuck waiting on a comm that
   will never complete. The barrier here surfaces that immediately as
   a timeout in this case rather than letting the corruption leak into
   the next case's collectives.

Also pops the transient NVTE_* env vars run_dpa_with_cp sets at the
top of each call. run_dpa_with_cp already sets them unconditionally so
this is defensive, but cheap insurance against future variants that
might not.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Deep-copy ModelConfig in run_dpa_with_cp

The model_configs_{flash,fused}_attn dicts are module-level and shared
across pool cases. The THD branch below rewrites config.attn_mask_type
in place (causal -> padding_causal, no_mask -> padding). With the
persistent-pool runner, the next case looking up the same model key
gets the mutated config and fails the "causal or no_mask only" assert.

Caught at benchmark time on cp_2_0 + thd, identical to the cascade the
existing batching PR (sudhakars/cp_test_batching_pr) hit and fixed the
same way in commit 6355f62.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Skip deterministic configs incompatible with FusedAttention

Mirrors the two pre-emptive skips on the PR-batching branch:

* non-vanilla softmax with FusedAttention is not deterministic
* post_scale_bias with requires_grad is not deterministic

Without these skips, the corresponding configs propagate into the pool
worker under NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 and fail inside
run_dpa_with_cp instead of being marked SKIPPED.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Reseed RNG between pool cases; reset before, not after

The pool worker reused RNG state across cases, which produced
small numerical drift on some non-FP8 fused-attention configs
(cp_1_0 + thd/p2p, cp_1_0 + sbhd/all_gather) compared to the
single-shot worker. Matches the per-case startup of the single-shot
worker: torch.manual_seed(1234) + torch.cuda.manual_seed(1234) at
the start of every case, alongside the existing FP8 / env / cache
resets.

Moved the reset call from the post-case finally block to the start
of _run_one so the first case is also seeded consistently with
subsequent cases. Otherwise the first case would inherit the
process-default RNG and only the second-and-later cases would be
deterministic.

Validated locally: 38 passed, 0 failed (was 36 passed, 2 failed).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Robustify pool: capture worker stderr, tighten timeout, add timing knob

Three changes that bring the pool's failure semantics on par with the
per-batch torchrun approach in PR #2965 and remove a couple of footguns:

1. Capture pool-worker stderr into a ring buffer and attach the tail to
   crash-path AssertionErrors. Equivalent in spirit to PR #2965's
   run_distributed() — CI JUnit XML now shows the actual cause (NCCL
   error, Python traceback, OOM) inline with the failing test, instead
   of just "pool worker died mid-request" / "timed out". A daemon
   drainer thread reads stderr line-by-line into a deque(maxlen=200)
   and also echoes to sys.stderr so pytest's per-test capture still
   gets every line. Maximum buffered footprint ~40 KB.

2. Tighten POOL_SUBMIT_TIMEOUT_SEC default 600 -> 90. On H100 the
   slowest observed per-case wall is ~15 s (p99 also 15 s, p50 ~5 s).
   90 s gives ~6x headroom over the worst observed case while still
   detecting a genuine hang within ~1.5 min instead of ~10 min. Env
   var still overrides for slower machines or expanded test matrices.

3. Optional per-case wall-time logging (NVTE_CP_POOL_TIMING=1) prints
   "[POOL-TIMING] case_idx=N world_size=W wall_s=X.XXX ok=B" to stderr
   on rank 0 only. Grep-friendly; lets future tuning recalibrate the
   timeout against the observed distribution. Off by default so normal
   runs stay quiet.

Validated: 38 passed / 0 failed in 248 s on H100, test_essential=True,
with no perf regression vs the un-patched 256 s.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address PR review: NCCL leak, stdout protocol, Windows note

Three fixes responding to #2993
review comments:

P1: NCCL communicator leak on exception (run_attention_with_cp.py)

run_dpa_with_cp() created cp_comm_group (and optionally cp_comm_sub_groups)
near the top, but the destroy_process_group() calls ran only on the
success path at the end of the function. Any exception in between
(tensor assertion, OOM, NCCL error) skipped the cleanup, leaking
communicators in pool mode. Long sessions with repeated failures
could exhaust NCCL internal tracking.

Wrap the test work in try/finally so the destroy logic always runs.
Initialise cp_comm_sub_groups = [] unconditionally so the finally
block is safe even when cp_comm_type != "a2a+p2p" (or when an assert
fires before the populate loop). Each destroy is itself try/except so
a destroy failure on one group doesn't leak the others.

P2: stdout protocol can be corrupted by interleaved chatter

torchrun and ranks 1..N share rank 0's stdout fd. Any non-rank-0
print, NCCL debug line, or torchrun status output interleaves with
the JSON response and breaks json.loads, killing the pool with a
misleading "json decode error".

Prefix every response with "[CP_POOL_RESP] " in run_attention_with_cp_pool.py
and have PoolWorker.submit() scan stdout for sentinel-prefixed lines,
echoing non-protocol lines to stderr for visibility. Bounded scan
(MAX_NOISE_LINES=1000) so a chatty worker can't stall the parent.

P2 (doc): select.select on a pipe fd is Linux/macOS only

Added a short comment noting Windows portability. CP attention tests
run on Linux GPU hosts; this is a documentation issue, not a real bug.

Validated: 38 passed / 0 failed in 270 s on H100, test_essential=True
(was 248 s pre-P2 — the +22 s is the new sentinel-scan loop's per-line
overhead at ~600 ms/case, within noise).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [PyTorch] Fix stream race on max_logit_per_step in all-gather CP forward

In AttnFuncWithCPAndKVAllGather.forward, max_logit_per_step[i] is
written inside `with torch.cuda.stream(flash_attn_streams[i])`. For
i=1, flash_attn_streams[1] is cp_stream — i.e. *not* the default
stream. Later, at loop iteration i=2, the code reads
max_logit_per_step[1] via `torch.maximum(max_logit, max_logit_per_step[i-1])`
which runs on the default stream. Without an explicit wait_stream,
this is a read-after-write race across streams. The post-loop
`current_stream().wait_stream(cp_stream)` is too late — the race has
already fired.

The race is latent: outcome depends on stream scheduling. In a
fresh-process subprocess (one-torchrun-per-test path), streams are
cleanly initialised and timing happens to put the write before the
read. In a long-running persistent-worker process — exposed by
PR #2993's pool design — prior workloads shape stream state
differently, the read can fire before the write completes, and
max_logit ends up with stale values in some heads (~0.3 abs diff,
3/12 elements wrong on the H100 matrix).

Fix: insert `current_stream().wait_stream(flash_attn_streams[i-1])`
before the torch.maximum read. No-op when the streams are identical
(i=1 case, where flash_attn_streams[0] is current_stream), only
fires when reading from cp_stream (i=2 case).

Validated: 8xH100, test_essential=False, 348 passed / 0 failed in
27m 10s (was 323 passed + 5 failed at this commit's parent, all 5
failing on cp_comm_type=all_gather with mismatched max_logit).
The failing configs (all_gather + cp_1_0/cp_1_1 + bshd or fp16) now
pass under the pool — confirming the race was the sole root cause.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Address PR review (R2): drop dead code in pool worker and PoolWorker

Line-level cleanups from the second reviewer pass on PR #2993. Each item
is dead/redundant; none changes behaviour. Full-matrix test_essential=False
on 8xH100 still passes 348/0 in 26m 23s after these.

run_attention_with_cp_pool.py:
- Drop _TRANSIENT_ENV_KEYS tuple + pop loop. run_dpa_with_cp already
  re-sets NVTE_FUSED_ATTN/NVTE_FLASH_ATTN unconditionally at the top
  and pops the FP8 ones itself. The pop loop was defensive against a
  hypothetical "future caller that doesn't re-set them" that doesn't
  exist.
- Drop gc.collect() after torch.cuda.empty_cache(). The cases create
  no Python reference cycles between iterations and empty_cache only
  frees CUDA blocks PyTorch already considers free; the combination
  was no-op here.
- Drop dist.barrier() after dist.gather_object(). gather_object is
  itself a collective synchronization point — if every rank reaches
  it, none is ahead. The "surface a wedged communicator here" comment
  was wishful: a wedged communicator would already wedge the gather.

test_attention_with_cp.py (PoolWorker):
- Drop _MAX_NOISE_LINES = 1000 + the scanned counter + the
  unreachable post-loop "1000+ lines" branch. select()'s deadline
  already bounds the loop; the line-count cap was redundant and
  the over-limit branch was unreachable in practice.
- Inline _stderr_tail() into _diag(). Single caller, single use.
- Drop the _stderr_thread attribute. The drainer is daemon and
  self-terminates when the pipe closes; we never read the field
  anywhere, so initialising and nulling it was bookkeeping for no
  reason.
- Drop the dead assert in submit() — _ensure_alive() on the prior
  line already guarantees proc/stdin/stdout exist.

Deferred to a follow-up:
- L8 (drop try/except around dist.destroy_process_group). Real
  semantic change: hides errors that occur when a previous test
  wedged the communicator. Worth doing but needs its own validation.
- R1 medium items M1 (module-level flag vs NVTE_CP_POOL_PG env var),
  M2 (redirect rank>0 stdout vs sentinel scan), M3 (explicit
  CUDA_VISIBLE_DEVICES per pool). Same reasoning — separate PRs.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Address PR review (items 2+3): reuse CP groups across pool cases

world_size and the rank set don't change for the lifetime of one pool, so
recreating the world group and a2a+p2p sub-groups per case wastes ~50-100 ms
of NCCL setup each. Pre-create them once in the pool worker (new helper
_create_cp_comm_groups), stash on the run_attention_with_cp module via
module-level _pool_cp_comm_group / _pool_cp_comm_sub_groups pointers, and
reuse them from run_dpa_with_cp in pool mode. Pool teardown destroys them
once at shutdown.

Also move per-case dist.new_group() calls inside the try/finally in
run_dpa_with_cp: a failure mid-loop in the a2a+p2p sub_group population
otherwise leaks every communicator created before the failure. The finally
now only destroys groups we created locally (cp_comm_group / sub_groups
populated in the else-branch), leaving pool-owned groups alone for reuse.

cyanguwa's review feedback on PR #2993.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Flatten try/finally wrap in run_dpa_with_cp

The Round-1 P1 NCCL-communicator-leak fix (e162a9e) wrapped the
~540-line body of run_dpa_with_cp in try/finally. The wrap itself
was tiny but it re-indented every line of the body by one level,
inflating the PR diff of run_attention_with_cp.py to ~1000 lines
against origin/main.

Items 2+3 (d15bfce) since made the wrap unnecessary:
  - In pool mode, cp_comm_group and cp_comm_sub_groups are owned by
    the pool worker (which destroys them once at pool shutdown).
    run_dpa_with_cp neither creates nor destroys them, so an
    in-body exception can't leak communicators.
  - In single-shot mode, groups are still created locally, but the
    subprocess exits at function return; NCCL releases everything
    at process teardown, so a stray exception leaks communicators
    only for the milliseconds before the process dies — a bounded
    one-off cost, not the unbounded accumulation that Round-1
    flagged for pool mode.

Removing the wrap drops the run_attention_with_cp.py diff against
origin/main from ~1000 lines to ~120 lines without changing
observable behaviour. Smoke-tested: 4 representative cases pass.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Set test_essential=True to match shipping default

Round-3 review (greptile, discussion_r3250016711) flagged that the
working tree had test_essential=False — i.e. the full ~328-config
matrix instead of the ~38-config essential subset that the rest of the
CI matrix expects. Flipping back to True so CI doesn't regress baseline
on the known H1-style cascade configs that only appear in the full
matrix.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Retry once on pool-infrastructure failures with stderr-logged flake trace

The pool worker subprocess can die mid-case due to async NCCL aborts or
flaky 4-GPU collective state that doesn't reproduce on a fresh pool.
Without retry, these manifest as one-off CI failures attributable to
infrastructure, not the PR's content.

Add a single-attempt retry around PoolWorker.submit() that fires only
on infrastructure failure modes (pool-worker-died, timeout,
broken-pipe-pre-send). Test-assertion failures from the worker
(resp["error"]) carry full per-rank tracebacks and propagate without
retry — so a real bug still surfaces as FAILED.

Visibility: every retry attempt writes a [POOL-RETRY] line to stderr.
pytest captures per-test stderr and writes it into JUnit
<testcase>/<system-err>. A flaky test will appear as PASSED in the
case row but with a [POOL-RETRY] line in <system-err> — visible to
the reviewer, and queryable by CI dashboards looking for flake
patterns (e.g. "same test_id retries across multiple CI runs").

If both attempts die, a [POOL-RETRY-FAIL] line is also logged with
the first error's headline, then the second attempt's full traceback
propagates as the test failure.

Smoke-tested: 3 representative cases (p2p, a2a flash; p2p fused)
still PASS in 19 s.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [PyTorch] Pool: redirect non-rank-0 stdout to /dev/null; drop sentinel

Replaces the [CP_POOL_RESP] sentinel-prefix protocol with a stronger
fix at the source: on rank>0, close stdout at the fd level via dup2
to /dev/null at worker startup. Catches both Python `print` writes
and C-level (NCCL, libc, etc.) writes that the sentinel could only
mitigate by scanning + skipping non-protocol lines.

With non-rank-0 stdout silenced, rank 0's JSON line is the only
thing that reaches the parent's pipe, so PoolWorker._submit_once
collapses from a sentinel-scanning while loop to a single
select + readline + json.loads.

Closes follow-up M2 from the PR description; addresses greptile's
review comment on stdout pollution. Validated on 8xH100 with the
test_essential=True flash-attn pool path (9 passed / 55 skipped /
0 failed in 56s; no JSONDecodeError, no protocol corruption).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Address PR review (R3): backend-cache, pool isolation, group-kill, decode-safety

- Invalidate DotProductAttention._attention_backends between pool cases so
  per-case NVTE_FLASH_ATTN/NVTE_FUSED_ATTN toggles take effect instead of
  reusing the previous case's resolved backend.
- torch.cuda.empty_cache() after each case so a 2-GPU pool doesn't squat on
  GPUs that an overlapping 4-GPU pool needs.
- PoolWorker subprocess uses start_new_session=True; _kill() uses killpg on
  the whole process group so torchrun's rank workers don't survive as
  orphans holding CUDA/NCCL state.
- On a failed worker response, kill the pool before raising so half-aborted
  CUDA/NCCL/FP8 state from a failed case doesn't leak into the next.
- Guard json.loads with a try/except + diagnostic so any rank-0 stdout
  pollution surfaces as a clear test failure rather than a silent protocol
  desync.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
update cudnn-fe 1.24

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@pull pull Bot locked and limited conversation to collaborators May 22, 2026
@pull pull Bot added the ⤵️ pull label May 22, 2026
@pull pull Bot merged commit 856d075 into phu0ngng:main May 22, 2026
10 of 11 checks passed
@pull pull Bot had a problem deploying to github-pages May 22, 2026 04:33 Failure
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants