tests: fix SM100 varlen backward failures on B200#2534
Draft
Johnsonms wants to merge 5 commits into
Draft
Conversation
hdim=192 on SM100 requires 2CTA instructions, but softcap injects a score_mod that disables 2CTA, triggering the assertion in FlashAttentionBackwardSm100.__init__. The non-varlen test already gates its backward on softcap==0.0; add the equivalent skip to the varlen backward block.
torch.AcceleratorError is the async variant of OOM — the allocation fails in a prior CUDA op and the error surfaces on the next API call. The existing retry_on_oom only caught torch.OutOfMemoryError, so async OOMs caused by concurrent kernel compilation across 64 xdist workers were not retried.
SM100 varlen kernel hangs when deterministic=True and softcap > 0.0. Skip until the kernel-side bug is fixed.
27d692d to
fc577ea
Compare
…and local+softcap
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Skip SM100 hd192 bwd + softcap:
d=192on SM100 requires 2CTA instructions, butsoftcap > 0.0injects ascore_modthat forcesuse_2cta_instrs=False, hitting the assertion inFlashAttentionBackwardSm100.__init__. Addedpytest.skipin the varlen backward block, matching the existing pattern ford=256. Fixes 49 CI failures.Retry on
AcceleratorErrorOOM:retry_on_oomonly caughttorch.OutOfMemoryError. Async CUDA OOM raisestorch.AcceleratorErrorinstead (allocation fails in a prior op, surfaces on next API call). Extended the catch to include both, still guarded by the"out of memory"message check.Repro
AssertionError: Must use 2CTA for hdim 192 flash_attn/cute/flash_bwd_sm100.py:93
Triggered by any
test_flash_attn_varlen_outputcase withd=192, softcap=15.0on SM100 (B200). Root cause:FlashAttentionBackwardSm100setsuse_2cta_instrs = use_2cta_instrs and ... and score_mod is None, so softcap's score_mod silently disables 2CTA, then the assertion fires. The non-varlen test (test_flash_attn_output) was already guarded byand softcap == 0.0in its backward condition; the varlen test was missing the equivalent guard.For the OOM:
torch.AcceleratorError: CUDA error: out of memorysurfaces at an innocent call (lengths[i] = 0) because the actual allocation failure happened asynchronously in a prior CUDA op during concurrent kernel compilation across 64 xdist workers.Test plan
Ran on B200 (SM100) locally:
pytest tests/cute/test_flash_attn.py -k "test_flash_attn_varlen_output and 192 and 15.0"Result: 48384 skipped, 0 failed (1:27:33) — all previously failing cases now skip correctly via the new guard.
Full suite result with both fixes applied:
168605 passed, 249112 skipped, 0 failed (0:32:59)