From aafc46a596b9a18fa623ec065f65d12ee6adf5a1 Mon Sep 17 00:00:00 2001 From: Raphael Friedmann Date: Fri, 15 May 2026 01:08:05 +0200 Subject: [PATCH] perf(prefill): env-gated graph wrapper for forward_logits (Phase 4 Step 3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wires CudaGraphRunner around the non-last-chunk forward_logits in the chunked-prefill path. Captures (a) the device-args MoE prefill path (default-on since PR #164) and (b) the cuBLAS GQA attention path (now device-ptr-array based since PR #177). Opt-in via IMP_PREFILL_GRAPH=1; default behavior unchanged. Also pre-creates the attention_cublas static cuBLAS handle at engine init via a new attention_cublas_prewarm() entry point. Without this, the first attention_cublas_prefill call inside a captured stream would trigger cublasCreate, whose internal cudaMalloc for workspace is illegal under capture (CUBLAS_STATUS_NOT_INITIALIZED → abort()). gemm_init() already follows the same pattern for the dense GEMM handle. ## Capture status (empirical, Qwen3-Coder-30B-NVFP4, pp=1024 reps=3) - Build: 0 warnings, 0 errors - cuBLAS handle init: clean (no more cublasCreate-under-capture abort) - Warmup forward_logits: runs eager, primes caches and handles - Capture step: graph captured successfully - **Replay: IMA (illegal memory access)** — exactly the failure mode documented in `prefill_graph_blockers_2026_05_14` memo for Blocker B ("captured graph references memory whose addresses differ across replays"). Confirms the residual structural blockers post-PR-#177 are state-lifecycle issues, not API-discovery issues. ## What ships - Scaffolding (env-gated, default off): production behavior unchanged - Foundation for incremental Blocker-B fixes (each replay-IMA source can be isolated and fixed under IMP_PREFILL_GRAPH=1 without affecting default decode/prefill) ## What remains Per memo step 4 (audit 95 H2D/sync sites), step 5 (per-shape graph pool), step 6 (4-model validation). The IMA root cause is the next debugging target — likely chunked-prefill's per-call cudaMallocAsync for `k_full`/`v_full` at executor_attention.cu:762-763 when `q_offset > 0`. Captured graph might also be re-reading from a freed pf_pool slot, or the KV cache block_table content has shifted. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/compute/attention_cublas.cu | 7 +++++++ src/compute/attention_cublas.h | 6 ++++++ src/runtime/engine.cpp | 29 ++++++++++++++++++++++++++++- src/runtime/engine.h | 8 ++++++++ 4 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/compute/attention_cublas.cu b/src/compute/attention_cublas.cu index 22fa572b..d37081c7 100644 --- a/src/compute/attention_cublas.cu +++ b/src/compute/attention_cublas.cu @@ -34,6 +34,13 @@ static cublasHandle_t get_attn_cublas_handle() { return handle; } +void attention_cublas_prewarm() { + // Force lazy-init of the static cuBLAS handle. The first + // cublasGemmBatchedEx call after this is guaranteed not to trigger + // cublasCreate's internal cudaMalloc (illegal under stream capture). + (void)get_attn_cublas_handle(); +} + // --------------------------------------------------------------------------- // Fused causal softmax FP32 → FP16: reads FP32 S matrix, writes FP16 probs // to a separate output buffer. Replaces causal_softmax_fp32_inplace_kernel diff --git a/src/compute/attention_cublas.h b/src/compute/attention_cublas.h index 938c24e5..7c5bf12b 100644 --- a/src/compute/attention_cublas.h +++ b/src/compute/attention_cublas.h @@ -20,4 +20,10 @@ void attention_cublas_prefill(const Tensor& Q, const Tensor& K, const Tensor& V, int n_heads, int n_kv_heads, int head_dim, float scale, bool causal, float softcap = 0.0f, int q_offset = 0, cudaStream_t stream = nullptr); +// Force-create the static cuBLAS handle. Safe to call multiple times. +// Engine init calls this so the first attention_cublas_prefill invocation +// inside a captured stream can reuse the handle without cublasCreate +// (which does internal cudaMalloc for workspace and is illegal under capture). +void attention_cublas_prewarm(); + } // namespace imp diff --git a/src/runtime/engine.cpp b/src/runtime/engine.cpp index 53da4794..42fc2841 100644 --- a/src/runtime/engine.cpp +++ b/src/runtime/engine.cpp @@ -7,6 +7,7 @@ #include "model/gguf_loader.h" #include "model/chat_template.h" #include "compute/gemm.h" +#include "compute/attention_cublas.h" #include "compute/gemm_grouped.h" #include "compute/sampling.h" #include "compute/attention.h" @@ -963,6 +964,7 @@ bool Engine::init(std::shared_ptr model, const EngineConfig& config) { return false; } gemm_init(); + attention_cublas_prewarm(); scheduler_ = std::make_unique(config_.max_batch_size); (void)stream_.create(cudaStreamNonBlocking); @@ -2233,7 +2235,32 @@ void Engine::step_prefill_one(std::shared_ptr& req, int effective_chunk executor_->use_workspace(0); } Tensor logits_out; - executor_->forward_logits(state, logits_out, pf_stream); + + // Prefill graph capture (opt-in, Phase 4 of MoE-prefill-graphs work). + // Conditions: env-gated, pool path (stable device buffers), and + // chunk shape stable (in practice all non-last chunks share chunk_len + // = prefill_chunk_size). H2D upload happened above on pf_stream + // *before* this wrapper — captured region is forward_logits only, + // analogous to the decode graph pattern. + static const bool prefill_graph_enabled = (std::getenv("IMP_PREFILL_GRAPH") != nullptr); + const bool can_capture = prefill_graph_enabled && pf_pool_used && config_.use_cuda_graphs; + if (can_capture) { + const int block_count = static_cast(block_table.size()); + if (chunk_len != last_prefill_chunk_len_ || block_count != last_prefill_block_count_) { + prefill_graph_runner_.invalidate_for_update(); + last_prefill_chunk_len_ = chunk_len; + last_prefill_block_count_ = block_count; + } + prefill_graph_runner_.set_decode_fn([this, &state, &logits_out](cudaStream_t s) { + executor_->forward_logits(state, logits_out, s); + }); + prefill_graph_runner_.execute(pf_stream); + if (logits_out.data == nullptr) { + logits_out = executor_->get_logits_view(/*n_sequences=*/1); + } + } else { + executor_->forward_logits(state, logits_out, pf_stream); + } if (!pf_pool_used) { free_prefill_buffers(d_token_ids, d_positions, d_block_tables, d_context_lens, pf_stream); diff --git a/src/runtime/engine.h b/src/runtime/engine.h index cba15382..971e594a 100644 --- a/src/runtime/engine.h +++ b/src/runtime/engine.h @@ -223,6 +223,14 @@ class Engine { static constexpr int kMaxGraphPoolSize = 32; CudaGraphRunner decode_graph_pool_[kMaxGraphPoolSize]; // index = n_sequences - 1 int last_decode_max_blocks_per_graph_[kMaxGraphPoolSize] = {}; + + // Prefill graph runner — captures forward_logits for non-last chunks of + // chunked prefill. Single runner: in practice chunk_len == prefill_chunk_size + // for all non-last chunks, so per-shape variability collapses to one shape. + // Opt-in via IMP_PREFILL_GRAPH=1 (Phase 4 of MoE-prefill-graphs work). + CudaGraphRunner prefill_graph_runner_; + int last_prefill_chunk_len_ = -1; + int last_prefill_block_count_ = -1; int32_t* h_sample_pinned_ = nullptr; // Async conditional graph loop CudaGraphConditionalRunner async_graph_runner_;