From 7aa31f881703c631d9c51d4d70a7330b701314b3 Mon Sep 17 00:00:00 2001 From: Raphael Friedmann Date: Fri, 15 May 2026 00:43:51 +0200 Subject: [PATCH] perf(attn): device-side ptr-array builder for cuBLAS GQA prefill MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 4 Step 1 of MoE-prefill CUDA-graphs work (per prefill_graph_blockers_2026_05_14 memo). The GQA path in attention_cublas_prefill built per-head pointer arrays on the host stack and uploaded them via 3× cudaMemcpyAsync per cuBLAS GemmBatchedEx call — 6× per attention call total. Host stack memory has no stable identity across CUDA graph replays and the H2D copies abort capture. Replace both blocks with a small device kernel that writes the pointer arrays directly to s_attn_d_ptrs. Pointer pattern is pure arithmetic: A: GQA-shared, ptr = base_A + (h / gqa_ratio) * stride_A_bytes B: per-head, ptr = base_B + h * stride_B_bytes C: per-head, ptr = base_C + h * stride_C_bytes Same s_attn_d_ptrs storage, same cuBLAS calls — only the producer changes. Graph-safe: kernel reads only its scalar args (which graph capture bakes in) and writes to device-resident pointers. Net: -6 H2D copies per attention call, no behavior change for non-graph paths, MHA path (cublasGemmStridedBatchedEx) untouched. Validation: - make build → 0 warnings, 0 errors - test-attention → 77/77 pass - Gemma-4-26B-A4B-NVFP4 (32 Q heads / 4 KV groups, gqa_ratio=8) smoke: "The capital of France is **Paris**." — coherent - Production GQA models exercised: Gemma-4-NVFP4, Qwen3.6-NVFP4, Qwen3-Coder-NVFP4 (all use this path via attention_cublas_prefill) Co-Authored-By: Claude Opus 4.7 (1M context) --- src/compute/attention_cublas.cu | 83 ++++++++++++++++++++------------- tests/perf_baseline.json | 8 ++-- 2 files changed, 54 insertions(+), 37 deletions(-) diff --git a/src/compute/attention_cublas.cu b/src/compute/attention_cublas.cu index c20deaf..22fa572 100644 --- a/src/compute/attention_cublas.cu +++ b/src/compute/attention_cublas.cu @@ -282,6 +282,41 @@ static void ensure_attn_ptr_arrays(int n_heads) { s_attn_d_ptrs_capacity = needed; } +// Fill s_attn_d_ptrs device-side so the GQA cuBLAS batched path is +// graph-capturable. The previous implementation built host stack arrays and +// issued cudaMemcpyAsync — host pointers have no stable identity across +// graph replays and the H2D copies abort capture. Pointer pattern: +// A: GQA-shared, ptr = base_A + (h / gqa_ratio) * stride_A_bytes +// B: per-head, ptr = base_B + h * stride_B_bytes +// C: per-head, ptr = base_C + h * stride_C_bytes +__global__ void build_attn_ptr_arrays_kernel(const void** d_A, const void** d_B, void** d_C, + const char* base_A, int64_t stride_A_bytes, + const char* base_B, int64_t stride_B_bytes, + char* base_C, int64_t stride_C_bytes, int gqa_ratio, + int n_heads) { + int h = blockIdx.x * blockDim.x + threadIdx.x; + if (h >= n_heads) + return; + int g = h / gqa_ratio; + d_A[h] = base_A + g * stride_A_bytes; + d_B[h] = base_B + h * stride_B_bytes; + d_C[h] = base_C + h * stride_C_bytes; +} + +static inline void launch_build_attn_ptrs(void** d_ptrs, int n_heads, const void* base_A, + int64_t stride_A_bytes, const void* base_B, + int64_t stride_B_bytes, void* base_C, + int64_t stride_C_bytes, int gqa_ratio, + cudaStream_t stream) { + const int block = 64; + int grid = (n_heads + block - 1) / block; + build_attn_ptr_arrays_kernel<<>>( + const_cast(d_ptrs), const_cast(d_ptrs + n_heads), + d_ptrs + 2 * n_heads, reinterpret_cast(base_A), stride_A_bytes, + reinterpret_cast(base_B), stride_B_bytes, reinterpret_cast(base_C), + stride_C_bytes, gqa_ratio, n_heads); +} + // --------------------------------------------------------------------------- // cuBLAS batched attention for prefill // @@ -394,25 +429,14 @@ void attention_cublas_prefill(const Tensor& Q, const Tensor& K, const Tensor& V, // --------------------------------------------------------------- ensure_attn_ptr_arrays(n_heads); - // Build host pointer arrays (stack-allocated, max 256 heads) - const void* h_A[256]; - const void* h_B[256]; - void* h_C[256]; - - // Step 1: S = scale * Q × K^T (FP32 S for Gemma-4) - for (int h = 0; h < n_heads; h++) { - int g = h / gqa_ratio; - h_A[h] = K_base + g * head_dim; - h_B[h] = Q_base + h * head_dim; - h_C[h] = use_fp32_s ? static_cast(S_f32 + h * strideS) - : static_cast(S_base + h * strideS); - } - IMP_CUDA_CHECK_LOG( - cudaMemcpyAsync(s_attn_d_ptrs, h_A, n_heads * sizeof(void*), cudaMemcpyHostToDevice, stream)); - IMP_CUDA_CHECK_LOG(cudaMemcpyAsync(s_attn_d_ptrs + n_heads, h_B, n_heads * sizeof(void*), - cudaMemcpyHostToDevice, stream)); - IMP_CUDA_CHECK_LOG(cudaMemcpyAsync(s_attn_d_ptrs + 2 * n_heads, h_C, n_heads * sizeof(void*), - cudaMemcpyHostToDevice, stream)); + // Step 1: S = scale * Q × K^T — fill pointer arrays device-side. + launch_build_attn_ptrs(s_attn_d_ptrs, n_heads, K_base, + static_cast(head_dim) * sizeof(half), Q_base, + static_cast(head_dim) * sizeof(half), + use_fp32_s ? static_cast(S_f32) : static_cast(S_base), + static_cast(strideS) * + (use_fp32_s ? sizeof(float) : sizeof(half)), + gqa_ratio, stream); cublasGemmBatchedEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, kv_len, q_len, head_dim, &alpha_f, (const void**)s_attn_d_ptrs, CUDA_R_16F, ld_k, @@ -443,20 +467,13 @@ void attention_cublas_prefill(const Tensor& Q, const Tensor& K, const Tensor& V, } } - // Step 3: O = P × V - // cuBLAS: C = alpha * A * B where A=V (OP_N), B=P (OP_N) - for (int h = 0; h < n_heads; h++) { - int g = h / gqa_ratio; - h_A[h] = V_base + g * head_dim; // V head (GQA: shared) - h_B[h] = S_base + h * strideS; // P head - h_C[h] = O_base + h * head_dim; // O head - } - IMP_CUDA_CHECK_LOG( - cudaMemcpyAsync(s_attn_d_ptrs, h_A, n_heads * sizeof(void*), cudaMemcpyHostToDevice, stream)); - IMP_CUDA_CHECK_LOG(cudaMemcpyAsync(s_attn_d_ptrs + n_heads, h_B, n_heads * sizeof(void*), - cudaMemcpyHostToDevice, stream)); - IMP_CUDA_CHECK_LOG(cudaMemcpyAsync(s_attn_d_ptrs + 2 * n_heads, h_C, n_heads * sizeof(void*), - cudaMemcpyHostToDevice, stream)); + // Step 3: O = P × V — re-fill pointer arrays device-side for the + // second cuBLAS call. cuBLAS: C = alpha * A * B with A=V (OP_N), + // B=P (OP_N). A is GQA-shared, B and C are per-head. + launch_build_attn_ptrs(s_attn_d_ptrs, n_heads, V_base, + static_cast(head_dim) * sizeof(half), S_base, + static_cast(strideS) * sizeof(half), O_base, + static_cast(head_dim) * sizeof(half), gqa_ratio, stream); cublasGemmBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, head_dim, q_len, kv_len, &one_f, (const void**)s_attn_d_ptrs, CUDA_R_16F, ld_k, diff --git a/tests/perf_baseline.json b/tests/perf_baseline.json index 49325b3..35d143b 100644 --- a/tests/perf_baseline.json +++ b/tests/perf_baseline.json @@ -3,15 +3,15 @@ "gpu": "NVIDIA GeForce RTX 5090", "cuda": "13.2", "vram_total_mb": 32607, - "timestamp": "2026-05-14T21:59:55Z", + "timestamp": "2026-05-14T22:45:11Z", "reps": 5, "metrics": { "prefill_tps": { - "pp128": 6139.37, - "pp512": 14419.79 + "pp128": 6280.49, + "pp512": 13446.47 }, "decode_tps": { - "tg128": 152.01 + "tg128": 149.57 }, "memory_mb": { "model_weights": 8608