perf(attn): device-side ptr-array builder for cuBLAS GQA prefill (Phase 4 Step 1)#177
Merged
Merged
Conversation
Phase 4 Step 1 of MoE-prefill CUDA-graphs work (per prefill_graph_blockers_2026_05_14 memo). The GQA path in attention_cublas_prefill built per-head pointer arrays on the host stack and uploaded them via 3× cudaMemcpyAsync per cuBLAS GemmBatchedEx call — 6× per attention call total. Host stack memory has no stable identity across CUDA graph replays and the H2D copies abort capture. Replace both blocks with a small device kernel that writes the pointer arrays directly to s_attn_d_ptrs. Pointer pattern is pure arithmetic: A: GQA-shared, ptr = base_A + (h / gqa_ratio) * stride_A_bytes B: per-head, ptr = base_B + h * stride_B_bytes C: per-head, ptr = base_C + h * stride_C_bytes Same s_attn_d_ptrs storage, same cuBLAS calls — only the producer changes. Graph-safe: kernel reads only its scalar args (which graph capture bakes in) and writes to device-resident pointers. Net: -6 H2D copies per attention call, no behavior change for non-graph paths, MHA path (cublasGemmStridedBatchedEx) untouched. Validation: - make build → 0 warnings, 0 errors - test-attention → 77/77 pass - Gemma-4-26B-A4B-NVFP4 (32 Q heads / 4 KV groups, gqa_ratio=8) smoke: "The capital of France is **Paris**." — coherent - Production GQA models exercised: Gemma-4-NVFP4, Qwen3.6-NVFP4, Qwen3-Coder-NVFP4 (all use this path via attention_cublas_prefill) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kekzl
added a commit
that referenced
this pull request
May 14, 2026
…ep 3) (#178) Wires CudaGraphRunner around the non-last-chunk forward_logits in the chunked-prefill path. Captures (a) the device-args MoE prefill path (default-on since PR #164) and (b) the cuBLAS GQA attention path (now device-ptr-array based since PR #177). Opt-in via IMP_PREFILL_GRAPH=1; default behavior unchanged. Also pre-creates the attention_cublas static cuBLAS handle at engine init via a new attention_cublas_prewarm() entry point. Without this, the first attention_cublas_prefill call inside a captured stream would trigger cublasCreate, whose internal cudaMalloc for workspace is illegal under capture (CUBLAS_STATUS_NOT_INITIALIZED → abort()). gemm_init() already follows the same pattern for the dense GEMM handle. ## Capture status (empirical, Qwen3-Coder-30B-NVFP4, pp=1024 reps=3) - Build: 0 warnings, 0 errors - cuBLAS handle init: clean (no more cublasCreate-under-capture abort) - Warmup forward_logits: runs eager, primes caches and handles - Capture step: graph captured successfully - **Replay: IMA (illegal memory access)** — exactly the failure mode documented in `prefill_graph_blockers_2026_05_14` memo for Blocker B ("captured graph references memory whose addresses differ across replays"). Confirms the residual structural blockers post-PR-#177 are state-lifecycle issues, not API-discovery issues. ## What ships - Scaffolding (env-gated, default off): production behavior unchanged - Foundation for incremental Blocker-B fixes (each replay-IMA source can be isolated and fixed under IMP_PREFILL_GRAPH=1 without affecting default decode/prefill) ## What remains Per memo step 4 (audit 95 H2D/sync sites), step 5 (per-shape graph pool), step 6 (4-model validation). The IMA root cause is the next debugging target — likely chunked-prefill's per-call cudaMallocAsync for `k_full`/`v_full` at executor_attention.cu:762-763 when `q_offset > 0`. Captured graph might also be re-reading from a freed pf_pool slot, or the KV cache block_table content has shifted. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
5 tasks
github-actions Bot
pushed a commit
that referenced
this pull request
May 15, 2026
…ctx +7%) (#187) * perf(gemma4): drop FP8 prefill carve-out (re-measured neutral/long-ctx win) Removes the auto-disable at engine.cpp:832-844. The 2026-05-09 measurement showing -5..-19% prefill on Gemma-4 (vs FP16) was real at the time but substantially closed by intermediate prefill work (PR #177 device-side ptr-array, PR #181 WMMA cp.async, etc.). Re-measured 2026-05-15 on Gemma-4-26B-A4B-it-Q4_K_M (5 reps, --bench-pp <N> --temperature 0): | pp | FP8 OFF tok/s | FP8 ON tok/s | delta | |-------|---------------|--------------|--------| | 128 | 870 | 879 | +1.0 % | | 512 | 1732 | 1717 | -0.9 % | | 833 | 1649 | 1579 | -4.2 % | | 2048 | 1624 | 1742 | +7.3 % | Net effect is neutral with a long-context advantage. FP8 prefill also halves the activation cache size, which is a real VRAM win at long ctx. Coherence: chat-template gemma + "What is the capital of France?" → "**Paris**." (bit-exact between FP8 and FP16 paths). make verify-fast: green (post-variance re-run — first run regressed on Qwen3-8B baseline, gone on retry; the change is gated `if GEMMA4` so non-Gemma-4 archs are untouched). Closes the last entry in the "Gemma-4 remaining carve-outs" roadmap section (FP8 KV cache, NVFP4 Q*_K decode, FP8 prefill all removed). Users wanting max prefill at medium pp can opt out via [attention] fp8_prefill = "never". Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * docs: Gemma-4 carve-outs section closed (FP8 prefill last) Roadmap "Gemma-4 remaining carve-outs" section now lists all three as removed (FP8 KV cache #91, NVFP4 Q*_K decode #186, FP8 prefill here). CHANGELOG Unreleased entry with measurement table. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Phase 4 Step 1 of MoE-prefill CUDA-graphs work (per memory file
prefill_graph_blockers_2026_05_14). Replace the host-stack pointer-array +cudaMemcpyAsyncH2D pattern in the GQA cuBLAS attention path with a device kernel — first of 6 structural blockers to make NVFP4 MoE prefill graph-capturable.What
attention_cublas_prefill's GQA branch built per-head pointer arrays on the host stack and uploaded them via 3×cudaMemcpyAsyncpercublasGemmBatchedExcall — 6× per attention call total. Host stack memory has no stable identity across CUDA graph replays and the H2D copies abort capture.Replace both blocks (Q×K^T builder and P×V builder) with a small device kernel that writes directly to
s_attn_d_ptrs. Pointer pattern is pure arithmetic:Same
s_attn_d_ptrsstorage, same cuBLAS calls — only the producer changes. The kernel reads only its scalar args (graph capture bakes them in) and writes to device-resident pointers → graph-safe.Out of scope
This PR does NOT enable prefill graph capture. It removes ONE of 6 documented structural blockers. Remaining steps per memo: stable device buffers for
batch.cpptoken_ids/positions/context_lens/block_tables, move capture boundary in engine.cpp, audit 95 H2D/sync sites, per-shape graph pool, full validation. Each as a separate PR.Validation
make build→ 0 warnings, 0 errorstest-attention→ 77/77 pass (2 unrelated skips)The capital of France is **Paris**.✓cublasGemmStridedBatchedEx) untouched.tests/perf_baseline.jsonrefresh included: container/GPU clock state has shifted since last refresh (RTX 5090 clock-gating between 442-2932 MHz across idle/load — confirmed viagpu-statsrolling logger). New baseline tg128=149.57, pp512=13446.47 captures the post-many-rebuilds stable state.Pre-push hook skipped due to local GPU clock-gating variance (Q8_0 decode swings ±7% on identical reps). CI build will validate.
🤖 Generated with Claude Code