From da755bde8c45426d3d466ed0d083175bf2884508 Mon Sep 17 00:00:00 2001 From: xueyangcs Date: Fri, 3 Apr 2026 16:07:34 +0800 Subject: [PATCH] add k scale in prefill attention --- hpc/attention.py | 14 ++-- src/attention/entry.cc | 17 +++-- src/attention/prefill/config.h | 14 ++-- src/attention/prefill/kernels.cuh | 66 +++++++++---------- src/attention/prefill/prefill.cc | 13 ++-- src/attention/prefill/prefill.h | 11 ++-- .../warp_spec_with_kvcache_fp8_dim128.cu | 44 +++++++------ .../warp_spec_with_kvcache_fp8_dim128.h | 12 ++-- ...test_attention_with_kvcache_prefill_fp8.py | 18 +++-- 9 files changed, 113 insertions(+), 96 deletions(-) diff --git a/hpc/attention.py b/hpc/attention.py index 1587d29..87d8fce 100644 --- a/hpc/attention.py +++ b/hpc/attention.py @@ -129,7 +129,8 @@ def attention_with_kvcache_prefill_fp8( q: Tensor, kcache: Tensor, vcache: Tensor, - qkscale: Tensor, + qscale: Tensor, + kscale: Tensor, vscale: Tensor, cu_seqlens_q: Tensor, block_ids: Tensor, @@ -156,9 +157,12 @@ def attention_with_kvcache_prefill_fp8( Constrainst the unused slots in last block of vcache for each request to be set zeros. Shape: [num_blocks, block_size, num_head_kv, num_dim_v] Dtype: fp8 - qkscale: QK fp8 quant scale. Per Token Per Head Fp8 Quant. + qscale: QK fp8 quant scale. Per Token Per Head Fp8 Quant. Shape: [num_batch, num_head_q, max_seqlens_q_pad] Dtype: float32 + kscale: K fp8 quant scale. Per Tensor Fp8 Quant. + Shape: [1] + Dtype: float32 vscale: V fp8 quant scale. Per Tensor Fp8 Quant. Shape: [1] Dtype: float32 @@ -193,7 +197,8 @@ def attention_with_kvcache_prefill_fp8( q, kcache, vcache, - qkscale, + qscale, + kscale, vscale, cu_seqlens_q, block_ids, @@ -365,7 +370,8 @@ def attention_with_kvcache_prefill_fp8_fake( q, kcache, vcache, - qkscale, + qscale, + kscale, vscale, cu_seqlens_q, block_ids, diff --git a/src/attention/entry.cc b/src/attention/entry.cc index d6cff7c..3a2399b 100644 --- a/src/attention/entry.cc +++ b/src/attention/entry.cc @@ -128,14 +128,16 @@ torch::Tensor attention_with_kvcache_prefill_bf16_entry( torch::Tensor attention_with_kvcache_prefill_fp8_entry( const torch::Tensor &q, const torch::Tensor &kcache, const torch::Tensor &vcache, - const torch::Tensor &qkscale, const torch::Tensor &vscale, const torch::Tensor &cu_seqlens_q, - const torch::Tensor block_ids, const torch::Tensor seqlens_kvcache, int64_t max_seqlens_q, + const torch::Tensor &qscale, const torch::Tensor &kscale, const torch::Tensor &vscale, + const torch::Tensor &cu_seqlens_q, const torch::Tensor block_ids, + const torch::Tensor seqlens_kvcache, int64_t max_seqlens_q, std::optional output) { auto stream = at::cuda::getCurrentCUDAStream(q.get_device()); TORCH_CHECK(q.device().is_cuda(), "q tensor must be cuda"); TORCH_CHECK(kcache.device().is_cuda(), "kcache tensor must be cuda"); TORCH_CHECK(vcache.device().is_cuda(), "vcache tensor must be cuda"); - TORCH_CHECK(qkscale.device().is_cuda(), "qkscale tensor must be cuda"); + TORCH_CHECK(qscale.device().is_cuda(), "qscale tensor must be cuda"); + TORCH_CHECK(kscale.device().is_cuda(), "kscale tensor must be cuda"); TORCH_CHECK(vscale.device().is_cuda(), "vscale tensor must be cuda"); TORCH_CHECK(cu_seqlens_q.device().is_cuda(), "cu_seqlens_q tensor must be cuda"); TORCH_CHECK(block_ids.device().is_cuda(), "block_ids tensor must be cuda"); @@ -155,7 +157,7 @@ torch::Tensor attention_with_kvcache_prefill_fp8_entry( int num_seq_max_blocks = block_ids.size(1); - int max_seqlens_q_pad = qkscale.size(2); + int max_seqlens_q_pad = qscale.size(2); auto options = q.options().dtype(torch::kBFloat16); torch::Tensor y; @@ -171,7 +173,8 @@ torch::Tensor attention_with_kvcache_prefill_fp8_entry( const auto *q_ptr = q.const_data_ptr(); const auto *kcache_ptr = kcache.const_data_ptr(); const auto *vcache_ptr = vcache.const_data_ptr(); - const auto *qkscale_ptr = qkscale.const_data_ptr(); + const auto *qscale_ptr = qscale.const_data_ptr(); + const auto *kscale_ptr = kscale.const_data_ptr(); const auto *vscale_ptr = vscale.const_data_ptr(); const auto *cu_seqlens_q_ptr = cu_seqlens_q.const_data_ptr(); const auto *block_ids_ptr = block_ids.const_data_ptr(); @@ -187,7 +190,7 @@ torch::Tensor attention_with_kvcache_prefill_fp8_entry( int ldY = y.stride(0); // num_head_q * num_dim_v; attention_with_kvcache_prefill_fp8_async( - y_ptr, q_ptr, kcache_ptr, vcache_ptr, qkscale_ptr, vscale_ptr, cu_seqlens_q_ptr, + y_ptr, q_ptr, kcache_ptr, vcache_ptr, qscale_ptr, kscale_ptr, vscale_ptr, cu_seqlens_q_ptr, block_ids_ptr, seqlens_kvcache_ptr, tmas_ptr, num_batch, total_seq_q, max_seqlens_q, max_seqlens_q_pad, num_dim_qk, num_dim_v, num_head_q, num_head_kv, num_kvcache_blocks, block_size, num_seq_max_blocks, ldY, ldQ, ldK, ldV, stream); @@ -419,7 +422,7 @@ TORCH_LIBRARY_FRAGMENT(hpc, m) { m.def( "attention_with_kvcache_prefill_fp8(Tensor q, Tensor kcache, Tensor vcache," - "Tensor qkscale, Tensor vscale, Tensor cu_seqlens_q," + "Tensor qscale, Tensor kscale, Tensor vscale, Tensor cu_seqlens_q," "Tensor block_ids, Tensor num_seq_kvcache, int max_seqlens_q, Tensor? output) -> (Tensor)"); m.impl("attention_with_kvcache_prefill_fp8", torch::kCUDA, &hpc::attention::attention_with_kvcache_prefill_fp8_entry); diff --git a/src/attention/prefill/config.h b/src/attention/prefill/config.h index fb44a0c..d928e10 100644 --- a/src/attention/prefill/config.h +++ b/src/attention/prefill/config.h @@ -242,7 +242,7 @@ struct AttentionKVCachePrefillFp8Config { SLayoutVTAtom{}, make_shape(Int{}, Int{}, Int{}))); using SLayoutY = decltype(tile_to_shape(SLayoutYAtom{}, make_shape(Int{}, Int{}))); - using SLayoutQKS = decltype(make_layout(make_shape(Int{}), make_stride(Int<1>{}))); + using SLayoutQS = decltype(make_layout(make_shape(Int{}), make_stride(Int<1>{}))); using CopyBoxK = decltype(tile_to_shape(SLayoutKAtom{}, make_shape(Int{}, Int{}))); @@ -251,14 +251,14 @@ struct AttentionKVCachePrefillFp8Config { using CopyBoxY = decltype(tile_to_shape(SLayoutYAtom{}, make_shape(Int{}, Int{}))); - template - auto get_tma(TQ q, TK k, TV v, TY y, TQKS qks) { + template + auto get_tma(TQ q, TK k, TV v, TY y, TQS qs) { auto tma_q = make_tma_copy(SM90_TMA_LOAD{}, q, SLayoutQ{}); auto tma_k = make_tma_copy(SM90_TMA_LOAD{}, k, CopyBoxK{}); auto tma_v = make_tma_copy(SM90_TMA_LOAD{}, v, CopyBoxV{}); auto tma_y = make_tma_copy(SM90_TMA_STORE{}, y, CopyBoxY{}); - auto tma_qks = make_tma_copy(SM90_TMA_LOAD{}, qks, SLayoutQKS{}); - return std::make_tuple(tma_q, tma_k, tma_v, tma_y, tma_qks); + auto tma_qs = make_tma_copy(SM90_TMA_LOAD{}, qs, SLayoutQS{}); + return std::make_tuple(tma_q, tma_k, tma_v, tma_y, tma_qs); } using WarpgroupLayout = @@ -270,8 +270,8 @@ struct AttentionKVCachePrefillFp8Config { (cosize(SLayoutQ{}) + cosize(SLayoutK{}) + cosize(SLayoutV{}) + cosize(SLayoutVT{})) * sizeof(Tin); static constexpr int shm_y = cosize(SLayoutY{}) * sizeof(Tout); - static constexpr int shm_qks = cosize(SLayoutQKS{}) * sizeof(float); - static constexpr int shm_size = shm_qkv + shm_y + shm_qks; + static constexpr int shm_qs = cosize(SLayoutQS{}) * sizeof(float); + static constexpr int shm_size = shm_qkv + shm_y + shm_qs; auto get_shm_size() { return shm_size; } }; diff --git a/src/attention/prefill/kernels.cuh b/src/attention/prefill/kernels.cuh index 08b7111..b0cbf03 100644 --- a/src/attention/prefill/kernels.cuh +++ b/src/attention/prefill/kernels.cuh @@ -89,22 +89,21 @@ __device__ __forceinline__ void online_softmax(ATensor &&tAttr_mn, MTensor &&gMa } } -template +template __device__ __forceinline__ void online_softmax_with_scale(ATensor &&tAttr_mn, MTensor &&gMax, STensor &&gSum, YTensor &&tYr_mn, - QKSTensor &&tQKS, int kM, int kN, + QSTensor &&tQS, int kM, int kN, float one_over_dk_log2e) { #pragma unroll for (int im = 0; im < kM; ++im) { - float qks_im = tQKS[im]; - tAttr_mn(im, 0) = tAttr_mn(im, 0) * qks_im; + float qs_im = tQS[im]; + tAttr_mn(im, 0) = tAttr_mn(im, 0) * qs_im; float row_max = tAttr_mn(im, 0); float row_sum = 0.f; #pragma unroll for (int in = 1; in < kN; ++in) { - float local_max = tAttr_mn(im, in) * qks_im; + float local_max = tAttr_mn(im, in) * qs_im; tAttr_mn(im, in) = local_max; row_max = fmaxf(row_max, local_max); } @@ -992,17 +991,17 @@ __global__ void __launch_bounds__(384, 1) } template + typename TmaQS> __global__ void __launch_bounds__(384, 1) attention_with_kvcache_prefill_fp8_warp_specialization_kernel( cute::TmaDescriptor *td_qy, const __grid_constant__ TmaK tma_k, - const __grid_constant__ TmaV tma_v, const __grid_constant__ TmaQKS tma_qks, - const float *qkscale_ptr, const float *vscale_ptr, const int *cu_seqlens_q_ptr, - const int *seqlens_kvcache_ptr, const int *block_ids_ptr, int num_batch, int max_seq_q, - int max_seq_q_pad, int num_dim_qk, int num_dim_v, int num_head_q, int num_head_kv, - int num_kvcache_blocks, int block_size, int num_seq_max_blocks, float one_over_dk_log2e, - cutlass::FastDivmod head_kv_divmod, cutlass::FastDivmod head_q_divmod, - cutlass::FastDivmod tile_m_divmod) { + const __grid_constant__ TmaV tma_v, const __grid_constant__ TmaQS tma_qs, + const float *qscale_ptr, const float *kscale_ptr, const float *vscale_ptr, + const int *cu_seqlens_q_ptr, const int *seqlens_kvcache_ptr, const int *block_ids_ptr, + int num_batch, int max_seq_q, int max_seq_q_pad, int num_dim_qk, int num_dim_v, + int num_head_q, int num_head_kv, int num_kvcache_blocks, int block_size, + int num_seq_max_blocks, float one_over_dk_log2e, cutlass::FastDivmod head_kv_divmod, + cutlass::FastDivmod head_q_divmod, cutlass::FastDivmod tile_m_divmod) { using namespace cute; // NOLINT using Tin = typename Config::Tin; @@ -1014,7 +1013,7 @@ __global__ void __launch_bounds__(384, 1) using SLayoutV = typename Config::SLayoutV; using SLayoutVT = typename Config::SLayoutVT; using SLayoutY = typename Config::SLayoutY; - using SLayoutQKS = typename Config::SLayoutQKS; + using SLayoutQS = typename Config::SLayoutQS; constexpr int kTileM = Config::kTileM; constexpr int kTileN = Config::kTileN; @@ -1044,8 +1043,8 @@ __global__ void __launch_bounds__(384, 1) auto *shm_v = shm_k + cosize(SLayoutK{}); auto *shm_vt = shm_v + cosize(SLayoutV{}); auto *shm_y = reinterpret_cast(shm_vt + cosize(SLayoutVT{})); - auto *shm_qks = reinterpret_cast(shm_y + cosize(SLayoutY{})); - auto *shm_seqlens_q = reinterpret_cast(shm_qks + cosize(SLayoutQKS{})); + auto *shm_qs = reinterpret_cast(shm_y + cosize(SLayoutY{})); + auto *shm_seqlens_q = reinterpret_cast(shm_qs + cosize(SLayoutQS{})); auto *shm_seqlens_kv = shm_seqlens_q + num_batch; auto *shm_seqlens_qstart = shm_seqlens_kv + num_batch; @@ -1059,7 +1058,7 @@ __global__ void __launch_bounds__(384, 1) auto gV = tma_v.get_tma_tensor(make_shape(num_dim_v, kBlockSize, num_head_kv, num_kvcache_blocks)); auto gY = tma_y.get_tma_tensor(make_shape(max_seq_q, num_dim_v, num_head_q)); - auto gQKS = tma_qks.get_tma_tensor(make_shape(max_seq_q_pad, num_head_q, num_batch)); + auto gQS = tma_qs.get_tma_tensor(make_shape(max_seq_q_pad, num_head_q, num_batch)); auto gAtt = make_tensor(make_gmem_ptr(static_cast(nullptr)), @@ -1077,25 +1076,25 @@ __global__ void __launch_bounds__(384, 1) auto sV = make_tensor(make_smem_ptr(shm_v), SLayoutV{}); auto sVT = make_tensor(make_smem_ptr(shm_vt), SLayoutVT{}); auto sY = make_tensor(make_smem_ptr(shm_y), SLayoutY{}); - auto sQKS = make_tensor(make_smem_ptr(shm_qks), SLayoutQKS{}); + auto sQS = make_tensor(make_smem_ptr(shm_qs), SLayoutQS{}); // Block Level tma auto btma_q = tma_q.get_slice(0); auto btma_k = tma_k.get_slice(0); auto btma_v = tma_v.get_slice(0); auto btma_y = tma_y.get_slice(0); - auto btma_qks = tma_qks.get_slice(0); + auto btma_qs = tma_qs.get_slice(0); // Thread Level Tensor - auto tQg = btma_q.partition_S(gQ); // (TMA, TMA_M, TMA_K, head, batch) - auto tKg = btma_k.partition_S(gK); // (TMA, TMA_N, TMA_K, head, batch) - auto tVg = btma_v.partition_S(gV); // (TMA, TMA_V, TMA_N, head, batch) - auto tQKSg = btma_qks.partition_S(gQKS); // (TMA, TMA_M, head, batch) + auto tQg = btma_q.partition_S(gQ); // (TMA, TMA_M, TMA_K, head, batch) + auto tKg = btma_k.partition_S(gK); // (TMA, TMA_N, TMA_K, head, batch) + auto tVg = btma_v.partition_S(gV); // (TMA, TMA_V, TMA_N, head, batch) + auto tQSg = btma_qs.partition_S(gQS); // (TMA, TMA_M, head, batch) - auto tQs = btma_q.partition_D(sQ); // (TMA, _1, _1) - auto tKs = btma_k.partition_D(sK); // (TMA, _1, _1, kStage) - auto tVs = btma_v.partition_D(sV); // (TMA, _1, _1, kStage) - auto tQKSs = btma_qks.partition_D(sQKS); // (TMA, _1) + auto tQs = btma_q.partition_D(sQ); // (TMA, _1, _1) + auto tKs = btma_k.partition_D(sK); // (TMA, _1, _1, kStage) + auto tVs = btma_v.partition_D(sV); // (TMA, _1, _1, kStage) + auto tQSs = btma_qs.partition_D(sQS); // (TMA, _1) TiledMmaQK tiled_mma_qk; TiledMmaPV tiled_mma_pv; @@ -1180,9 +1179,9 @@ __global__ void __launch_bounds__(384, 1) // Load Q wait_barrier(writable_q, phase_q); cute::copy(tma_q.with(td_q, readable_q), tQg(_, itile_m, _, ihead_q), tQs(_, 0, _)); - cute::copy(tma_qks.with(readable_q), tQKSg(_, itile_m, ihead_q, ibatch), tQKSs(_, 0)); + cute::copy(tma_qs.with(readable_q), tQSg(_, itile_m, ihead_q, ibatch), tQSs(_, 0)); set_barrier_transaction_bytes( - readable_q, sizeof(Tin) * cosize(SLayoutQ{}) + sizeof(float) * cosize(SLayoutQKS{})); + readable_q, sizeof(Tin) * cosize(SLayoutQ{}) + sizeof(float) * cosize(SLayoutQS{})); phase_q ^= 1; int num_tile_kv = (start_seq_q + (itile_m + 1) * kTileM + kTileN - 1) / kTileN; @@ -1270,7 +1269,8 @@ __global__ void __launch_bounds__(384, 1) int phase = 0; int phase_q = 0; - float tQKS[kM]; + float tQS[kM]; + float kscale = kscale_ptr[0]; float vscale = vscale_ptr[0]; while (true) { @@ -1303,7 +1303,7 @@ __global__ void __launch_bounds__(384, 1) auto tI_mn = retile_fragment(tI); #pragma unroll for (int im = 0; im < kM; im++) { - tQKS[im] = sQKS(get<0>(tI_mn(im, 0))); + tQS[im] = sQS(get<0>(tI_mn(im, 0))) * kscale; } #pragma unroll 1 @@ -1355,7 +1355,7 @@ __global__ void __launch_bounds__(384, 1) auto tYr_mn = retile_fragment(tYr); // online softmax - online_softmax_with_scale(tAttr_mn, gMax, gSum, tYr_mn, tQKS, kM, kN, one_over_dk_log2e); + online_softmax_with_scale(tAttr_mn, gMax, gSum, tYr_mn, tQS, kM, kN, one_over_dk_log2e); // convert P to fp8 and permute for pv gemm auto tAttr_float32x4 = recast(tAttr); diff --git a/src/attention/prefill/prefill.cc b/src/attention/prefill/prefill.cc index ef95347..1978325 100644 --- a/src/attention/prefill/prefill.cc +++ b/src/attention/prefill/prefill.cc @@ -70,13 +70,14 @@ void attention_with_kvcache_prefill_bf16_async( void attention_with_kvcache_prefill_fp8_async( void *y_ptr, const void *q_ptr, const void *kcache_ptr, const void *vcache_ptr, - const void *qkscale_ptr, const void *vscale_ptr, const void *cu_seqlens_q_ptr, - const void *block_ids_ptr, const void *seqlens_kvcache_ptr, void *tmas_ptr, int num_batch, - int total_seq_q, int max_seq_q, int max_seq_q_pad, int num_dim_qk, int num_dim_v, - int num_head_q, int num_head_kv, int num_kvcache_blocks, int block_size, int num_seq_max_blocks, - int ldY, int ldQ, int ldK, int ldV, cudaStream_t stream) { + const void *qscale_ptr, const void *kscale_ptr, const void *vscale_ptr, + const void *cu_seqlens_q_ptr, const void *block_ids_ptr, const void *seqlens_kvcache_ptr, + void *tmas_ptr, int num_batch, int total_seq_q, int max_seq_q, int max_seq_q_pad, + int num_dim_qk, int num_dim_v, int num_head_q, int num_head_kv, int num_kvcache_blocks, + int block_size, int num_seq_max_blocks, int ldY, int ldQ, int ldK, int ldV, + cudaStream_t stream) { prefill::warp_spec_with_kvcache_fp8_dim128_async( - y_ptr, q_ptr, kcache_ptr, vcache_ptr, qkscale_ptr, vscale_ptr, cu_seqlens_q_ptr, + y_ptr, q_ptr, kcache_ptr, vcache_ptr, qscale_ptr, kscale_ptr, vscale_ptr, cu_seqlens_q_ptr, block_ids_ptr, seqlens_kvcache_ptr, tmas_ptr, num_batch, total_seq_q, max_seq_q, max_seq_q_pad, num_dim_qk, num_dim_v, num_head_q, num_head_kv, num_kvcache_blocks, block_size, num_seq_max_blocks, ldY, ldQ, ldK, ldV, stream); diff --git a/src/attention/prefill/prefill.h b/src/attention/prefill/prefill.h index e5af002..f4396be 100644 --- a/src/attention/prefill/prefill.h +++ b/src/attention/prefill/prefill.h @@ -25,11 +25,12 @@ void attention_with_kvcache_prefill_bf16_async( void attention_with_kvcache_prefill_fp8_async( void *y_ptr, const void *q_ptr, const void *kcache_ptr, const void *vcache_ptr, - const void *qkscale_ptr, const void *vscale_ptr, const void *cu_seqlens_q_ptr, - const void *block_ids_ptr, const void *seqlens_kvcache_ptr, void *tmas_ptr, int num_batch, - int total_seq_q, int max_seq_q, int max_seq_q_pad, int num_dim_qk, int num_dim_v, - int num_head_q, int num_head_kv, int num_kvcache_blocks, int block_size, int num_seq_max_blocks, - int ldY, int ldQ, int ldK, int ldV, cudaStream_t stream); + const void *qscale_ptr, const void *kscale_ptr, const void *vscale_ptr, + const void *cu_seqlens_q_ptr, const void *block_ids_ptr, const void *seqlens_kvcache_ptr, + void *tmas_ptr, int num_batch, int total_seq_q, int max_seq_q, int max_seq_q_pad, + int num_dim_qk, int num_dim_v, int num_head_q, int num_head_kv, int num_kvcache_blocks, + int block_size, int num_seq_max_blocks, int ldY, int ldQ, int ldK, int ldV, + cudaStream_t stream); } // namespace attention } // namespace hpc diff --git a/src/attention/prefill/warp_spec_with_kvcache_fp8_dim128.cu b/src/attention/prefill/warp_spec_with_kvcache_fp8_dim128.cu index 273bfde..07a847a 100644 --- a/src/attention/prefill/warp_spec_with_kvcache_fp8_dim128.cu +++ b/src/attention/prefill/warp_spec_with_kvcache_fp8_dim128.cu @@ -19,11 +19,12 @@ namespace prefill { template void launch_warp_spec_with_kvcache_fp8_dim128( void *y_ptr, const void *q_ptr, const void *kcache_ptr, const void *vcache_ptr, - const void *qkscale_ptr, const void *vscale_ptr, const void *cu_seqlens_q_ptr, - const void *block_ids_ptr, const void *seqlens_kvcache_ptr, void *tmas_ptr, int num_batch, - int total_seq_q, int max_seq_q, int max_seq_q_pad, int num_dim_qk, int num_dim_v, - int num_head_q, int num_head_kv, int num_kvcache_blocks, int block_size, int num_seq_max_blocks, - int ldY, int ldQ, int ldK, int ldV, cudaStream_t stream) { + const void *qscale_ptr, const void *kscale_ptr, const void *vscale_ptr, + const void *cu_seqlens_q_ptr, const void *block_ids_ptr, const void *seqlens_kvcache_ptr, + void *tmas_ptr, int num_batch, int total_seq_q, int max_seq_q, int max_seq_q_pad, + int num_dim_qk, int num_dim_v, int num_head_q, int num_head_kv, int num_kvcache_blocks, + int block_size, int num_seq_max_blocks, int ldY, int ldQ, int ldK, int ldV, + cudaStream_t stream) { using namespace cute; // NOLINT using Tin = cute::float_e4m3_t; @@ -41,9 +42,9 @@ void launch_warp_spec_with_kvcache_fp8_dim128( auto Y = make_tensor(make_gmem_ptr(reinterpret_cast(y_ptr)), make_shape(max_seq_q, num_dim_v, num_head_q), make_stride(ldY, Int<1>{}, num_dim_v)); - auto QKS = make_tensor(make_gmem_ptr(reinterpret_cast(qkscale_ptr)), - make_shape(max_seq_q_pad, num_head_q, num_batch), - make_stride(Int<1>{}, max_seq_q_pad, num_head_q * max_seq_q_pad)); + auto QS = make_tensor(make_gmem_ptr(reinterpret_cast(qscale_ptr)), + make_shape(max_seq_q_pad, num_head_q, num_batch), + make_stride(Int<1>{}, max_seq_q_pad, num_head_q * max_seq_q_pad)); auto *tma_qy = static_cast(tmas_ptr); constexpr float kLog2e = 1.4426950408889634f; @@ -55,7 +56,7 @@ void launch_warp_spec_with_kvcache_fp8_dim128( 128, kBlockSize, 2, 2, 1, 128, 128, 128, 128>; Config config; - auto [tma_q, tma_k, tma_v, tma_y, tma_qks] = config.get_tma(Q, K, V, Y, QKS); + auto [tma_q, tma_k, tma_v, tma_y, tma_qks] = config.get_tma(Q, K, V, Y, QS); // 0. update tma { @@ -86,32 +87,33 @@ void launch_warp_spec_with_kvcache_fp8_dim128( decltype(tma_qks)>; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); kernel<<>>( - tma_qy, tma_k, tma_v, tma_qks, (const float *)qkscale_ptr, (const float *)vscale_ptr, - (const int *)cu_seqlens_q_ptr, (const int *)seqlens_kvcache_ptr, (const int *)block_ids_ptr, - num_batch, max_seq_q, max_seq_q_pad, num_dim_qk, num_dim_v, num_head_q, num_head_kv, - num_kvcache_blocks, block_size, num_seq_max_blocks, one_over_dk_log2e, head_kv_divmod, - head_q_divmod, tile_m_divmod); + tma_qy, tma_k, tma_v, tma_qks, (const float *)qscale_ptr, (const float *)kscale_ptr, + (const float *)vscale_ptr, (const int *)cu_seqlens_q_ptr, (const int *)seqlens_kvcache_ptr, + (const int *)block_ids_ptr, num_batch, max_seq_q, max_seq_q_pad, num_dim_qk, num_dim_v, + num_head_q, num_head_kv, num_kvcache_blocks, block_size, num_seq_max_blocks, + one_over_dk_log2e, head_kv_divmod, head_q_divmod, tile_m_divmod); } } void warp_spec_with_kvcache_fp8_dim128_async( void *y_ptr, const void *q_ptr, const void *kcache_ptr, const void *vcache_ptr, - const void *qkscale_ptr, const void *vscale_ptr, const void *cu_seqlens_q_ptr, - const void *block_ids_ptr, const void *seqlens_kvcache_ptr, void *tmas_ptr, int num_batch, - int total_seq_q, int max_seq_q, int max_seq_q_pad, int num_dim_qk, int num_dim_v, - int num_head_q, int num_head_kv, int num_kvcache_blocks, int block_size, int num_seq_max_blocks, - int ldY, int ldQ, int ldK, int ldV, cudaStream_t stream) { + const void *qscale_ptr, const void *kscale_ptr, const void *vscale_ptr, + const void *cu_seqlens_q_ptr, const void *block_ids_ptr, const void *seqlens_kvcache_ptr, + void *tmas_ptr, int num_batch, int total_seq_q, int max_seq_q, int max_seq_q_pad, + int num_dim_qk, int num_dim_v, int num_head_q, int num_head_kv, int num_kvcache_blocks, + int block_size, int num_seq_max_blocks, int ldY, int ldQ, int ldK, int ldV, + cudaStream_t stream) { if (block_size == 32) { constexpr int kBlockSize = 32; launch_warp_spec_with_kvcache_fp8_dim128( - y_ptr, q_ptr, kcache_ptr, vcache_ptr, qkscale_ptr, vscale_ptr, cu_seqlens_q_ptr, + y_ptr, q_ptr, kcache_ptr, vcache_ptr, qscale_ptr, kscale_ptr, vscale_ptr, cu_seqlens_q_ptr, block_ids_ptr, seqlens_kvcache_ptr, tmas_ptr, num_batch, total_seq_q, max_seq_q, max_seq_q_pad, num_dim_qk, num_dim_v, num_head_q, num_head_kv, num_kvcache_blocks, block_size, num_seq_max_blocks, ldY, ldQ, ldK, ldV, stream); } else if (block_size == 64) { constexpr int kBlockSize = 64; launch_warp_spec_with_kvcache_fp8_dim128( - y_ptr, q_ptr, kcache_ptr, vcache_ptr, qkscale_ptr, vscale_ptr, cu_seqlens_q_ptr, + y_ptr, q_ptr, kcache_ptr, vcache_ptr, qscale_ptr, kscale_ptr, vscale_ptr, cu_seqlens_q_ptr, block_ids_ptr, seqlens_kvcache_ptr, tmas_ptr, num_batch, total_seq_q, max_seq_q, max_seq_q_pad, num_dim_qk, num_dim_v, num_head_q, num_head_kv, num_kvcache_blocks, block_size, num_seq_max_blocks, ldY, ldQ, ldK, ldV, stream); diff --git a/src/attention/prefill/warp_spec_with_kvcache_fp8_dim128.h b/src/attention/prefill/warp_spec_with_kvcache_fp8_dim128.h index 4859886..3d8572e 100644 --- a/src/attention/prefill/warp_spec_with_kvcache_fp8_dim128.h +++ b/src/attention/prefill/warp_spec_with_kvcache_fp8_dim128.h @@ -9,14 +9,14 @@ namespace hpc { namespace attention { namespace prefill { - void warp_spec_with_kvcache_fp8_dim128_async( void *y_ptr, const void *q_ptr, const void *kcache_ptr, const void *vcache_ptr, - const void *qkscale_ptr, const void *vscale_ptr, const void *cu_seqlens_q_ptr, - const void *block_ids_ptr, const void *seqlens_kvcache_ptr, void *tmas_ptr, int num_batch, - int total_seq_q, int max_seq_q, int max_seq_q_pad, int num_dim_qk, int num_dim_v, - int num_head_q, int num_head_kv, int num_kvcache_blocks, int block_size, int num_seq_max_blocks, - int ldY, int ldQ, int ldK, int ldV, cudaStream_t stream); + const void *qscale_ptr, const void *kscale_ptr, const void *vscale_ptr, + const void *cu_seqlens_q_ptr, const void *block_ids_ptr, const void *seqlens_kvcache_ptr, + void *tmas_ptr, int num_batch, int total_seq_q, int max_seq_q, int max_seq_q_pad, + int num_dim_qk, int num_dim_v, int num_head_q, int num_head_kv, int num_kvcache_blocks, + int block_size, int num_seq_max_blocks, int ldY, int ldQ, int ldK, int ldV, + cudaStream_t stream); } // namespace prefill } // namespace attention diff --git a/tests/test_attention_with_kvcache_prefill_fp8.py b/tests/test_attention_with_kvcache_prefill_fp8.py index 9454208..70813ef 100644 --- a/tests/test_attention_with_kvcache_prefill_fp8.py +++ b/tests/test_attention_with_kvcache_prefill_fp8.py @@ -23,7 +23,8 @@ def naive_attn_with_kvcache_func( q, k_cache, v_cache, - qkscale, + qscale, + kscale, vscale, cache_seqlens, page_table, @@ -33,7 +34,7 @@ def naive_attn_with_kvcache_func( num_batch, num_seq_q, num_head_q, num_dim_qk = q.shape num_blocks, block_size, num_head_kv, _ = k_cache.shape _, _, _, num_dim_v = v_cache.shape - _, _, max_seq_q_pad = qkscale.shape + _, _, max_seq_q_pad = qscale.shape num_group = num_head_q // num_head_kv output = torch.empty_like(q).to(torch.bfloat16) @@ -57,9 +58,9 @@ def naive_attn_with_kvcache_func( .repeat_interleave(num_group, dim=0) ).float() - scale = qkscale[i, :, :].unsqueeze(-1)[:, :num_seq_q, :] + scale = qscale[i, :, :].unsqueeze(-1)[:, :num_seq_q, :] - scores = torch.matmul(BQ, BK.transpose(-2, -1)) * scale / math.sqrt(num_dim_qk) + scores = torch.matmul(BQ, BK.transpose(-2, -1)) * scale * kscale[0] / math.sqrt(num_dim_qk) if causal: causal_mask = ( torch.tril(torch.ones(num_seq_kv, num_seq_kv, device=q.device, dtype=torch.bool))[ @@ -128,12 +129,13 @@ def test_attention_with_kvcache_prefill_fp8( v = torch.randn( (num_batch * num_seq_q, num_head_kv, num_dim_v), dtype=torch.bfloat16, device="cuda" ).to(torch.float8_e4m3fn) - qkscale = ( + qscale = ( torch.abs( torch.randn((num_batch, num_head_q, num_seq_q_pad), dtype=torch.float32, device="cuda") ) / 10 ) + kscale = torch.randn((1), dtype=torch.float32, device="cuda").abs() * 10 vscale = torch.randn((1), dtype=torch.float32, device="cuda") seqlens_q = torch.full((num_batch,), num_seq_q, dtype=torch.int32, device="cuda") @@ -169,7 +171,8 @@ def test_attention_with_kvcache_prefill_fp8( q=q, k_cache=kvcache[:, 0, :, :], v_cache=kvcache[:, 1, :, :], - qkscale=qkscale, + qscale=qscale, + kscale=kscale, vscale=vscale, cache_seqlens=seqlens_kvcache, page_table=block_ids, @@ -180,7 +183,8 @@ def test_attention_with_kvcache_prefill_fp8( q.reshape(-1, num_head_q, num_dim_qk), kvcache[:, 0, :, :, :], kvcache[:, 1, :, :, :], - qkscale, + qscale, + kscale, vscale, cu_seqlens_q, block_ids,