Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions src/attention/decode/smallm_dim128.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ void launch_attention_decode_bf16_dim128_smallm(
auto Q = make_tensor(make_gmem_ptr(reinterpret_cast<const Tin *>(q_ptr)),
make_shape(num_head_q, num_dim_qk, num_batch),
make_stride(num_dim_qk, Int<1>{}, ldQ));

auto K = make_tensor(make_gmem_ptr(reinterpret_cast<const Tin *>(kcache_ptr)),
make_shape(kBlockSize, num_dim_qk, num_head_k, num_kvcache_blocks),
make_stride(num_dim_qk * num_head_k, Int<1>{}, num_dim_qk, ldK));

auto V = make_tensor(make_gmem_ptr(reinterpret_cast<const Tin *>(vcache_ptr)),
make_shape(num_dim_v, kBlockSize, num_head_v, num_kvcache_blocks),
make_stride(Int<1>{}, num_head_v * num_dim_v, num_dim_v, ldV));

auto Y = make_tensor(make_gmem_ptr(reinterpret_cast<const Tout *>(y_ptr)),
auto Y = make_tensor(make_gmem_ptr(reinterpret_cast<Tout *>(y_ptr)),
make_shape(num_dim_v, num_head_q, num_batch),
make_stride(Int<1>{}, num_dim_v, ldY));

Expand Down Expand Up @@ -67,12 +67,12 @@ void launch_attention_decode_bf16_dim128_smallm(
make_shape(Int<kTileV>{}, Int<kBlockSize>{}));
auto tma_copy_layout_y = tile_to_shape(GMMA::Layout_MN_SW128_Atom<Tin>{},
make_shape(Int<kTileV>{}, Int<kHeadsPerGroup>{}));

auto tma_q = make_tma_copy(SM90_TMA_LOAD{}, Q, tma_copy_layout_q);
auto tma_k = make_tma_copy(SM90_TMA_LOAD{}, K, tma_copy_layout_k);
auto tma_v = make_tma_copy(SM90_TMA_LOAD{}, V, tma_copy_layout_v);
auto tma_y = make_tma_copy(SM90_TMA_STORE{}, Y, tma_copy_layout_y);

using TiledMmaQK =
decltype(make_tiled_mma(SM90_64x8x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{}));
using TiledMmaSV =
Expand All @@ -94,15 +94,16 @@ void launch_attention_decode_bf16_dim128_smallm(

auto kernel = kernels::attention_decode_bf16_multistage_ws_smallm_kernel<
Tout, Tin, kTileM, kTileN, kTileK, kTileV, TiledMmaQK, TiledMmaSV, decltype(tma_q),
decltype(tma_k), decltype(tma_v), decltype(tma_y), decltype(slayout_q), decltype(slayout_k),
decltype(tma_k), decltype(tma_v), decltype(tma_y), decltype(Q), decltype(Y),
decltype(slayout_q), decltype(slayout_k),
decltype(slayout_p), decltype(slayout_s), decltype(slayout_v), decltype(slayout_y),
kBlockSize, kStage>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);

kernel<<<grid, block, shm_size, stream>>>(
tma_q, tma_k, tma_v, tma_y, block_ids_ptr, num_seq_kvcache_ptr, new_kv_included, num_batch,
num_dim_qk, num_dim_v, num_head_q, num_head_k, num_head_v, heads_per_group,
num_kvcache_blocks, num_seq_max_blocks, one_over_dk_log2e);
tma_q, tma_k, tma_v, tma_y, Q, Y, block_ids_ptr, num_seq_kvcache_ptr,
new_kv_included, num_batch, num_dim_qk, num_dim_v, num_head_q, num_head_k, num_head_v,
heads_per_group, num_kvcache_blocks, num_seq_max_blocks, one_over_dk_log2e);
}

bool smallm_dim128_async(void *y_ptr, const void *q_ptr, void *kcache_ptr, void *vcache_ptr,
Expand All @@ -112,7 +113,7 @@ bool smallm_dim128_async(void *y_ptr, const void *q_ptr, void *kcache_ptr, void
int block_size, int num_seq_max_blocks, int ldY, int ldQ, int ldK, int ldV,
cudaStream_t stream) {
using namespace cute; // NOLINT

constexpr int kTileM = 64;
constexpr int kTileN = 8;
constexpr int kTileK = 128;
Expand All @@ -126,7 +127,14 @@ bool smallm_dim128_async(void *y_ptr, const void *q_ptr, void *kcache_ptr, void
}

int heads_per_group = num_head_q / num_head_k;
if (heads_per_group == 8 || heads_per_group == 4) {

if (heads_per_group == 1 || heads_per_group > 8) {
std::cout << "launch launch_attention_decode_bf16_dim128_smallm failed with "
<< " heads_per_group: " << heads_per_group << std::endl;
return false;
}

if (heads_per_group <= 8) {
constexpr int kHeadsPerGroup = 8;
if (block_size == 32) {
constexpr int kBlockSize = 32;
Expand Down
127 changes: 108 additions & 19 deletions src/attention/decode/smallm_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -193,21 +193,87 @@ __device__ __forceinline__ void load_paged_kv(TmaK &tma_k, TmaV &tma_v, uint64_t
sizeof(Tin) * load_blocks * kBlockSize * num_dim_v);
}

// Swizzle safety (must stay in sync with SLayoutQ / SLayoutY):
// - Q load: num_dim_qk multiple of 8; uint4 never crosses head boundaries.
// - Y store: Swizzle period 16 BF16; d aligned to 8 => uint4 safe; float->BF16 in registers.

template <typename Tin, typename TensorQG, typename TensorSQ>
__device__ __forceinline__ void load_q_group_direct_to_smem(
TensorQG const &Q, TensorSQ &sQ, int ihead_q0, int ibatch,
int heads_per_group, int num_dim_qk, int rank_in_threads, int num_threads) {
using namespace cute; // NOLINT

const int total_elems = heads_per_group * num_dim_qk;
constexpr int kVecSize = 8; // uint4 = 128 bits = 8 BF16 elements
const int kVecStride = num_threads * kVecSize;

const Tin *q_base = Q(ihead_q0, _, ibatch).data().get();
for (int base = rank_in_threads * kVecSize; base + kVecSize <= total_elems;
base += kVecStride) {
int lh = base / num_dim_qk;
int k = base % num_dim_qk;
store(&sQ(lh, k), load<Tin, kVecSize>(q_base + base));
}
const int vec_covered = (total_elems / kVecStride) * kVecStride;
for (int elem = vec_covered + rank_in_threads; elem < total_elems; elem += num_threads) {
int lh = elem / num_dim_qk;
int k = elem % num_dim_qk;
sQ(lh, k) = Q(ihead_q0 + lh, k, ibatch);
}
}

template <typename Tout, typename TensorSY, typename TensorGY>
__device__ __forceinline__ void store_sY_to_gmem_bf16(
TensorSY &sY, TensorGY &Y, int ihead_q0, int ibatch, int heads_per_group, int num_dim_v,
int idx, int kMathThreads) {
using namespace cute; // NOLINT

const int vec_size = 8;
const int num_dim_v_vec = num_dim_v / vec_size;
const int num_dim_v_rem = num_dim_v % vec_size;
const int total_vec = heads_per_group * num_dim_v_vec;
for (int lin = idx; lin < total_vec; lin += kMathThreads) {
int lh = lin / num_dim_v_vec;
int d_idx = lin % num_dim_v_vec;
int d = d_idx * vec_size;
uint4 val;
Tout *v16 = reinterpret_cast<Tout *>(&val);
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
v16[i] = static_cast<Tout>(static_cast<float>(sY(d + i, lh)));
}

store(&Y(d, ihead_q0 + lh, ibatch), load<Tout, vec_size>(v16));
}
if (num_dim_v_rem > 0) {
const int d_base = num_dim_v_vec * vec_size;
const int total_rem = heads_per_group * num_dim_v_rem;
for (int lin = idx; lin < total_rem; lin += kMathThreads) {
int lh = lin / num_dim_v_rem;
int d = d_base + lin % num_dim_v_rem;
Y(d, ihead_q0 + lh, ibatch) = static_cast<Tout>(static_cast<float>(sY(d, lh)));
}
}
}

template <typename Tout, typename Tin, int kTileM, int kTileN, int kTileK, int kTileV,
typename TiledMmaQK, typename TiledMmaSV, typename TmaQ, typename TmaK, typename TmaV,
typename TmaY, typename SLayoutQ, typename SLayoutK, typename SLayoutP, typename SLayoutS,
typename SLayoutV, typename SLayoutY, int kBlockSize, int kStage>
typename TmaY, typename TensorQ, typename TensorY, typename SLayoutQ, typename SLayoutK,
typename SLayoutP, typename SLayoutS, typename SLayoutV, typename SLayoutY,
int kBlockSize, int kStage>
__global__ void attention_decode_bf16_multistage_ws_smallm_kernel(
const __grid_constant__ TmaQ tma_q, const __grid_constant__ TmaK tma_k,
const __grid_constant__ TmaV tma_v, const __grid_constant__ TmaY tma_y,
const int *block_ids_ptr, const int *num_seq_kvcache_ptr, bool new_kv_included, int num_batch,
int num_dim_qk, int num_dim_v, int num_head_q, int num_head_k, int num_head_v,
int heads_per_group, int num_kvcache_blocks, int num_seq_max_blocks, float one_over_dk_log2e) {
TensorQ Q, TensorY Y, const int *block_ids_ptr, const int *num_seq_kvcache_ptr,
bool new_kv_included, int num_batch, int num_dim_qk, int num_dim_v,
int num_head_q, int num_head_k, int num_head_v, int heads_per_group,
int num_kvcache_blocks, int num_seq_max_blocks, float one_over_dk_log2e) {
using namespace cute; // NOLINT

int idx = threadIdx.x;
int ihead_kv = blockIdx.x;
int ibatch = blockIdx.y;
const int ihead_q0 = ihead_kv * heads_per_group;

constexpr int kMathThreads = size(TiledMmaQK{});
constexpr int kWarpsPerWrapGroup = 4;
Expand Down Expand Up @@ -310,11 +376,13 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_kernel(
// cutlass::arch::warpgroup_reg_dealloc<24>();
bool is_leader_in_load = ((iwarp == kMathThreads / 32) && elected);

if (is_leader_in_load) {
// Load Q
cute::copy(tma_q.with(q_readable), tQg(_, ihead_kv, _, ibatch), tQs(_, 0, _));
set_barrier_transaction_bytes(
q_readable, sizeof(Tin) * max(heads_per_group, size<0, 0, 1>(tQg)) * num_dim_qk);
if ((heads_per_group == kTileN) || (num_head_q == 4 && num_head_k == 1)) {
if (is_leader_in_load) {
// Load Q
cute::copy(tma_q.with(q_readable), tQg(_, ihead_kv, _, ibatch), tQs(_, 0, _));
set_barrier_transaction_bytes(
q_readable, sizeof(Tin) * max(heads_per_group, size<0, 0, 1>(tQg)) * num_dim_qk);
}
}
}

Expand Down Expand Up @@ -413,7 +481,14 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_kernel(

tiled_mma_sv.accumulate_ = GMMA::ScaleOut::One;

wait_barrier(q_readable, 0);
if ((heads_per_group == kTileN) || (num_head_q == 4 && num_head_k == 1)) {
wait_barrier(q_readable, 0);
} else {
// if not using TMA, math warpgroup loads Q using kMathThreads threads.
load_q_group_direct_to_smem<Tin>(Q, sQ, ihead_q0, ibatch, heads_per_group, num_dim_qk,
idx, kMathThreads);
syncwarpgroup(iwarpgroup);
}

int phase = 0;
int istage_read = 0;
Expand Down Expand Up @@ -449,7 +524,7 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_kernel(
int iposq = num_seq_kvcache + get<1>(tI_mn(im, in)) / heads_per_group;
int iposk = itile_seq_kv * kTileM + get<0>(tI_mn(im, in));

if ((iposk > iposq) || (iposk >= num_seq_kv)) {
if ((iposk > iposq) || (iposk >= num_seq_kv) || (in >= heads_per_group)) {
tAttr_mn(im, in) = -std::numeric_limits<float>::infinity();
}
}
Expand Down Expand Up @@ -515,9 +590,20 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_kernel(
arrive_barrier(k_writable[istage_read]);
}

auto tAttr_mn_full = retile_fragment(tAttr);
#pragma unroll
for (int im = 0; im < kM; ++im) {
#pragma unroll
for (int in = 0; in < kN; ++in) {
if (in >= heads_per_group) {
tAttr_mn_full(im, in) = -std::numeric_limits<float>::infinity();
}
}
}

auto tYr_mn = retile_fragment(tYr);
// online softmax
online_softmax<false, kTileN, kM, kN>(tAttr_mn, gMax, gSum, tYr_mn, one_over_dk_log2e,
online_softmax<false, kTileN, kM, kN>(tAttr_mn_full, gMax, gSum, tYr_mn, one_over_dk_log2e,
shm_max, iwarpgroup, iwarp_in_warpgroup,
ilane_in_warpgroup);

Expand Down Expand Up @@ -578,12 +664,15 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_kernel(
cute::copy(r2s_tiled_copy, tYr4s, tYs4r);
syncwarpgroup(iwarpgroup);
tma_store_fence();
// using TMA to store
if (is_leader_in_warpgroup) {
auto tYss = btma_y.partition_S(sY); // (TMA, TMA_M, TMA_N)
auto tYgg = btma_y.partition_D(gY); // (TMA, TMA_M, TMA_N, b)

cute::copy(tma_y, tYss(_, _, 0), tYgg(_, _, ihead_kv, ibatch));
if ((heads_per_group == kTileN) || (num_head_q == 4 && num_head_k == 1)) {
if (is_leader_in_warpgroup) {
auto tYss = btma_y.partition_S(sY); // (TMA, TMA_M, TMA_N)
auto tYgg = btma_y.partition_D(gY); // (TMA, TMA_M, TMA_N, b)
cute::copy(tma_y, tYss(_, _, 0), tYgg(_, _, ihead_kv, ibatch));
}
} else {
store_sY_to_gmem_bf16<Tout>(sY, Y, ihead_q0, ibatch, heads_per_group, num_dim_v,
idx, kMathThreads);
}
}
}
Expand Down
9 changes: 8 additions & 1 deletion src/attention/decode/smallm_splitk_combine_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@ __global__ void attention_decode_bf16_smallm_splitk_combine_kernel(
T *y_ptr, const float *split_input_ptr, const float *lse_ptr, const int *num_seq_kvcache_ptr,
bool new_kv_included, int num_head_q) {
int ibatch = blockIdx.x;
int ihead = threadIdx.x / 32;
// Each warp handles one Q-head. If num_head_q * 32 > 1024, tile along
// blockIdx.y so that each block contains at most 1024 threads.
int heads_per_block = blockDim.x / 32;
int ihead = blockIdx.y * heads_per_block + threadIdx.x / 32;
int ilane = threadIdx.x % 32;

if (ihead >= num_head_q) {
return;
}

constexpr int kItemsPerThread = 4;
constexpr int kSeqlenQ = 1;

Expand Down
34 changes: 24 additions & 10 deletions src/attention/decode/smallm_splitk_dim128.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void launch_attention_decode_bf16_dim128_smallm_splitk(
make_shape(num_dim_v, kBlockSize, num_head_v, num_kvcache_blocks),
make_stride(Int<1>{}, num_head_v * num_dim_v, num_dim_v, ldV));

auto Y = make_tensor(make_gmem_ptr(reinterpret_cast<const Tout *>(y_ptr)),
auto Y = make_tensor(make_gmem_ptr(reinterpret_cast<Tout *>(y_ptr)),
make_shape(num_dim_v, num_head_q, num_batch),
make_stride(Int<1>{}, num_dim_v, ldY));

Expand Down Expand Up @@ -107,14 +107,16 @@ void launch_attention_decode_bf16_dim128_smallm_splitk(

auto kernel = kernels::attention_decode_bf16_multistage_ws_smallm_splitk_kernel<
Tout, Tin, kTileM, kTileN, kTileK, kTileV, TiledMmaQK, TiledMmaSV, decltype(tma_q),
decltype(tma_k), decltype(tma_v), decltype(tma_y), decltype(tma_splity), decltype(slayout_q),
decltype(tma_k), decltype(tma_v), decltype(tma_y), decltype(tma_splity),
decltype(Q), decltype(Y), decltype(splitY), decltype(slayout_q),
decltype(slayout_k), decltype(slayout_p), decltype(slayout_s), decltype(slayout_v),
decltype(slayout_y), decltype(slayout_splity), kBlockSize, kStage, kSplitK, kSplitMinLen>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);

kernel<<<grid, block, shm_size, stream>>>(
tma_q, tma_k, tma_v, tma_y, tma_splity, reinterpret_cast<float *>(lse_ptr), block_ids_ptr,
num_seq_kvcache_ptr, new_kv_included, num_batch, num_dim_qk, num_dim_v, num_head_q,
tma_q, tma_k, tma_v, tma_y, tma_splity, Q, Y, splitY,
reinterpret_cast<float *>(lse_ptr), block_ids_ptr, num_seq_kvcache_ptr,
new_kv_included, num_batch, num_dim_qk, num_dim_v, num_head_q,
num_head_k, num_head_v, heads_per_group, num_kvcache_blocks, num_seq_max_blocks,
one_over_dk_log2e);
}
Expand Down Expand Up @@ -144,10 +146,16 @@ bool smallm_splitk_dim128_async(void *y_ptr, void *lse_ptr, void *splitk_out_ptr

int heads_per_group = num_head_q / num_head_k;

if (heads_per_group == 1 || heads_per_group > 8) {
std::cout << "launch launch_attention_decode_bf16_dim128_smallm failed with "
<< " heads_per_group: " << heads_per_group << std::endl;
return false;
}

if (splitk == 4) {
constexpr int kSplitK = 4;
constexpr int kSplitMinLen = 4096;
if (heads_per_group == 8 || heads_per_group == 4) {
if (heads_per_group <= 8) {
constexpr int kHeadsPerGroup = 8;
if (block_size == 32) {
constexpr int kBlockSize = 32;
Expand All @@ -167,8 +175,11 @@ bool smallm_splitk_dim128_async(void *y_ptr, void *lse_ptr, void *splitk_out_ptr
num_seq_max_blocks, ldY, ldQ, ldK, ldV, stream);
}
using Tout = __nv_bfloat16;
dim3 grid(num_batch);
dim3 block(32 * num_head_q);

// Cap heads_per_block at 32 (= 1024 / 32) and tile extra heads in blockIdx.y.
int heads_per_block = std::min(num_head_q, 32);
dim3 grid(num_batch, (num_head_q + heads_per_block - 1) / heads_per_block);
dim3 block(32 * heads_per_block);
kernels::attention_decode_bf16_smallm_splitk_combine_kernel<Tout, kTileM, kTileV, kSplitK,
kSplitMinLen, kConsumers>
<<<grid, block, 0, stream>>>(reinterpret_cast<Tout *>(y_ptr),
Expand All @@ -179,7 +190,7 @@ bool smallm_splitk_dim128_async(void *y_ptr, void *lse_ptr, void *splitk_out_ptr
} else if (splitk == 16) {
constexpr int kSplitK = 16;
constexpr int kSplitMinLen = 512;
if (heads_per_group == 8 || heads_per_group == 4) {
if (heads_per_group <= 8) {
constexpr int kHeadsPerGroup = 8;
if (block_size == 32) {
constexpr int kBlockSize = 32;
Expand All @@ -199,8 +210,11 @@ bool smallm_splitk_dim128_async(void *y_ptr, void *lse_ptr, void *splitk_out_ptr
num_seq_max_blocks, ldY, ldQ, ldK, ldV, stream);
}
using Tout = __nv_bfloat16;
dim3 grid(num_batch);
dim3 block(32 * num_head_q);

// Cap heads_per_block at 32 (= 1024 / 32) and tile extra heads in blockIdx.y.
int heads_per_block = std::min(num_head_q, 32);
dim3 grid(num_batch, (num_head_q + heads_per_block - 1) / heads_per_block);
dim3 block(32 * heads_per_block);
kernels::attention_decode_bf16_smallm_splitk_combine_kernel<Tout, kTileM, kTileV, kSplitK,
kSplitMinLen, kConsumers>
<<<grid, block, 0, stream>>>(reinterpret_cast<Tout *>(y_ptr),
Expand Down
Loading