Skip to content

[Cute,Sm100] allow for zero length sequences in hdim 256 kernels#2568

Merged
Johnsonms merged 1 commit into
mainfrom
jshah/hdim256-varlen-zero-lengths
May 16, 2026
Merged

[Cute,Sm100] allow for zero length sequences in hdim 256 kernels#2568
Johnsonms merged 1 commit into
mainfrom
jshah/hdim256-varlen-zero-lengths

Conversation

@jayhshah
Copy link
Copy Markdown
Collaborator

@jayhshah jayhshah commented May 15, 2026

We add support for zero-length Q and KV sequences for varlen mode in the sm100 hdim 256 kernels. Changes are as follows:

Forward: have softmax execute one dummy iteration to not hang for zero-length K. We also add row sum check for zero or NaN to not output NaN in this case.

Backward dQ: for guarding work tile per-iteration, check also that trip count is non-zero. If it is zero, write zero as the output.

Backward dKV: zero-length Q and K was nominally supported but the write zero logic was broken and yielded IMA; now fixed in the PR.

@Johnsonms
Copy link
Copy Markdown
Collaborator

PR fixes the bug.

Before patch (b11 baseline):

  • forward OK → out.backward(g) → cudaErrorIllegalAddress in _flash_attn_bwd (sm100 hd=256 kernel)

After patch (jshah/hdim256-varlen-zero-lengths, commit 75db52f):

  • forward OK
  • backward OK
  • grads finite, non-zero, and structurally correct — dk[0:2538] (the K rows paired with the zero-length Q segment) is
    exactly zero as expected, while dk[2538:] carries signal.

@Johnsonms Johnsonms merged commit 8a8b2f1 into main May 16, 2026
2 of 3 checks passed
ussoewwin added a commit to ussoewwin/flash-attention that referenced this pull request May 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants