Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions hpc/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def attention_with_kvcache_prefill_fp8(
q: Tensor,
kcache: Tensor,
vcache: Tensor,
qkscale: Tensor,
qscale: Tensor,
kscale: Tensor,
vscale: Tensor,
cu_seqlens_q: Tensor,
block_ids: Tensor,
Expand All @@ -156,9 +157,12 @@ def attention_with_kvcache_prefill_fp8(
Constrainst the unused slots in last block of vcache for each request to be set zeros.
Shape: [num_blocks, block_size, num_head_kv, num_dim_v]
Dtype: fp8
qkscale: QK fp8 quant scale. Per Token Per Head Fp8 Quant.
qscale: QK fp8 quant scale. Per Token Per Head Fp8 Quant.
Shape: [num_batch, num_head_q, max_seqlens_q_pad]
Dtype: float32
kscale: K fp8 quant scale. Per Tensor Fp8 Quant.
Shape: [1]
Dtype: float32
vscale: V fp8 quant scale. Per Tensor Fp8 Quant.
Shape: [1]
Dtype: float32
Expand Down Expand Up @@ -193,7 +197,8 @@ def attention_with_kvcache_prefill_fp8(
q,
kcache,
vcache,
qkscale,
qscale,
kscale,
vscale,
cu_seqlens_q,
block_ids,
Expand Down Expand Up @@ -365,7 +370,8 @@ def attention_with_kvcache_prefill_fp8_fake(
q,
kcache,
vcache,
qkscale,
qscale,
kscale,
vscale,
cu_seqlens_q,
block_ids,
Expand Down
17 changes: 10 additions & 7 deletions src/attention/entry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,16 @@ torch::Tensor attention_with_kvcache_prefill_bf16_entry(

torch::Tensor attention_with_kvcache_prefill_fp8_entry(
const torch::Tensor &q, const torch::Tensor &kcache, const torch::Tensor &vcache,
const torch::Tensor &qkscale, const torch::Tensor &vscale, const torch::Tensor &cu_seqlens_q,
const torch::Tensor block_ids, const torch::Tensor seqlens_kvcache, int64_t max_seqlens_q,
const torch::Tensor &qscale, const torch::Tensor &kscale, const torch::Tensor &vscale,
const torch::Tensor &cu_seqlens_q, const torch::Tensor block_ids,
const torch::Tensor seqlens_kvcache, int64_t max_seqlens_q,
std::optional<torch::Tensor> output) {
auto stream = at::cuda::getCurrentCUDAStream(q.get_device());
TORCH_CHECK(q.device().is_cuda(), "q tensor must be cuda");
TORCH_CHECK(kcache.device().is_cuda(), "kcache tensor must be cuda");
TORCH_CHECK(vcache.device().is_cuda(), "vcache tensor must be cuda");
TORCH_CHECK(qkscale.device().is_cuda(), "qkscale tensor must be cuda");
TORCH_CHECK(qscale.device().is_cuda(), "qscale tensor must be cuda");
TORCH_CHECK(kscale.device().is_cuda(), "kscale tensor must be cuda");
TORCH_CHECK(vscale.device().is_cuda(), "vscale tensor must be cuda");
TORCH_CHECK(cu_seqlens_q.device().is_cuda(), "cu_seqlens_q tensor must be cuda");
TORCH_CHECK(block_ids.device().is_cuda(), "block_ids tensor must be cuda");
Expand All @@ -155,7 +157,7 @@ torch::Tensor attention_with_kvcache_prefill_fp8_entry(

int num_seq_max_blocks = block_ids.size(1);

int max_seqlens_q_pad = qkscale.size(2);
int max_seqlens_q_pad = qscale.size(2);

auto options = q.options().dtype(torch::kBFloat16);
torch::Tensor y;
Expand All @@ -171,7 +173,8 @@ torch::Tensor attention_with_kvcache_prefill_fp8_entry(
const auto *q_ptr = q.const_data_ptr();
const auto *kcache_ptr = kcache.const_data_ptr();
const auto *vcache_ptr = vcache.const_data_ptr();
const auto *qkscale_ptr = qkscale.const_data_ptr();
const auto *qscale_ptr = qscale.const_data_ptr();
const auto *kscale_ptr = kscale.const_data_ptr();
const auto *vscale_ptr = vscale.const_data_ptr();
const auto *cu_seqlens_q_ptr = cu_seqlens_q.const_data_ptr();
const auto *block_ids_ptr = block_ids.const_data_ptr();
Expand All @@ -187,7 +190,7 @@ torch::Tensor attention_with_kvcache_prefill_fp8_entry(
int ldY = y.stride(0); // num_head_q * num_dim_v;

attention_with_kvcache_prefill_fp8_async(
y_ptr, q_ptr, kcache_ptr, vcache_ptr, qkscale_ptr, vscale_ptr, cu_seqlens_q_ptr,
y_ptr, q_ptr, kcache_ptr, vcache_ptr, qscale_ptr, kscale_ptr, vscale_ptr, cu_seqlens_q_ptr,
block_ids_ptr, seqlens_kvcache_ptr, tmas_ptr, num_batch, total_seq_q, max_seqlens_q,
max_seqlens_q_pad, num_dim_qk, num_dim_v, num_head_q, num_head_kv, num_kvcache_blocks,
block_size, num_seq_max_blocks, ldY, ldQ, ldK, ldV, stream);
Expand Down Expand Up @@ -419,7 +422,7 @@ TORCH_LIBRARY_FRAGMENT(hpc, m) {

m.def(
"attention_with_kvcache_prefill_fp8(Tensor q, Tensor kcache, Tensor vcache,"
"Tensor qkscale, Tensor vscale, Tensor cu_seqlens_q,"
"Tensor qscale, Tensor kscale, Tensor vscale, Tensor cu_seqlens_q,"
"Tensor block_ids, Tensor num_seq_kvcache, int max_seqlens_q, Tensor? output) -> (Tensor)");
m.impl("attention_with_kvcache_prefill_fp8", torch::kCUDA,
&hpc::attention::attention_with_kvcache_prefill_fp8_entry);
Expand Down
14 changes: 7 additions & 7 deletions src/attention/prefill/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ struct AttentionKVCachePrefillFp8Config {
SLayoutVTAtom{}, make_shape(Int<kTileV>{}, Int<kTileN>{}, Int<kStage * kWarpgroupM>{})));
using SLayoutY =
decltype(tile_to_shape(SLayoutYAtom{}, make_shape(Int<kTileM>{}, Int<kTileV>{})));
using SLayoutQKS = decltype(make_layout(make_shape(Int<kTileM>{}), make_stride(Int<1>{})));
using SLayoutQS = decltype(make_layout(make_shape(Int<kTileM>{}), make_stride(Int<1>{})));

using CopyBoxK =
decltype(tile_to_shape(SLayoutKAtom{}, make_shape(Int<kBlockSize>{}, Int<kTileK>{})));
Expand All @@ -251,14 +251,14 @@ struct AttentionKVCachePrefillFp8Config {
using CopyBoxY = decltype(tile_to_shape(SLayoutYAtom{},
make_shape(Int<kTileM / kWarpgroupM>{}, Int<kTileV>{})));

template <typename TQ, typename TK, typename TV, typename TY, typename TQKS>
auto get_tma(TQ q, TK k, TV v, TY y, TQKS qks) {
template <typename TQ, typename TK, typename TV, typename TY, typename TQS>
auto get_tma(TQ q, TK k, TV v, TY y, TQS qs) {
auto tma_q = make_tma_copy(SM90_TMA_LOAD{}, q, SLayoutQ{});
auto tma_k = make_tma_copy(SM90_TMA_LOAD{}, k, CopyBoxK{});
auto tma_v = make_tma_copy(SM90_TMA_LOAD{}, v, CopyBoxV{});
auto tma_y = make_tma_copy(SM90_TMA_STORE{}, y, CopyBoxY{});
auto tma_qks = make_tma_copy(SM90_TMA_LOAD{}, qks, SLayoutQKS{});
return std::make_tuple(tma_q, tma_k, tma_v, tma_y, tma_qks);
auto tma_qs = make_tma_copy(SM90_TMA_LOAD{}, qs, SLayoutQS{});
return std::make_tuple(tma_q, tma_k, tma_v, tma_y, tma_qs);
}

using WarpgroupLayout =
Expand All @@ -270,8 +270,8 @@ struct AttentionKVCachePrefillFp8Config {
(cosize(SLayoutQ{}) + cosize(SLayoutK{}) + cosize(SLayoutV{}) + cosize(SLayoutVT{})) *
sizeof(Tin);
static constexpr int shm_y = cosize(SLayoutY{}) * sizeof(Tout);
static constexpr int shm_qks = cosize(SLayoutQKS{}) * sizeof(float);
static constexpr int shm_size = shm_qkv + shm_y + shm_qks;
static constexpr int shm_qs = cosize(SLayoutQS{}) * sizeof(float);
static constexpr int shm_size = shm_qkv + shm_y + shm_qs;

auto get_shm_size() { return shm_size; }
};
Expand Down
66 changes: 33 additions & 33 deletions src/attention/prefill/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,21 @@ __device__ __forceinline__ void online_softmax(ATensor &&tAttr_mn, MTensor &&gMa
}
}

template <typename ATensor, typename MTensor, typename STensor, typename YTensor,
typename QKSTensor>
template <typename ATensor, typename MTensor, typename STensor, typename YTensor, typename QSTensor>
__device__ __forceinline__ void online_softmax_with_scale(ATensor &&tAttr_mn, MTensor &&gMax,
STensor &&gSum, YTensor &&tYr_mn,
QKSTensor &&tQKS, int kM, int kN,
QSTensor &&tQS, int kM, int kN,
float one_over_dk_log2e) {
#pragma unroll
for (int im = 0; im < kM; ++im) {
float qks_im = tQKS[im];
tAttr_mn(im, 0) = tAttr_mn(im, 0) * qks_im;
float qs_im = tQS[im];
tAttr_mn(im, 0) = tAttr_mn(im, 0) * qs_im;
float row_max = tAttr_mn(im, 0);
float row_sum = 0.f;

#pragma unroll
for (int in = 1; in < kN; ++in) {
float local_max = tAttr_mn(im, in) * qks_im;
float local_max = tAttr_mn(im, in) * qs_im;
tAttr_mn(im, in) = local_max;
row_max = fmaxf(row_max, local_max);
}
Expand Down Expand Up @@ -992,17 +991,17 @@ __global__ void __launch_bounds__(384, 1)
}

template <typename Config, typename TmaQ, typename TmaK, typename TmaV, typename TmaY,
typename TmaQKS>
typename TmaQS>
__global__ void __launch_bounds__(384, 1)
attention_with_kvcache_prefill_fp8_warp_specialization_kernel(
cute::TmaDescriptor *td_qy, const __grid_constant__ TmaK tma_k,
const __grid_constant__ TmaV tma_v, const __grid_constant__ TmaQKS tma_qks,
const float *qkscale_ptr, const float *vscale_ptr, const int *cu_seqlens_q_ptr,
const int *seqlens_kvcache_ptr, const int *block_ids_ptr, int num_batch, int max_seq_q,
int max_seq_q_pad, int num_dim_qk, int num_dim_v, int num_head_q, int num_head_kv,
int num_kvcache_blocks, int block_size, int num_seq_max_blocks, float one_over_dk_log2e,
cutlass::FastDivmod head_kv_divmod, cutlass::FastDivmod head_q_divmod,
cutlass::FastDivmod tile_m_divmod) {
const __grid_constant__ TmaV tma_v, const __grid_constant__ TmaQS tma_qs,
const float *qscale_ptr, const float *kscale_ptr, const float *vscale_ptr,
const int *cu_seqlens_q_ptr, const int *seqlens_kvcache_ptr, const int *block_ids_ptr,
int num_batch, int max_seq_q, int max_seq_q_pad, int num_dim_qk, int num_dim_v,
int num_head_q, int num_head_kv, int num_kvcache_blocks, int block_size,
int num_seq_max_blocks, float one_over_dk_log2e, cutlass::FastDivmod head_kv_divmod,
cutlass::FastDivmod head_q_divmod, cutlass::FastDivmod tile_m_divmod) {
using namespace cute; // NOLINT

using Tin = typename Config::Tin;
Expand All @@ -1014,7 +1013,7 @@ __global__ void __launch_bounds__(384, 1)
using SLayoutV = typename Config::SLayoutV;
using SLayoutVT = typename Config::SLayoutVT;
using SLayoutY = typename Config::SLayoutY;
using SLayoutQKS = typename Config::SLayoutQKS;
using SLayoutQS = typename Config::SLayoutQS;

constexpr int kTileM = Config::kTileM;
constexpr int kTileN = Config::kTileN;
Expand Down Expand Up @@ -1044,8 +1043,8 @@ __global__ void __launch_bounds__(384, 1)
auto *shm_v = shm_k + cosize(SLayoutK{});
auto *shm_vt = shm_v + cosize(SLayoutV{});
auto *shm_y = reinterpret_cast<Tout *>(shm_vt + cosize(SLayoutVT{}));
auto *shm_qks = reinterpret_cast<float *>(shm_y + cosize(SLayoutY{}));
auto *shm_seqlens_q = reinterpret_cast<int *>(shm_qks + cosize(SLayoutQKS{}));
auto *shm_qs = reinterpret_cast<float *>(shm_y + cosize(SLayoutY{}));
auto *shm_seqlens_q = reinterpret_cast<int *>(shm_qs + cosize(SLayoutQS{}));
auto *shm_seqlens_kv = shm_seqlens_q + num_batch;
auto *shm_seqlens_qstart = shm_seqlens_kv + num_batch;

Expand All @@ -1059,7 +1058,7 @@ __global__ void __launch_bounds__(384, 1)
auto gV =
tma_v.get_tma_tensor(make_shape(num_dim_v, kBlockSize, num_head_kv, num_kvcache_blocks));
auto gY = tma_y.get_tma_tensor(make_shape(max_seq_q, num_dim_v, num_head_q));
auto gQKS = tma_qks.get_tma_tensor(make_shape(max_seq_q_pad, num_head_q, num_batch));
auto gQS = tma_qs.get_tma_tensor(make_shape(max_seq_q_pad, num_head_q, num_batch));

auto gAtt =
make_tensor(make_gmem_ptr(static_cast<float *>(nullptr)),
Expand All @@ -1077,25 +1076,25 @@ __global__ void __launch_bounds__(384, 1)
auto sV = make_tensor(make_smem_ptr(shm_v), SLayoutV{});
auto sVT = make_tensor(make_smem_ptr(shm_vt), SLayoutVT{});
auto sY = make_tensor(make_smem_ptr(shm_y), SLayoutY{});
auto sQKS = make_tensor(make_smem_ptr(shm_qks), SLayoutQKS{});
auto sQS = make_tensor(make_smem_ptr(shm_qs), SLayoutQS{});

// Block Level tma
auto btma_q = tma_q.get_slice(0);
auto btma_k = tma_k.get_slice(0);
auto btma_v = tma_v.get_slice(0);
auto btma_y = tma_y.get_slice(0);
auto btma_qks = tma_qks.get_slice(0);
auto btma_qs = tma_qs.get_slice(0);

// Thread Level Tensor
auto tQg = btma_q.partition_S(gQ); // (TMA, TMA_M, TMA_K, head, batch)
auto tKg = btma_k.partition_S(gK); // (TMA, TMA_N, TMA_K, head, batch)
auto tVg = btma_v.partition_S(gV); // (TMA, TMA_V, TMA_N, head, batch)
auto tQKSg = btma_qks.partition_S(gQKS); // (TMA, TMA_M, head, batch)
auto tQg = btma_q.partition_S(gQ); // (TMA, TMA_M, TMA_K, head, batch)
auto tKg = btma_k.partition_S(gK); // (TMA, TMA_N, TMA_K, head, batch)
auto tVg = btma_v.partition_S(gV); // (TMA, TMA_V, TMA_N, head, batch)
auto tQSg = btma_qs.partition_S(gQS); // (TMA, TMA_M, head, batch)

auto tQs = btma_q.partition_D(sQ); // (TMA, _1, _1)
auto tKs = btma_k.partition_D(sK); // (TMA, _1, _1, kStage)
auto tVs = btma_v.partition_D(sV); // (TMA, _1, _1, kStage)
auto tQKSs = btma_qks.partition_D(sQKS); // (TMA, _1)
auto tQs = btma_q.partition_D(sQ); // (TMA, _1, _1)
auto tKs = btma_k.partition_D(sK); // (TMA, _1, _1, kStage)
auto tVs = btma_v.partition_D(sV); // (TMA, _1, _1, kStage)
auto tQSs = btma_qs.partition_D(sQS); // (TMA, _1)

TiledMmaQK tiled_mma_qk;
TiledMmaPV tiled_mma_pv;
Expand Down Expand Up @@ -1180,9 +1179,9 @@ __global__ void __launch_bounds__(384, 1)
// Load Q
wait_barrier(writable_q, phase_q);
cute::copy(tma_q.with(td_q, readable_q), tQg(_, itile_m, _, ihead_q), tQs(_, 0, _));
cute::copy(tma_qks.with(readable_q), tQKSg(_, itile_m, ihead_q, ibatch), tQKSs(_, 0));
cute::copy(tma_qs.with(readable_q), tQSg(_, itile_m, ihead_q, ibatch), tQSs(_, 0));
set_barrier_transaction_bytes(
readable_q, sizeof(Tin) * cosize(SLayoutQ{}) + sizeof(float) * cosize(SLayoutQKS{}));
readable_q, sizeof(Tin) * cosize(SLayoutQ{}) + sizeof(float) * cosize(SLayoutQS{}));
phase_q ^= 1;

int num_tile_kv = (start_seq_q + (itile_m + 1) * kTileM + kTileN - 1) / kTileN;
Expand Down Expand Up @@ -1270,7 +1269,8 @@ __global__ void __launch_bounds__(384, 1)
int phase = 0;
int phase_q = 0;

float tQKS[kM];
float tQS[kM];
float kscale = kscale_ptr[0];
float vscale = vscale_ptr[0];

while (true) {
Expand Down Expand Up @@ -1303,7 +1303,7 @@ __global__ void __launch_bounds__(384, 1)
auto tI_mn = retile_fragment(tI);
#pragma unroll
for (int im = 0; im < kM; im++) {
tQKS[im] = sQKS(get<0>(tI_mn(im, 0)));
tQS[im] = sQS(get<0>(tI_mn(im, 0))) * kscale;
}

#pragma unroll 1
Expand Down Expand Up @@ -1355,7 +1355,7 @@ __global__ void __launch_bounds__(384, 1)

auto tYr_mn = retile_fragment(tYr);
// online softmax
online_softmax_with_scale(tAttr_mn, gMax, gSum, tYr_mn, tQKS, kM, kN, one_over_dk_log2e);
online_softmax_with_scale(tAttr_mn, gMax, gSum, tYr_mn, tQS, kM, kN, one_over_dk_log2e);

// convert P to fp8 and permute for pv gemm
auto tAttr_float32x4 = recast<float4>(tAttr);
Expand Down
13 changes: 7 additions & 6 deletions src/attention/prefill/prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,14 @@ void attention_with_kvcache_prefill_bf16_async(

void attention_with_kvcache_prefill_fp8_async(
void *y_ptr, const void *q_ptr, const void *kcache_ptr, const void *vcache_ptr,
const void *qkscale_ptr, const void *vscale_ptr, const void *cu_seqlens_q_ptr,
const void *block_ids_ptr, const void *seqlens_kvcache_ptr, void *tmas_ptr, int num_batch,
int total_seq_q, int max_seq_q, int max_seq_q_pad, int num_dim_qk, int num_dim_v,
int num_head_q, int num_head_kv, int num_kvcache_blocks, int block_size, int num_seq_max_blocks,
int ldY, int ldQ, int ldK, int ldV, cudaStream_t stream) {
const void *qscale_ptr, const void *kscale_ptr, const void *vscale_ptr,
const void *cu_seqlens_q_ptr, const void *block_ids_ptr, const void *seqlens_kvcache_ptr,
void *tmas_ptr, int num_batch, int total_seq_q, int max_seq_q, int max_seq_q_pad,
int num_dim_qk, int num_dim_v, int num_head_q, int num_head_kv, int num_kvcache_blocks,
int block_size, int num_seq_max_blocks, int ldY, int ldQ, int ldK, int ldV,
cudaStream_t stream) {
prefill::warp_spec_with_kvcache_fp8_dim128_async(
y_ptr, q_ptr, kcache_ptr, vcache_ptr, qkscale_ptr, vscale_ptr, cu_seqlens_q_ptr,
y_ptr, q_ptr, kcache_ptr, vcache_ptr, qscale_ptr, kscale_ptr, vscale_ptr, cu_seqlens_q_ptr,
block_ids_ptr, seqlens_kvcache_ptr, tmas_ptr, num_batch, total_seq_q, max_seq_q,
max_seq_q_pad, num_dim_qk, num_dim_v, num_head_q, num_head_kv, num_kvcache_blocks, block_size,
num_seq_max_blocks, ldY, ldQ, ldK, ldV, stream);
Expand Down
11 changes: 6 additions & 5 deletions src/attention/prefill/prefill.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ void attention_with_kvcache_prefill_bf16_async(

void attention_with_kvcache_prefill_fp8_async(
void *y_ptr, const void *q_ptr, const void *kcache_ptr, const void *vcache_ptr,
const void *qkscale_ptr, const void *vscale_ptr, const void *cu_seqlens_q_ptr,
const void *block_ids_ptr, const void *seqlens_kvcache_ptr, void *tmas_ptr, int num_batch,
int total_seq_q, int max_seq_q, int max_seq_q_pad, int num_dim_qk, int num_dim_v,
int num_head_q, int num_head_kv, int num_kvcache_blocks, int block_size, int num_seq_max_blocks,
int ldY, int ldQ, int ldK, int ldV, cudaStream_t stream);
const void *qscale_ptr, const void *kscale_ptr, const void *vscale_ptr,
const void *cu_seqlens_q_ptr, const void *block_ids_ptr, const void *seqlens_kvcache_ptr,
void *tmas_ptr, int num_batch, int total_seq_q, int max_seq_q, int max_seq_q_pad,
int num_dim_qk, int num_dim_v, int num_head_q, int num_head_kv, int num_kvcache_blocks,
int block_size, int num_seq_max_blocks, int ldY, int ldQ, int ldK, int ldV,
cudaStream_t stream);

} // namespace attention
} // namespace hpc
Expand Down
Loading