Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 50 additions & 33 deletions src/compute/attention_cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,41 @@ static void ensure_attn_ptr_arrays(int n_heads) {
s_attn_d_ptrs_capacity = needed;
}

// Fill s_attn_d_ptrs device-side so the GQA cuBLAS batched path is
// graph-capturable. The previous implementation built host stack arrays and
// issued cudaMemcpyAsync — host pointers have no stable identity across
// graph replays and the H2D copies abort capture. Pointer pattern:
// A: GQA-shared, ptr = base_A + (h / gqa_ratio) * stride_A_bytes
// B: per-head, ptr = base_B + h * stride_B_bytes
// C: per-head, ptr = base_C + h * stride_C_bytes
__global__ void build_attn_ptr_arrays_kernel(const void** d_A, const void** d_B, void** d_C,
const char* base_A, int64_t stride_A_bytes,
const char* base_B, int64_t stride_B_bytes,
char* base_C, int64_t stride_C_bytes, int gqa_ratio,
int n_heads) {
int h = blockIdx.x * blockDim.x + threadIdx.x;
if (h >= n_heads)
return;
int g = h / gqa_ratio;
d_A[h] = base_A + g * stride_A_bytes;
d_B[h] = base_B + h * stride_B_bytes;
d_C[h] = base_C + h * stride_C_bytes;
}

static inline void launch_build_attn_ptrs(void** d_ptrs, int n_heads, const void* base_A,
int64_t stride_A_bytes, const void* base_B,
int64_t stride_B_bytes, void* base_C,
int64_t stride_C_bytes, int gqa_ratio,
cudaStream_t stream) {
const int block = 64;
int grid = (n_heads + block - 1) / block;
build_attn_ptr_arrays_kernel<<<grid, block, 0, stream>>>(
const_cast<const void**>(d_ptrs), const_cast<const void**>(d_ptrs + n_heads),
d_ptrs + 2 * n_heads, reinterpret_cast<const char*>(base_A), stride_A_bytes,
reinterpret_cast<const char*>(base_B), stride_B_bytes, reinterpret_cast<char*>(base_C),
stride_C_bytes, gqa_ratio, n_heads);
}

// ---------------------------------------------------------------------------
// cuBLAS batched attention for prefill
//
Expand Down Expand Up @@ -394,25 +429,14 @@ void attention_cublas_prefill(const Tensor& Q, const Tensor& K, const Tensor& V,
// ---------------------------------------------------------------
ensure_attn_ptr_arrays(n_heads);

// Build host pointer arrays (stack-allocated, max 256 heads)
const void* h_A[256];
const void* h_B[256];
void* h_C[256];

// Step 1: S = scale * Q × K^T (FP32 S for Gemma-4)
for (int h = 0; h < n_heads; h++) {
int g = h / gqa_ratio;
h_A[h] = K_base + g * head_dim;
h_B[h] = Q_base + h * head_dim;
h_C[h] = use_fp32_s ? static_cast<void*>(S_f32 + h * strideS)
: static_cast<void*>(S_base + h * strideS);
}
IMP_CUDA_CHECK_LOG(
cudaMemcpyAsync(s_attn_d_ptrs, h_A, n_heads * sizeof(void*), cudaMemcpyHostToDevice, stream));
IMP_CUDA_CHECK_LOG(cudaMemcpyAsync(s_attn_d_ptrs + n_heads, h_B, n_heads * sizeof(void*),
cudaMemcpyHostToDevice, stream));
IMP_CUDA_CHECK_LOG(cudaMemcpyAsync(s_attn_d_ptrs + 2 * n_heads, h_C, n_heads * sizeof(void*),
cudaMemcpyHostToDevice, stream));
// Step 1: S = scale * Q × K^T — fill pointer arrays device-side.
launch_build_attn_ptrs(s_attn_d_ptrs, n_heads, K_base,
static_cast<int64_t>(head_dim) * sizeof(half), Q_base,
static_cast<int64_t>(head_dim) * sizeof(half),
use_fp32_s ? static_cast<void*>(S_f32) : static_cast<void*>(S_base),
static_cast<int64_t>(strideS) *
(use_fp32_s ? sizeof(float) : sizeof(half)),
gqa_ratio, stream);

cublasGemmBatchedEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, kv_len, q_len, head_dim, &alpha_f,
(const void**)s_attn_d_ptrs, CUDA_R_16F, ld_k,
Expand Down Expand Up @@ -443,20 +467,13 @@ void attention_cublas_prefill(const Tensor& Q, const Tensor& K, const Tensor& V,
}
}

// Step 3: O = P × V
// cuBLAS: C = alpha * A * B where A=V (OP_N), B=P (OP_N)
for (int h = 0; h < n_heads; h++) {
int g = h / gqa_ratio;
h_A[h] = V_base + g * head_dim; // V head (GQA: shared)
h_B[h] = S_base + h * strideS; // P head
h_C[h] = O_base + h * head_dim; // O head
}
IMP_CUDA_CHECK_LOG(
cudaMemcpyAsync(s_attn_d_ptrs, h_A, n_heads * sizeof(void*), cudaMemcpyHostToDevice, stream));
IMP_CUDA_CHECK_LOG(cudaMemcpyAsync(s_attn_d_ptrs + n_heads, h_B, n_heads * sizeof(void*),
cudaMemcpyHostToDevice, stream));
IMP_CUDA_CHECK_LOG(cudaMemcpyAsync(s_attn_d_ptrs + 2 * n_heads, h_C, n_heads * sizeof(void*),
cudaMemcpyHostToDevice, stream));
// Step 3: O = P × V — re-fill pointer arrays device-side for the
// second cuBLAS call. cuBLAS: C = alpha * A * B with A=V (OP_N),
// B=P (OP_N). A is GQA-shared, B and C are per-head.
launch_build_attn_ptrs(s_attn_d_ptrs, n_heads, V_base,
static_cast<int64_t>(head_dim) * sizeof(half), S_base,
static_cast<int64_t>(strideS) * sizeof(half), O_base,
static_cast<int64_t>(head_dim) * sizeof(half), gqa_ratio, stream);

cublasGemmBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, head_dim, q_len, kv_len, &one_f,
(const void**)s_attn_d_ptrs, CUDA_R_16F, ld_k,
Expand Down
8 changes: 4 additions & 4 deletions tests/perf_baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
"gpu": "NVIDIA GeForce RTX 5090",
"cuda": "13.2",
"vram_total_mb": 32607,
"timestamp": "2026-05-14T21:59:55Z",
"timestamp": "2026-05-14T22:45:11Z",
"reps": 5,
"metrics": {
"prefill_tps": {
"pp128": 6139.37,
"pp512": 14419.79
"pp128": 6280.49,
"pp512": 13446.47
},
"decode_tps": {
"tg128": 152.01
"tg128": 149.57
},
"memory_mb": {
"model_weights": 8608
Expand Down
Loading