diff --git a/src/attention/decode/smallm_dim128.cu b/src/attention/decode/smallm_dim128.cu index 7b6fdbd..ab4a2e1 100644 --- a/src/attention/decode/smallm_dim128.cu +++ b/src/attention/decode/smallm_dim128.cu @@ -30,7 +30,7 @@ void launch_attention_decode_bf16_dim128_smallm( auto Q = make_tensor(make_gmem_ptr(reinterpret_cast(q_ptr)), make_shape(num_head_q, num_dim_qk, num_batch), make_stride(num_dim_qk, Int<1>{}, ldQ)); - + auto K = make_tensor(make_gmem_ptr(reinterpret_cast(kcache_ptr)), make_shape(kBlockSize, num_dim_qk, num_head_k, num_kvcache_blocks), make_stride(num_dim_qk * num_head_k, Int<1>{}, num_dim_qk, ldK)); @@ -38,8 +38,8 @@ void launch_attention_decode_bf16_dim128_smallm( auto V = make_tensor(make_gmem_ptr(reinterpret_cast(vcache_ptr)), make_shape(num_dim_v, kBlockSize, num_head_v, num_kvcache_blocks), make_stride(Int<1>{}, num_head_v * num_dim_v, num_dim_v, ldV)); - - auto Y = make_tensor(make_gmem_ptr(reinterpret_cast(y_ptr)), + + auto Y = make_tensor(make_gmem_ptr(reinterpret_cast(y_ptr)), make_shape(num_dim_v, num_head_q, num_batch), make_stride(Int<1>{}, num_dim_v, ldY)); @@ -67,12 +67,12 @@ void launch_attention_decode_bf16_dim128_smallm( make_shape(Int{}, Int{})); auto tma_copy_layout_y = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(Int{}, Int{})); - + auto tma_q = make_tma_copy(SM90_TMA_LOAD{}, Q, tma_copy_layout_q); auto tma_k = make_tma_copy(SM90_TMA_LOAD{}, K, tma_copy_layout_k); auto tma_v = make_tma_copy(SM90_TMA_LOAD{}, V, tma_copy_layout_v); auto tma_y = make_tma_copy(SM90_TMA_STORE{}, Y, tma_copy_layout_y); - + using TiledMmaQK = decltype(make_tiled_mma(SM90_64x8x16_F32BF16BF16_SS{})); using TiledMmaSV = @@ -94,15 +94,16 @@ void launch_attention_decode_bf16_dim128_smallm( auto kernel = kernels::attention_decode_bf16_multistage_ws_smallm_kernel< Tout, Tin, kTileM, kTileN, kTileK, kTileV, TiledMmaQK, TiledMmaSV, decltype(tma_q), - decltype(tma_k), decltype(tma_v), decltype(tma_y), decltype(slayout_q), decltype(slayout_k), + decltype(tma_k), decltype(tma_v), decltype(tma_y), decltype(Q), decltype(Y), + decltype(slayout_q), decltype(slayout_k), decltype(slayout_p), decltype(slayout_s), decltype(slayout_v), decltype(slayout_y), kBlockSize, kStage>; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); kernel<<>>( - tma_q, tma_k, tma_v, tma_y, block_ids_ptr, num_seq_kvcache_ptr, new_kv_included, num_batch, - num_dim_qk, num_dim_v, num_head_q, num_head_k, num_head_v, heads_per_group, - num_kvcache_blocks, num_seq_max_blocks, one_over_dk_log2e); + tma_q, tma_k, tma_v, tma_y, Q, Y, block_ids_ptr, num_seq_kvcache_ptr, + new_kv_included, num_batch, num_dim_qk, num_dim_v, num_head_q, num_head_k, num_head_v, + heads_per_group, num_kvcache_blocks, num_seq_max_blocks, one_over_dk_log2e); } bool smallm_dim128_async(void *y_ptr, const void *q_ptr, void *kcache_ptr, void *vcache_ptr, @@ -112,7 +113,7 @@ bool smallm_dim128_async(void *y_ptr, const void *q_ptr, void *kcache_ptr, void int block_size, int num_seq_max_blocks, int ldY, int ldQ, int ldK, int ldV, cudaStream_t stream) { using namespace cute; // NOLINT - + constexpr int kTileM = 64; constexpr int kTileN = 8; constexpr int kTileK = 128; @@ -126,7 +127,14 @@ bool smallm_dim128_async(void *y_ptr, const void *q_ptr, void *kcache_ptr, void } int heads_per_group = num_head_q / num_head_k; - if (heads_per_group == 8 || heads_per_group == 4) { + + if (heads_per_group == 1 || heads_per_group > 8) { + std::cout << "launch launch_attention_decode_bf16_dim128_smallm failed with " + << " heads_per_group: " << heads_per_group << std::endl; + return false; + } + + if (heads_per_group <= 8) { constexpr int kHeadsPerGroup = 8; if (block_size == 32) { constexpr int kBlockSize = 32; diff --git a/src/attention/decode/smallm_kernels.cuh b/src/attention/decode/smallm_kernels.cuh index 28b80b0..3f523bc 100644 --- a/src/attention/decode/smallm_kernels.cuh +++ b/src/attention/decode/smallm_kernels.cuh @@ -193,21 +193,87 @@ __device__ __forceinline__ void load_paged_kv(TmaK &tma_k, TmaV &tma_v, uint64_t sizeof(Tin) * load_blocks * kBlockSize * num_dim_v); } +// Swizzle safety (must stay in sync with SLayoutQ / SLayoutY): +// - Q load: num_dim_qk multiple of 8; uint4 never crosses head boundaries. +// - Y store: Swizzle period 16 BF16; d aligned to 8 => uint4 safe; float->BF16 in registers. + +template +__device__ __forceinline__ void load_q_group_direct_to_smem( + TensorQG const &Q, TensorSQ &sQ, int ihead_q0, int ibatch, + int heads_per_group, int num_dim_qk, int rank_in_threads, int num_threads) { + using namespace cute; // NOLINT + + const int total_elems = heads_per_group * num_dim_qk; + constexpr int kVecSize = 8; // uint4 = 128 bits = 8 BF16 elements + const int kVecStride = num_threads * kVecSize; + + const Tin *q_base = Q(ihead_q0, _, ibatch).data().get(); + for (int base = rank_in_threads * kVecSize; base + kVecSize <= total_elems; + base += kVecStride) { + int lh = base / num_dim_qk; + int k = base % num_dim_qk; + store(&sQ(lh, k), load(q_base + base)); + } + const int vec_covered = (total_elems / kVecStride) * kVecStride; + for (int elem = vec_covered + rank_in_threads; elem < total_elems; elem += num_threads) { + int lh = elem / num_dim_qk; + int k = elem % num_dim_qk; + sQ(lh, k) = Q(ihead_q0 + lh, k, ibatch); + } +} + +template +__device__ __forceinline__ void store_sY_to_gmem_bf16( + TensorSY &sY, TensorGY &Y, int ihead_q0, int ibatch, int heads_per_group, int num_dim_v, + int idx, int kMathThreads) { + using namespace cute; // NOLINT + + const int vec_size = 8; + const int num_dim_v_vec = num_dim_v / vec_size; + const int num_dim_v_rem = num_dim_v % vec_size; + const int total_vec = heads_per_group * num_dim_v_vec; + for (int lin = idx; lin < total_vec; lin += kMathThreads) { + int lh = lin / num_dim_v_vec; + int d_idx = lin % num_dim_v_vec; + int d = d_idx * vec_size; + uint4 val; + Tout *v16 = reinterpret_cast(&val); +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + v16[i] = static_cast(static_cast(sY(d + i, lh))); + } + + store(&Y(d, ihead_q0 + lh, ibatch), load(v16)); + } + if (num_dim_v_rem > 0) { + const int d_base = num_dim_v_vec * vec_size; + const int total_rem = heads_per_group * num_dim_v_rem; + for (int lin = idx; lin < total_rem; lin += kMathThreads) { + int lh = lin / num_dim_v_rem; + int d = d_base + lin % num_dim_v_rem; + Y(d, ihead_q0 + lh, ibatch) = static_cast(static_cast(sY(d, lh))); + } + } +} + template + typename TmaY, typename TensorQ, typename TensorY, typename SLayoutQ, typename SLayoutK, + typename SLayoutP, typename SLayoutS, typename SLayoutV, typename SLayoutY, + int kBlockSize, int kStage> __global__ void attention_decode_bf16_multistage_ws_smallm_kernel( const __grid_constant__ TmaQ tma_q, const __grid_constant__ TmaK tma_k, const __grid_constant__ TmaV tma_v, const __grid_constant__ TmaY tma_y, - const int *block_ids_ptr, const int *num_seq_kvcache_ptr, bool new_kv_included, int num_batch, - int num_dim_qk, int num_dim_v, int num_head_q, int num_head_k, int num_head_v, - int heads_per_group, int num_kvcache_blocks, int num_seq_max_blocks, float one_over_dk_log2e) { + TensorQ Q, TensorY Y, const int *block_ids_ptr, const int *num_seq_kvcache_ptr, + bool new_kv_included, int num_batch, int num_dim_qk, int num_dim_v, + int num_head_q, int num_head_k, int num_head_v, int heads_per_group, + int num_kvcache_blocks, int num_seq_max_blocks, float one_over_dk_log2e) { using namespace cute; // NOLINT int idx = threadIdx.x; int ihead_kv = blockIdx.x; int ibatch = blockIdx.y; + const int ihead_q0 = ihead_kv * heads_per_group; constexpr int kMathThreads = size(TiledMmaQK{}); constexpr int kWarpsPerWrapGroup = 4; @@ -310,11 +376,13 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_kernel( // cutlass::arch::warpgroup_reg_dealloc<24>(); bool is_leader_in_load = ((iwarp == kMathThreads / 32) && elected); - if (is_leader_in_load) { - // Load Q - cute::copy(tma_q.with(q_readable), tQg(_, ihead_kv, _, ibatch), tQs(_, 0, _)); - set_barrier_transaction_bytes( - q_readable, sizeof(Tin) * max(heads_per_group, size<0, 0, 1>(tQg)) * num_dim_qk); + if ((heads_per_group == kTileN) || (num_head_q == 4 && num_head_k == 1)) { + if (is_leader_in_load) { + // Load Q + cute::copy(tma_q.with(q_readable), tQg(_, ihead_kv, _, ibatch), tQs(_, 0, _)); + set_barrier_transaction_bytes( + q_readable, sizeof(Tin) * max(heads_per_group, size<0, 0, 1>(tQg)) * num_dim_qk); + } } } @@ -413,7 +481,14 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_kernel( tiled_mma_sv.accumulate_ = GMMA::ScaleOut::One; - wait_barrier(q_readable, 0); + if ((heads_per_group == kTileN) || (num_head_q == 4 && num_head_k == 1)) { + wait_barrier(q_readable, 0); + } else { + // if not using TMA, math warpgroup loads Q using kMathThreads threads. + load_q_group_direct_to_smem(Q, sQ, ihead_q0, ibatch, heads_per_group, num_dim_qk, + idx, kMathThreads); + syncwarpgroup(iwarpgroup); + } int phase = 0; int istage_read = 0; @@ -449,7 +524,7 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_kernel( int iposq = num_seq_kvcache + get<1>(tI_mn(im, in)) / heads_per_group; int iposk = itile_seq_kv * kTileM + get<0>(tI_mn(im, in)); - if ((iposk > iposq) || (iposk >= num_seq_kv)) { + if ((iposk > iposq) || (iposk >= num_seq_kv) || (in >= heads_per_group)) { tAttr_mn(im, in) = -std::numeric_limits::infinity(); } } @@ -515,9 +590,20 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_kernel( arrive_barrier(k_writable[istage_read]); } + auto tAttr_mn_full = retile_fragment(tAttr); +#pragma unroll + for (int im = 0; im < kM; ++im) { +#pragma unroll + for (int in = 0; in < kN; ++in) { + if (in >= heads_per_group) { + tAttr_mn_full(im, in) = -std::numeric_limits::infinity(); + } + } + } + auto tYr_mn = retile_fragment(tYr); // online softmax - online_softmax(tAttr_mn, gMax, gSum, tYr_mn, one_over_dk_log2e, + online_softmax(tAttr_mn_full, gMax, gSum, tYr_mn, one_over_dk_log2e, shm_max, iwarpgroup, iwarp_in_warpgroup, ilane_in_warpgroup); @@ -578,12 +664,15 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_kernel( cute::copy(r2s_tiled_copy, tYr4s, tYs4r); syncwarpgroup(iwarpgroup); tma_store_fence(); - // using TMA to store - if (is_leader_in_warpgroup) { - auto tYss = btma_y.partition_S(sY); // (TMA, TMA_M, TMA_N) - auto tYgg = btma_y.partition_D(gY); // (TMA, TMA_M, TMA_N, b) - - cute::copy(tma_y, tYss(_, _, 0), tYgg(_, _, ihead_kv, ibatch)); + if ((heads_per_group == kTileN) || (num_head_q == 4 && num_head_k == 1)) { + if (is_leader_in_warpgroup) { + auto tYss = btma_y.partition_S(sY); // (TMA, TMA_M, TMA_N) + auto tYgg = btma_y.partition_D(gY); // (TMA, TMA_M, TMA_N, b) + cute::copy(tma_y, tYss(_, _, 0), tYgg(_, _, ihead_kv, ibatch)); + } + } else { + store_sY_to_gmem_bf16(sY, Y, ihead_q0, ibatch, heads_per_group, num_dim_v, + idx, kMathThreads); } } } diff --git a/src/attention/decode/smallm_splitk_combine_kernels.cuh b/src/attention/decode/smallm_splitk_combine_kernels.cuh index 65a0041..44eaa74 100644 --- a/src/attention/decode/smallm_splitk_combine_kernels.cuh +++ b/src/attention/decode/smallm_splitk_combine_kernels.cuh @@ -23,9 +23,16 @@ __global__ void attention_decode_bf16_smallm_splitk_combine_kernel( T *y_ptr, const float *split_input_ptr, const float *lse_ptr, const int *num_seq_kvcache_ptr, bool new_kv_included, int num_head_q) { int ibatch = blockIdx.x; - int ihead = threadIdx.x / 32; + // Each warp handles one Q-head. If num_head_q * 32 > 1024, tile along + // blockIdx.y so that each block contains at most 1024 threads. + int heads_per_block = blockDim.x / 32; + int ihead = blockIdx.y * heads_per_block + threadIdx.x / 32; int ilane = threadIdx.x % 32; + if (ihead >= num_head_q) { + return; + } + constexpr int kItemsPerThread = 4; constexpr int kSeqlenQ = 1; diff --git a/src/attention/decode/smallm_splitk_dim128.cu b/src/attention/decode/smallm_splitk_dim128.cu index 35c9d26..0250873 100644 --- a/src/attention/decode/smallm_splitk_dim128.cu +++ b/src/attention/decode/smallm_splitk_dim128.cu @@ -41,7 +41,7 @@ void launch_attention_decode_bf16_dim128_smallm_splitk( make_shape(num_dim_v, kBlockSize, num_head_v, num_kvcache_blocks), make_stride(Int<1>{}, num_head_v * num_dim_v, num_dim_v, ldV)); - auto Y = make_tensor(make_gmem_ptr(reinterpret_cast(y_ptr)), + auto Y = make_tensor(make_gmem_ptr(reinterpret_cast(y_ptr)), make_shape(num_dim_v, num_head_q, num_batch), make_stride(Int<1>{}, num_dim_v, ldY)); @@ -107,14 +107,16 @@ void launch_attention_decode_bf16_dim128_smallm_splitk( auto kernel = kernels::attention_decode_bf16_multistage_ws_smallm_splitk_kernel< Tout, Tin, kTileM, kTileN, kTileK, kTileV, TiledMmaQK, TiledMmaSV, decltype(tma_q), - decltype(tma_k), decltype(tma_v), decltype(tma_y), decltype(tma_splity), decltype(slayout_q), + decltype(tma_k), decltype(tma_v), decltype(tma_y), decltype(tma_splity), + decltype(Q), decltype(Y), decltype(splitY), decltype(slayout_q), decltype(slayout_k), decltype(slayout_p), decltype(slayout_s), decltype(slayout_v), decltype(slayout_y), decltype(slayout_splity), kBlockSize, kStage, kSplitK, kSplitMinLen>; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); kernel<<>>( - tma_q, tma_k, tma_v, tma_y, tma_splity, reinterpret_cast(lse_ptr), block_ids_ptr, - num_seq_kvcache_ptr, new_kv_included, num_batch, num_dim_qk, num_dim_v, num_head_q, + tma_q, tma_k, tma_v, tma_y, tma_splity, Q, Y, splitY, + reinterpret_cast(lse_ptr), block_ids_ptr, num_seq_kvcache_ptr, + new_kv_included, num_batch, num_dim_qk, num_dim_v, num_head_q, num_head_k, num_head_v, heads_per_group, num_kvcache_blocks, num_seq_max_blocks, one_over_dk_log2e); } @@ -144,10 +146,16 @@ bool smallm_splitk_dim128_async(void *y_ptr, void *lse_ptr, void *splitk_out_ptr int heads_per_group = num_head_q / num_head_k; + if (heads_per_group == 1 || heads_per_group > 8) { + std::cout << "launch launch_attention_decode_bf16_dim128_smallm failed with " + << " heads_per_group: " << heads_per_group << std::endl; + return false; + } + if (splitk == 4) { constexpr int kSplitK = 4; constexpr int kSplitMinLen = 4096; - if (heads_per_group == 8 || heads_per_group == 4) { + if (heads_per_group <= 8) { constexpr int kHeadsPerGroup = 8; if (block_size == 32) { constexpr int kBlockSize = 32; @@ -167,8 +175,11 @@ bool smallm_splitk_dim128_async(void *y_ptr, void *lse_ptr, void *splitk_out_ptr num_seq_max_blocks, ldY, ldQ, ldK, ldV, stream); } using Tout = __nv_bfloat16; - dim3 grid(num_batch); - dim3 block(32 * num_head_q); + + // Cap heads_per_block at 32 (= 1024 / 32) and tile extra heads in blockIdx.y. + int heads_per_block = std::min(num_head_q, 32); + dim3 grid(num_batch, (num_head_q + heads_per_block - 1) / heads_per_block); + dim3 block(32 * heads_per_block); kernels::attention_decode_bf16_smallm_splitk_combine_kernel <<>>(reinterpret_cast(y_ptr), @@ -179,7 +190,7 @@ bool smallm_splitk_dim128_async(void *y_ptr, void *lse_ptr, void *splitk_out_ptr } else if (splitk == 16) { constexpr int kSplitK = 16; constexpr int kSplitMinLen = 512; - if (heads_per_group == 8 || heads_per_group == 4) { + if (heads_per_group <= 8) { constexpr int kHeadsPerGroup = 8; if (block_size == 32) { constexpr int kBlockSize = 32; @@ -199,8 +210,11 @@ bool smallm_splitk_dim128_async(void *y_ptr, void *lse_ptr, void *splitk_out_ptr num_seq_max_blocks, ldY, ldQ, ldK, ldV, stream); } using Tout = __nv_bfloat16; - dim3 grid(num_batch); - dim3 block(32 * num_head_q); + + // Cap heads_per_block at 32 (= 1024 / 32) and tile extra heads in blockIdx.y. + int heads_per_block = std::min(num_head_q, 32); + dim3 grid(num_batch, (num_head_q + heads_per_block - 1) / heads_per_block); + dim3 block(32 * heads_per_block); kernels::attention_decode_bf16_smallm_splitk_combine_kernel <<>>(reinterpret_cast(y_ptr), diff --git a/src/attention/decode/smallm_splitk_kernels.cuh b/src/attention/decode/smallm_splitk_kernels.cuh index 5f7cfa6..00d51f7 100644 --- a/src/attention/decode/smallm_splitk_kernels.cuh +++ b/src/attention/decode/smallm_splitk_kernels.cuh @@ -194,17 +194,115 @@ __device__ __forceinline__ void load_paged_kv(TmaK &tma_k, TmaV &tma_v, uint64_t sizeof(Tin) * load_blocks * kBlockSize * num_dim_v); } +// Swizzle safety (must stay in sync with SLayoutQ / SLayoutY / SLayoutSplitY): +// - Q load: num_dim_qk multiple of 8; uint4 never crosses head boundaries. +// - Y store: Swizzle period 16 BF16; d aligned to 8 => uint4 safe; float->BF16 in registers. +// - splitY store: Swizzle period 8 floats; d aligned to 4 => float4 safe. + +template +__device__ __forceinline__ void load_q_group_direct_to_smem( + TensorQG const &Q, TensorSQ &sQ, int ihead_q0, int ibatch, + int heads_per_group, int num_dim_qk, int rank_in_threads, int num_threads) { + using namespace cute; // NOLINT + + const int total_elems = heads_per_group * num_dim_qk; + constexpr int kVecSize = 8; // uint4 = 128 bits = 8 BF16 elements + const int kVecStride = num_threads * kVecSize; + + const Tin *q_base = Q(ihead_q0, _, ibatch).data().get(); + for (int base = rank_in_threads * kVecSize; base + kVecSize <= total_elems; + base += kVecStride) { + int lh = base / num_dim_qk; + int k = base % num_dim_qk; + store(&sQ(lh, k), load(q_base + base)); + } + const int vec_covered = (total_elems / kVecStride) * kVecStride; + for (int elem = vec_covered + rank_in_threads; elem < total_elems; elem += num_threads) { + int lh = elem / num_dim_qk; + int k = elem % num_dim_qk; + sQ(lh, k) = Q(ihead_q0 + lh, k, ibatch); + } +} + +template +__device__ __forceinline__ void store_sY_to_gmem_bf16( + TensorSY &sY, TensorGY &Y, int ihead_q0, int ibatch, int heads_per_group, int num_dim_v, + int idx, int kMathThreads) { + using namespace cute; // NOLINT + + const int vec_size = 8; + const int num_dim_v_vec = num_dim_v / vec_size; + const int num_dim_v_rem = num_dim_v % vec_size; + const int total_vec = heads_per_group * num_dim_v_vec; + for (int lin = idx; lin < total_vec; lin += kMathThreads) { + int lh = lin / num_dim_v_vec; + int d_idx = lin % num_dim_v_vec; + int d = d_idx * vec_size; + uint4 val; + Tout *v16 = reinterpret_cast(&val); +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + v16[i] = static_cast(static_cast(sY(d + i, lh))); + } + + store(&Y(d, ihead_q0 + lh, ibatch), load(v16)); + } + if (num_dim_v_rem > 0) { + const int d_base = num_dim_v_vec * vec_size; + const int total_rem = heads_per_group * num_dim_v_rem; + for (int lin = idx; lin < total_rem; lin += kMathThreads) { + int lh = lin / num_dim_v_rem; + int d = d_base + lin % num_dim_v_rem; + Y(d, ihead_q0 + lh, ibatch) = static_cast(static_cast(sY(d, lh))); + } + } +} + +template +__device__ __forceinline__ void store_sSplitY_to_gmem_float( + TensorSSplitY &sSplitY, TensorGSplitY &splitY, int ihead_q0, int ichunk, int ibatch, + int heads_per_group, int num_dim_v, int idx, int kMathThreads) { + using namespace cute; // NOLINT + + const int vec_size = 4; + const int num_dim_v_vec = num_dim_v / vec_size; + const int num_dim_v_rem = num_dim_v % vec_size; + const int total_vec = heads_per_group * num_dim_v_vec; + for (int lin = idx; lin < total_vec; lin += kMathThreads) { + int lh = lin / num_dim_v_vec; + int d_idx = lin % num_dim_v_vec; + int d = d_idx * vec_size; + float4 val; + val.x = sSplitY(d, lh); + val.y = sSplitY(d + 1, lh); + val.z = sSplitY(d + 2, lh); + val.w = sSplitY(d + 3, lh); + + store(&splitY(d, ihead_q0 + lh, ichunk, ibatch), load(&val)); + } + if (num_dim_v_rem > 0) { + const int d_base = num_dim_v_vec * vec_size; + const int total_rem = heads_per_group * num_dim_v_rem; + for (int lin = idx; lin < total_rem; lin += kMathThreads) { + int lh = lin / num_dim_v_rem; + int d = d_base + lin % num_dim_v_rem; + splitY(d, ihead_q0 + lh, ichunk, ibatch) = sSplitY(d, lh); + } + } +} + template __global__ void attention_decode_bf16_multistage_ws_smallm_splitk_kernel( const __grid_constant__ TmaQ tma_q, const __grid_constant__ TmaK tma_k, const __grid_constant__ TmaV tma_v, const __grid_constant__ TmaY tma_y, - const __grid_constant__ TmaSplitY tma_splity, float *lse_ptr, const int *block_ids_ptr, - const int *num_seq_kvcache_ptr, bool new_kv_included, int num_batch, int num_dim_qk, - int num_dim_v, int num_head_q, int num_head_k, int num_head_v, int heads_per_group, + const __grid_constant__ TmaSplitY tma_splity, TensorQ Q, TensorY Y, TensorSplitY splitY, + float *lse_ptr, const int *block_ids_ptr, const int *num_seq_kvcache_ptr, + bool new_kv_included, int num_batch, int num_dim_qk, int num_dim_v, + int num_head_q, int num_head_k, int num_head_v, int heads_per_group, int num_kvcache_blocks, int num_seq_max_blocks, float one_over_dk_log2e) { using namespace cute; // NOLINT @@ -212,6 +310,7 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_splitk_kernel( int ihead_kv = blockIdx.x; int ibatch = blockIdx.y; int ichunk = blockIdx.z; + const int ihead_q0 = ihead_kv * heads_per_group; constexpr int kMathThreads = size(TiledMmaQK{}); constexpr int kWarpsPerWrapGroup = 4; @@ -344,11 +443,13 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_splitk_kernel( // cutlass::arch::warpgroup_reg_dealloc<24>(); bool is_leader_in_load = ((iwarp == kMathThreads / 32) && elected); - if (is_leader_in_load) { - // Load Q - cute::copy(tma_q.with(q_readable), tQg(_, ihead_kv, _, ibatch), tQs(_, 0, _)); - set_barrier_transaction_bytes( - q_readable, sizeof(Tin) * max(heads_per_group, size<0, 0, 1>(tQg)) * num_dim_qk); + if ((heads_per_group == kTileN) || (num_head_q == 4 && num_head_k == 1)) { + if (is_leader_in_load) { + // Load Q + cute::copy(tma_q.with(q_readable), tQg(_, ihead_kv, _, ibatch), tQs(_, 0, _)); + set_barrier_transaction_bytes( + q_readable, sizeof(Tin) * max(heads_per_group, size<0, 0, 1>(tQg)) * num_dim_qk); + } } } @@ -429,9 +530,9 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_splitk_kernel( auto gI = make_identity_tensor(gAtt.shape()); auto tI = thr_mma_qk.partition_C(gI); - auto tAttr_mn = retile_fragment(tAttr); - constexpr int kM = size<0>(tAttr_mn); - constexpr int kN = size<1>(tAttr_mn); + auto tAttr_mn_shape = retile_fragment(tAttr); + constexpr int kM = size<0>(tAttr_mn_shape); + constexpr int kN = size<1>(tAttr_mn_shape); Tensor gMax = make_tensor(Int{}); Tensor gSum = make_tensor(Int{}); @@ -447,7 +548,14 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_splitk_kernel( tiled_mma_sv.accumulate_ = GMMA::ScaleOut::One; - wait_barrier(q_readable, 0); + if ((heads_per_group == kTileN) || (num_head_q == 4 && num_head_k == 1)) { + wait_barrier(q_readable, 0); + } else { + // if not using TMA, math warpgroup loads Q using kMathThreads threads. + load_q_group_direct_to_smem(Q, sQ, ihead_q0, ibatch, heads_per_group, num_dim_qk, + idx, kMathThreads); + syncwarpgroup(iwarpgroup); + } int phase = 0; int istage_read = 0; @@ -483,7 +591,7 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_splitk_kernel( int iposq = num_seq_kvcache + get<1>(tI_mn(im, in)) / heads_per_group; int iposk = itile_seq_kv * kTileM + get<0>(tI_mn(im, in)); - if ((iposk > iposq) || (iposk >= num_seq_kv)) { + if ((in >= heads_per_group) || (iposk > iposq) || (iposk >= num_seq_kv)) { tAttr_mn(im, in) = -std::numeric_limits::infinity(); } } @@ -549,9 +657,20 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_splitk_kernel( arrive_barrier(k_writable[istage_read]); } + auto tAttr_mn_full = retile_fragment(tAttr); +#pragma unroll + for (int im = 0; im < kM; ++im) { +#pragma unroll + for (int in = 0; in < kN; ++in) { + if (in >= heads_per_group) { + tAttr_mn_full(im, in) = -std::numeric_limits::infinity(); + } + } + } + auto tYr_mn = retile_fragment(tYr); // online softmax - online_softmax(tAttr_mn, gMax, gSum, tYr_mn, one_over_dk_log2e, + online_softmax(tAttr_mn_full, gMax, gSum, tYr_mn, one_over_dk_log2e, shm_max, iwarpgroup, iwarp_in_warpgroup, ilane_in_warpgroup); @@ -613,12 +732,17 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_splitk_kernel( cute::copy(r2s_tiled_copy, tYr4s, tYs4r); syncwarpgroup(iwarpgroup); tma_store_fence(); - // using TMA to store - if (is_leader_in_warpgroup) { - auto tYss = btma_y.partition_S(sY); // (TMA, TMA_M, TMA_N) - auto tYgg = btma_y.partition_D(gY); // (TMA, TMA_M, TMA_N, b) - cute::copy(tma_y, tYss(_, _, 0), tYgg(_, _, ihead_kv, ibatch)); + if ((heads_per_group == kTileN) || (num_head_q == 4 && num_head_k == 1)) { + // using TMA to store + if (is_leader_in_warpgroup) { + auto tYss = btma_y.partition_S(sY); // (TMA, TMA_M, TMA_N) + auto tYgg = btma_y.partition_D(gY); // (TMA, TMA_M, TMA_N, b) + cute::copy(tma_y, tYss(_, _, 0), tYgg(_, _, ihead_kv, ibatch)); + } + } else { + store_sY_to_gmem_bf16(sY, Y, ihead_q0, ibatch, heads_per_group, num_dim_v, + idx, kMathThreads); } } else { // Epilogue: write register-C to global memory @@ -633,26 +757,42 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_splitk_kernel( syncwarpgroup(iwarpgroup); tma_store_fence(); - // using TMA to store - if (is_leader_in_warpgroup) { - auto tYss = btma_splity.partition_S(sSplitY); // (TMA, TMA_M, TMA_N) - auto tYgg = btma_splity.partition_D(gSplitY); // (TMA, TMA_M, TMA_N, b) - - cute::copy(tma_splity, tYss(_, _, 0), tYgg(_, _, ihead_kv, ichunk, ibatch)); + if ((heads_per_group == kTileN) || (num_head_q == 4 && num_head_k == 1)) { + // using TMA to store + if (is_leader_in_warpgroup) { + auto tYss = btma_splity.partition_S(sSplitY); // (TMA, TMA_M, TMA_N) + auto tYgg = btma_splity.partition_D(gSplitY); // (TMA, TMA_M, TMA_N, b) + cute::copy(tma_splity, tYss(_, _, 0), tYgg(_, _, ihead_kv, ichunk, ibatch)); + } + } else { + store_sSplitY_to_gmem_float(sSplitY, splitY, ihead_q0, ichunk, ibatch, + heads_per_group, num_dim_v, idx, kMathThreads); } int ilane = idx % 32; // write lse - if (iwarp == 0 && ilane < heads_per_group / kN) { - vec_t lse; + // Vector store for full kN-element groups, scalar for the remainder. + if (iwarp == 0 && ilane * kN < heads_per_group) { + bool base_aligned = (reinterpret_cast(lse_batch) % (sizeof(float) * kN)) == 0; + if (base_aligned && ilane * kN + kN <= heads_per_group) { + vec_t lse; #pragma unroll - for (int in = 0; in < kN; ++in) { - lse[in] = gMax(in) + log2f_ftz(gSum(in)); + for (int in = 0; in < kN; ++in) { + lse[in] = gMax(in) + log2f_ftz(gSum(in)); + } + store(lse_batch + ilane * kN, lse); + } else { +#pragma unroll + for (int in = 0; in < kN; ++in) { + if (ilane * kN + in < heads_per_group) { + lse_batch[ilane * kN + in] = gMax(in) + log2f_ftz(gSum(in)); + } + } } - store(lse_batch + ilane * kN, lse); } } } + } } // namespace kernels diff --git a/tests/test_attention_decode_bf16.py b/tests/test_attention_decode_bf16.py index ecee194..2d8ed67 100644 --- a/tests/test_attention_decode_bf16.py +++ b/tests/test_attention_decode_bf16.py @@ -70,7 +70,7 @@ def ref_attn_with_paged_kvcache_func( @pytest.mark.parametrize("num_seq_q", [1]) @pytest.mark.parametrize("max_seq_kv", [1024, 4096]) @pytest.mark.parametrize("block_size", [64]) -@pytest.mark.parametrize("kv_head_q_head", [(1, 4), (1, 8), (2, 16), (4, 32)]) +@pytest.mark.parametrize("kv_head_q_head", [(1, 2), (1, 4), (1, 5), (1, 7), (1, 8), (2, 14), (2, 16), (4, 32)]) @pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("new_kv_included", [True, False]) @pytest.mark.parametrize("splitk", [True, False])