perf(prefill): tighten WMMA dispatch guard — graceful fallback for small-N#180
Merged
Conversation
…all-N Follow-up to PR #179. Cross-model validation (Qwen3-Coder-NVFP4, Qwen3.6-NVFP4, Qwen3-30B-Modelopt, Gemma-4-NVFP4) on warm container revealed the v1 WMMA kernel is significantly slower than cuBLASLt on small-N shapes — Qwen3.6 has a Q/K/V projection at M=512 N=32 K=2048 where the kernel launches only ⌈32/BN⌉=1 block × ⌈512/BM⌉=4 = 4 blocks across 128 SMs (3% SM saturation) AND wastes 75% of MMA cycles on zero-padded B fragments. Same-session A/B (post-warmup, pp=1024 reps=3) with the permissive guard (PR #179): | Model | Baseline | IMP_PREFILL_GRAPH=1 | Δ | |---|---:|---:|---:| | Qwen3-Coder-30B-NVFP4 | 15336 | 14593 | -4.8% | | Qwen3.6-35B-NVFP4 | 10219 | 8614 | -15.7% | | Gemma-4-26B-NVFP4 | 26524 | 31277 | +17.9% | The earlier +27% Qwen3-Coder measurement at PR #179 was on a cold-container session where cuBLASLt was at its slowest; the warm A/B above shows the WMMA kernel is ~5% behind cuBLASLt on Qwen3-Coder shapes, not ahead. Qwen3.6 regresses 15.7% because the kernel runs the N=32 shape with single-block grid + heavy MMA waste. ## Fix Tight dispatch guard (`N >= BN && M >= BM`): rejects shapes where the WMMA kernel is uncompetitive. Caller (`gemm.cu`) falls through to cuBLASLt for declined shapes; under stream capture cuBLASLt fails with status 14, the wrapper aborts capture, and falls back to eager — same as baseline, no regression. Models with all GEMMs ≥ BN (Qwen3-Coder, Modelopt) are unaffected by the guard change. Gemma-4 (SWA, no chunked prefill, wrapper doesn't fire) still gets the +17.9% from PR #179's engine-init prewarms. ## What ships Single-line guard tightening in `gemm_capture_fp16_sm120.cu`. Keeps `IMP_PREFILL_GRAPH=1` opt-in. Default behavior unchanged. Default flip deferred until WMMA kernel achieves cuBLASLt-warm-state parity across all production NVFP4 MoE shapes — multi-day kernel work (cp.async pipelining, BN=32 small-N specialization, possibly larger tile geometry). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Follow-up to #179. Cross-model validation revealed the v1 WMMA kernel is significantly slower than cuBLASLt on small-N shapes. Tighten the dispatch guard so such shapes fall back to eager (no regression) instead of capturing a slow kernel.
Cross-model A/B (same-session warm-container, pp=1024 reps=3)
IMP_PREFILL_GRAPH=1(PR #179)The earlier +27% Qwen3-Coder measurement at PR #179 was on a cold-container session where cuBLASLt was at its slowest. Warm-state shows the WMMA kernel is ~5% behind cuBLASLt on Qwen3-Coder shapes. Qwen3.6 regresses 15.7% because its Q/K/V projection at M=512 N=32 K=2048 launches only ⌈32/BN⌉=1 block × ⌈512/BM⌉=4 = 4 blocks across 128 SMs (3% SM saturation) AND wastes 75% of MMA cycles on zero-padded B fragments.
Fix
Tight dispatch guard
N >= BN && M >= BMrejects shapes where the WMMA kernel is uncompetitive. Caller falls through to cuBLASLt; under stream capture cuBLASLt fails with status 14, the wrapper aborts capture, and falls back to eager — same as baseline, no regression. Models with all GEMMs ≥ BN (Qwen3-Coder, Modelopt) are unaffected by the guard change.Gemma-4 (SWA, no chunked prefill, wrapper doesn't fire) still gets the +17.9% from #179's engine-init prewarms — that path is independent of the WMMA kernel.
Default flip status
Deferred. Real
IMP_PREFILL_GRAPHdefault-on requires the WMMA kernel to achieve cuBLASLt-warm-state parity across all production NVFP4 MoE shapes — multi-day kernel work (cp.async pipelining, BN=32 small-N specialization, possibly larger tile geometry).IMP_PREFILL_GRAPH=1remains opt-in.🤖 Generated with Claude Code