diff --git a/hpc/rope.py b/hpc/rope.py new file mode 100644 index 0000000..c4fc863 --- /dev/null +++ b/hpc/rope.py @@ -0,0 +1,327 @@ +from typing import Optional, Tuple + +import torch +from torch import Tensor + + +def rope_norm_store_kv( + key_cache: Tensor, + value_cache: Tensor, + qkv: Tensor, + cos_sin: Tensor, + num_seqlen_per_req: Tensor, + q_index: Tensor, + kvcache_indices: Tensor, + is_prefill: bool, + q_norm_weight: Optional[Tensor] = None, + k_norm_weight: Optional[Tensor] = None, + out_q: Optional[Tensor] = None, + out_k: Optional[Tensor] = None, + out_v: Optional[Tensor] = None, + qk_norm_policy: int = 0, +) -> Tensor: + """Applies RoPE to Q/K, optionally applies QK RMSNorm, and writes K/V into a paged KV cache. + + This function fuses RoPE rotation, optional QK RMSNorm, and blocked KV-cache writes + into a single CUDA kernel pass, supporting both prefill and decode modes. + + Args: + key_cache: Paged key cache to be updated in-place. + Shape: [num_blocks, block_size, num_kv_heads, qk_head_dim] + Dtype: bfloat16 + value_cache: Paged value cache to be updated in-place. + Shape: [num_blocks, block_size, num_kv_heads, v_head_dim] + Dtype: bfloat16 + qkv: Packed Q/K/V input tensor. + Shape: [num_rows, num_q_heads * qk_head_dim + num_kv_heads * qk_head_dim + num_kv_heads * v_head_dim] + Dtype: bfloat16 + cos_sin: Precomputed RoPE cosine/sine table. + Shape: [max_seq_len, qk_head_dim] + Dtype: float32 + num_seqlen_per_req: Current total sequence length (including new tokens) for each request. + Shape: [num_req] + Dtype: int32 + q_index: Prefix-sum index of Q tokens across requests. + Shape: [num_req + 1] + Dtype: int32 + kvcache_indices: Physical block index table for paged KV cache addressing. + Shape: [num_req, max_blocks] + Dtype: int32 + is_prefill: Whether to run in prefill mode (True) or decode mode (False). + Shape: scalar + Dtype: bool + q_norm_weight: RMSNorm weight for Q. Required when qk_norm_policy != 0. + Shape: [qk_head_dim] + Dtype: float32 + k_norm_weight: RMSNorm weight for K. Required when qk_norm_policy != 0. + Shape: [qk_head_dim] + Dtype: float32 + out_q: Optional pre-allocated output buffer for Q. + Shape: [num_rows, num_q_heads, qk_head_dim] + Dtype: bfloat16 + out_k: Optional output buffer for K. If provided, K is written here instead of key_cache. + Shape: [num_rows, num_kv_heads, qk_head_dim] + Dtype: bfloat16 + out_v: Optional output buffer for V. If provided, V is written here instead of value_cache. + Shape: [num_rows, num_kv_heads, v_head_dim] + Dtype: bfloat16 + qk_norm_policy: Controls whether RMSNorm is applied and its order relative to RoPE. + Shape: scalar + Dtype: int + - 0: No RMSNorm. + - 1: RoPE then RMSNorm. + - 2: RMSNorm then RoPE. + + Returns: + Tensor: Rotated (and optionally normalized) Q tensor. + Shape: [num_rows, num_q_heads, qk_head_dim] + Dtype: bfloat16 + + Raises: + RuntimeError: If the shapes or dtypes do not satisfy the constraints above. + """ + return torch.ops.hpc.rope_norm_store_kv( + key_cache, + value_cache, + qkv, + cos_sin, + num_seqlen_per_req, + q_index, + kvcache_indices, + is_prefill, + q_norm_weight, + k_norm_weight, + out_q, + out_k, + out_v, + qk_norm_policy, + ) + + +def rope_norm_store_kv_fp8( + key_cache: Tensor, + value_cache: Tensor, + qkv: Tensor, + cos_sin: Tensor, + num_seqlen_per_req: Tensor, + q_index: Tensor, + kvcache_indices: Tensor, + is_prefill: bool, + k_scale: Tensor, + v_scale: Tensor, + quant_policy: int, + max_seqlens: int = 0, + upper_max: Optional[float] = None, + q_scale_inv: Optional[Tensor] = None, + q_norm_weight: Optional[Tensor] = None, + k_norm_weight: Optional[Tensor] = None, + out_q: Optional[Tensor] = None, + out_k: Optional[Tensor] = None, + out_v: Optional[Tensor] = None, + qk_norm_policy: int = 0, +) -> Tuple[Tensor, Tensor, Tensor]: + """Applies RoPE to Q/K with FP8 quantization, optionally applies QK RMSNorm, and writes K/V into a paged FP8 KV cache. + + Extends rope_norm_store_kv with FP8 quantization for Q output and KV cache storage, + supporting dynamic per-token per-head (dqskv) and static (sqskv) quantization policies. + + Args: + key_cache: Paged key cache to be updated in-place. + Shape: [num_blocks, block_size, num_kv_heads, qk_head_dim] + Dtype: float8_e4m3fn + value_cache: Paged value cache to be updated in-place. + Shape: [num_blocks, block_size, num_kv_heads, v_head_dim] + Dtype: float8_e4m3fn + qkv: Packed Q/K/V input tensor. + Shape: [num_rows, num_q_heads * qk_head_dim + num_kv_heads * qk_head_dim + num_kv_heads * v_head_dim] + Dtype: bfloat16 + cos_sin: Precomputed RoPE cosine/sine table. + Shape: [max_seq_len, qk_head_dim] + Dtype: float32 + num_seqlen_per_req: Current total sequence length (including new tokens) for each request. + Shape: [num_req] + Dtype: int32 + q_index: Prefix-sum index of Q tokens across requests. + Shape: [num_req + 1] + Dtype: int32 + kvcache_indices: Physical block index table for paged KV cache addressing. + Shape: [num_req, max_blocks] + Dtype: int32 + is_prefill: Whether to run in prefill mode (True) or decode mode (False). + Shape: scalar + Dtype: bool + k_scale: Static quantization scale for K. Per-tensor. + Shape: [1] + Dtype: float32 + v_scale: Static quantization scale for V. Per-tensor. + Shape: [1] + Dtype: float32 + quant_policy: Q quantization mode. K/V always use static scaling. + Shape: scalar + Dtype: int + - 1: dqskv — dynamic per-token per-head quantization; scale computed by the kernel + and written to the returned q_scale tensor. + - 2: sqskv — static quantization; uses the caller-supplied q_scale_inv. + max_seqlens: Maximum sequence length in the batch. Used to size the q_scale allocation + in prefill mode (padded to a multiple of 128). + Shape: scalar + Dtype: int + upper_max: FP8 saturation upper bound. Defaults to FP8_MAX (~448.0). + Shape: scalar + Dtype: float + q_scale_inv: Static scale reciprocal for Q. Required when quant_policy=2. + Shape: [1] + Dtype: float32 + q_norm_weight: RMSNorm weight for Q. Required when qk_norm_policy != 0. + Shape: [qk_head_dim] + Dtype: float32 + k_norm_weight: RMSNorm weight for K. Required when qk_norm_policy != 0. + Shape: [qk_head_dim] + Dtype: float32 + out_q: Optional pre-allocated output buffer for Q. + Shape: [num_rows, num_q_heads, qk_head_dim] + Dtype: float8_e4m3fn + out_k: Optional output buffer for K. If provided, K is written here instead of key_cache. + Shape: [num_rows, num_kv_heads, qk_head_dim] + Dtype: float8_e4m3fn + out_v: Optional output buffer for V. If provided, V is written here instead of value_cache. + Shape: [num_rows, num_kv_heads, v_head_dim] + Dtype: float8_e4m3fn + qk_norm_policy: Controls whether RMSNorm is applied and its order relative to RoPE. + Shape: scalar + Dtype: int + - 0: No RMSNorm. + - 1: RoPE then RMSNorm. + - 2: RMSNorm then RoPE. + + Returns: + Tuple of: + - out_q_fp8 (Tensor): Rotated (and optionally normalized) Q tensor quantized to FP8. + Shape: [num_rows, num_q_heads, qk_head_dim] + Dtype: float8_e4m3fn + - q_scale (Tensor): Dynamic per-token per-head Q scale (dqskv only). + Prefill shape: [num_req, num_q_heads, max_seqlens_pad128]; Decode shape: [num_rows, num_q_heads]. + Empty tensor when quant_policy=2. + Dtype: float32 + - split_k_flag (Tensor): Per-request per-KV-head flag zeroed by the kernel, used by downstream attention. + Shape: [num_req, num_kv_heads] + Dtype: int32 + + Raises: + RuntimeError: If the shapes or dtypes do not satisfy the constraints above. + """ + return torch.ops.hpc.rope_norm_store_kv_fp8( + key_cache, + value_cache, + qkv, + cos_sin, + num_seqlen_per_req, + q_index, + kvcache_indices, + is_prefill, + k_scale, + v_scale, + quant_policy, + max_seqlens, + upper_max, + q_scale_inv, + q_norm_weight, + k_norm_weight, + out_q, + out_k, + out_v, + qk_norm_policy, + ) + + +@torch.library.register_fake("hpc::rope_norm_store_kv") +def rope_norm_store_kv_fake( + key_cache, + value_cache, + qkv, + cos_sin, + num_seqlen_per_req, + q_index, + kvcache_indices, + is_prefill, + q_norm_weight, + k_norm_weight, + out_q, + out_k, + out_v, + qk_norm_policy, +): + hidden_size = qkv.shape[-1] + kv_heads = key_cache.shape[-2] + qk_head_dim = key_cache.shape[-1] + v_head_dim = value_cache.shape[-1] + q_heads = (hidden_size - kv_heads * qk_head_dim - kv_heads * v_head_dim) // qk_head_dim + num_rows = qkv.shape[0] + return torch.empty(num_rows, q_heads, qk_head_dim, dtype=qkv.dtype, device=qkv.device) + + +@torch.library.register_fake("hpc::rope_norm_store_kv_fp8") +def rope_norm_store_kv_fp8_fake( + key_cache, + value_cache, + qkv, + cos_sin, + num_seqlen_per_req, + q_index, + kvcache_indices, + is_prefill, + k_scale, + v_scale, + quant_policy, + max_seqlens, + upper_max, + q_scale_inv, + q_norm_weight, + k_norm_weight, + out_q, + out_k, + out_v, + qk_norm_policy, +): + num_rows = qkv.shape[0] + qk_dim = key_cache.shape[-1] + kv_heads = key_cache.shape[-2] + v_dim = value_cache.shape[-1] + num_req = num_seqlen_per_req.shape[0] + q_heads = (qkv.shape[-1] - kv_heads * qk_dim - kv_heads * v_dim) // qk_dim + + out_q_fp8 = torch.empty( + num_rows, + q_heads, + qk_dim, + dtype=torch.float8_e4m3fn, + device=qkv.device, + ) + + if quant_policy == 1: # dq skv + if is_prefill: + aligned = ((max_seqlens + 127) // 128) * 128 + q_scale = torch.empty( + num_req, + q_heads, + aligned, + dtype=torch.float32, + device=qkv.device, + ) + else: + q_scale = torch.empty( + num_rows, + q_heads, + dtype=torch.float32, + device=qkv.device, + ) + else: + q_scale = None + + split_k_flag = torch.empty( + num_req, + kv_heads, + dtype=torch.int32, + device=qkv.device, + ) + return (out_q_fp8, q_scale, split_k_flag) diff --git a/src/rope/entry.cc b/src/rope/entry.cc new file mode 100644 index 0000000..7468e3c --- /dev/null +++ b/src/rope/entry.cc @@ -0,0 +1,240 @@ +// Copyright (C) 2026 Tencent. + +#include +#include +#include +#include + +#include +#include + +#include "src/rope/rope.h" + +namespace hpc { +namespace rope { + +torch::Tensor rope_norm_store_kv_entry( + torch::Tensor &kcache, torch::Tensor &vcache, const torch::Tensor &qkv, + const torch::Tensor &cos_sin, const torch::Tensor &num_seqlen_per_req, + const torch::Tensor &q_index, const torch::Tensor &kvcache_indices, bool is_prefill, + std::optional q_norm_weight_opt, std::optional k_norm_weight_opt, + std::optional out_q_opt, std::optional out_k_opt, + std::optional out_v_opt, int64_t qk_norm_policy) { + auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device()); + TORCH_CHECK(qkv.is_contiguous(), "qkv tensor must be contiguous"); + TORCH_CHECK(cos_sin.is_contiguous(), "cos_sin tensor must be contiguous"); + TORCH_CHECK(num_seqlen_per_req.is_contiguous(), "num_seqlen_per_req tensor must be contiguous"); + TORCH_CHECK(kvcache_indices.is_contiguous(), "kvcache_indices tensor must be contiguous"); + + TORCH_CHECK(qk_norm_policy >= 0 && qk_norm_policy <= 2, "qk_norm_policy must be 0, 1 or 2"); + + // Get dimensions + int num_req = num_seqlen_per_req.size(0); + int num_rows = qkv.size(0); + int num_kv_heads = kcache.size(2); + int qk_head_dim = kcache.size(3); + int v_head_dim = vcache.size(3); + int hidden_size = qkv.size(1); + int num_q_heads = + (hidden_size - num_kv_heads * qk_head_dim - num_kv_heads * v_head_dim) / qk_head_dim; + int kv_block_size = kcache.size(1); + int max_num_kv_block_per_batch = kvcache_indices.size(1); + int kcache_block_offset = kcache.stride(0); + int vcache_block_offset = vcache.stride(0); + + // Create output tensors + using DType = __nv_bfloat16; + torch::Tensor out_q; + if (out_q_opt.has_value()) { + out_q = out_q_opt.value(); + TORCH_CHECK(out_q.is_contiguous(), "out_q tensor must be contiguous"); + } else { + out_q = torch::empty({num_rows, num_q_heads, qk_head_dim}, + torch::dtype(qkv.dtype()).device(qkv.device())); + } + + DType *out_k_ptr = nullptr; + if (out_k_opt.has_value()) { + TORCH_CHECK(out_k_opt.value().is_contiguous(), "out_k tensor must be contiguous"); + out_k_ptr = reinterpret_cast(out_k_opt.value().mutable_data_ptr()); + } + + DType *out_v_ptr = nullptr; + if (out_v_opt.has_value()) { + auto out_v = out_v_opt.value(); + TORCH_CHECK(out_v.is_contiguous(), "out_v tensor must be contiguous"); + out_v_ptr = reinterpret_cast(out_v.mutable_data_ptr()); + } + + const float *q_norm_weight_ptr = nullptr; + const float *k_norm_weight_ptr = nullptr; + if (q_norm_weight_opt.has_value()) { + TORCH_CHECK(q_norm_weight_opt.value().scalar_type() == torch::kFloat); + q_norm_weight_ptr = q_norm_weight_opt.value().const_data_ptr(); + } + if (k_norm_weight_opt.has_value()) { + TORCH_CHECK(k_norm_weight_opt.value().scalar_type() == torch::kFloat); + k_norm_weight_ptr = k_norm_weight_opt.value().const_data_ptr(); + } + + rope_norm_store_kv_async( + reinterpret_cast(out_q.mutable_data_ptr()), + reinterpret_cast(kcache.mutable_data_ptr()), + reinterpret_cast(vcache.mutable_data_ptr()), out_k_ptr, out_v_ptr, + reinterpret_cast(qkv.const_data_ptr()), cos_sin.const_data_ptr(), + num_seqlen_per_req.const_data_ptr(), q_index.const_data_ptr(), + kvcache_indices.const_data_ptr(), q_norm_weight_ptr, k_norm_weight_ptr, + kcache_block_offset, vcache_block_offset, num_req, max_num_kv_block_per_batch, kv_block_size, + num_rows, num_q_heads, num_kv_heads, qk_head_dim, v_head_dim, is_prefill, qk_norm_policy, + stream); + + return out_q; +} + +std::tuple rope_norm_store_kv_fp8_entry( + torch::Tensor &kcache, torch::Tensor &vcache, const torch::Tensor &qkv, + const torch::Tensor &cos_sin, const torch::Tensor &num_seqlen_per_req, + const torch::Tensor &q_index, const torch::Tensor &kvcache_indices, bool is_prefill, + const torch::Tensor &k_scale, const torch::Tensor &v_scale, int64_t quant_policy, + int64_t max_seqlens, std::optional upper_max_double, + std::optional q_scale_inv_opt, std::optional q_norm_weight_opt, + std::optional k_norm_weight_opt, std::optional out_q_opt, + std::optional out_k_opt, std::optional out_v_opt, + int64_t qk_norm_policy) { + auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device()); + TORCH_CHECK(qkv.is_contiguous(), "qkv tensor must be contiguous"); + TORCH_CHECK(cos_sin.is_contiguous(), "cos_sin tensor must be contiguous"); + TORCH_CHECK(num_seqlen_per_req.is_contiguous(), "num_seqlen_per_req tensor must be contiguous"); + TORCH_CHECK(kvcache_indices.is_contiguous(), "kvcache_indices tensor must be contiguous"); + TORCH_CHECK(k_scale.dim() == 1 && k_scale.size(0) == 1, "k_scale must contain 1 element"); + TORCH_CHECK(v_scale.dim() == 1 && v_scale.size(0) == 1, "v_scale must contain 1 element"); + TORCH_CHECK(quant_policy == 1 || quant_policy == 2, "quant_policy must be 1 or 2"); + TORCH_CHECK(qkv.scalar_type() == torch::kBFloat16, "qkv must be bfloat16"); + TORCH_CHECK(kcache.dtype().itemsize() == 1, "kcache must be 1-byte dtype"); + TORCH_CHECK(vcache.dtype().itemsize() == 1, "vcache must be 1-byte dtype"); + + TORCH_CHECK(qk_norm_policy >= 0 && qk_norm_policy <= 2, "qk_norm_policy must be 0, 1 or 2"); + + using DType = __nv_bfloat16; + using QType = __nv_fp8_e4m3; + + int num_req = num_seqlen_per_req.size(0); + int num_rows = qkv.size(0); + int num_kv_heads = kcache.size(2); + int qk_head_dim = kcache.size(3); + int v_head_dim = vcache.size(3); + int hidden_size = qkv.size(1); + int num_q_heads = + (hidden_size - num_kv_heads * qk_head_dim - num_kv_heads * v_head_dim) / qk_head_dim; + int kv_block_size = kcache.size(1); + int max_num_kv_block_per_batch = kvcache_indices.size(1); + int kcache_block_offset = kcache.stride(0); + int vcache_block_offset = vcache.stride(0); + + float upper_max = static_cast(QType(1000.f)); + if (upper_max_double.has_value()) { + float in_upper_max = static_cast(upper_max_double.value()); + TORCH_CHECK(!(in_upper_max > upper_max), "upper_max should not be larger than fp8_max"); + upper_max = in_upper_max; + } + + // out_q + torch::Tensor out_q; + if (out_q_opt.has_value()) { + out_q = out_q_opt.value(); + TORCH_CHECK(out_q.is_contiguous() && out_q.scalar_type() == torch::kFloat8_e4m3fn); + } else { + out_q = torch::empty({num_rows, num_q_heads, qk_head_dim}, + torch::dtype(torch::kFloat8_e4m3fn).device(qkv.device())); + } + + // q_scale: dqskv allocates real storage, sqskv gets an empty tensor + torch::Tensor q_scale; + float *q_scale_ptr = nullptr; + int max_seqlens_pad128 = 0; + if (quant_policy == 1) { + if (is_prefill) { + max_seqlens_pad128 = ((max_seqlens + 127) / 128) * 128; + q_scale = torch::empty({num_req, num_q_heads, max_seqlens_pad128}, + torch::dtype(torch::kFloat).device(qkv.device())); + } else { + q_scale = + torch::empty({num_rows, num_q_heads}, torch::dtype(torch::kFloat).device(qkv.device())); + } + q_scale_ptr = q_scale.mutable_data_ptr(); + } + + // split_k_flag + torch::Tensor split_k_flag = + torch::empty({num_req, num_kv_heads}, torch::dtype(torch::kInt32).device(qkv.device())); + + // out_k, out_v (nullable bypass) + QType *out_k_ptr = nullptr; + QType *out_v_ptr = nullptr; + if (out_k_opt.has_value()) { + auto out_k = out_k_opt.value(); + TORCH_CHECK(out_k.is_contiguous() && out_k.scalar_type() == torch::kFloat8_e4m3fn); + out_k_ptr = reinterpret_cast(out_k.mutable_data_ptr()); + } + if (out_v_opt.has_value()) { + auto out_v = out_v_opt.value(); + TORCH_CHECK(out_v.is_contiguous() && out_v.scalar_type() == torch::kFloat8_e4m3fn); + out_v_ptr = reinterpret_cast(out_v.mutable_data_ptr()); + } + + const float *q_norm_weight_ptr = nullptr; + const float *k_norm_weight_ptr = nullptr; + if (q_norm_weight_opt.has_value()) { + TORCH_CHECK(q_norm_weight_opt.value().scalar_type() == torch::kFloat); + q_norm_weight_ptr = q_norm_weight_opt.value().const_data_ptr(); + } + if (k_norm_weight_opt.has_value()) { + TORCH_CHECK(k_norm_weight_opt.value().scalar_type() == torch::kFloat); + k_norm_weight_ptr = k_norm_weight_opt.value().const_data_ptr(); + } + + const float *q_scale_inv_ptr = nullptr; + if (quant_policy == 2) { + TORCH_CHECK(q_scale_inv_opt.has_value(), "q_scale_inv required for quant_policy=2"); + TORCH_CHECK(q_scale_inv_opt.value().scalar_type() == torch::kFloat); + q_scale_inv_ptr = q_scale_inv_opt.value().const_data_ptr(); + } + + rope_norm_store_kv_fp8_async( + reinterpret_cast(out_q.mutable_data_ptr()), + reinterpret_cast(kcache.mutable_data_ptr()), + reinterpret_cast(vcache.mutable_data_ptr()), out_k_ptr, out_v_ptr, + split_k_flag.mutable_data_ptr(), q_scale_ptr, + reinterpret_cast(qkv.const_data_ptr()), cos_sin.const_data_ptr(), + num_seqlen_per_req.const_data_ptr(), q_index.const_data_ptr(), + kvcache_indices.const_data_ptr(), q_norm_weight_ptr, k_norm_weight_ptr, + k_scale.const_data_ptr(), v_scale.const_data_ptr(), q_scale_inv_ptr, upper_max, + max_seqlens, kcache_block_offset, vcache_block_offset, num_req, max_num_kv_block_per_batch, + kv_block_size, num_rows, num_q_heads, num_kv_heads, qk_head_dim, v_head_dim, is_prefill, + qk_norm_policy, quant_policy, stream); + + return std::make_tuple(out_q, q_scale, split_k_flag); +} + +} // namespace rope +} // namespace hpc + +TORCH_LIBRARY_FRAGMENT(hpc, m) { + m.def( + "rope_norm_store_kv(Tensor! kcache, Tensor! vcache, Tensor qkv, Tensor cos_sin, " + "Tensor num_seqlen_per_req, Tensor q_index, Tensor kvcache_indices, bool is_prefill, " + "Tensor? q_norm_weight, Tensor? k_norm_weight, " + "Tensor? out_q=None, Tensor? out_k=None, Tensor? out_v=None, int qk_norm_policy=0) -> " + "Tensor"); + m.impl("rope_norm_store_kv", torch::kCUDA, &hpc::rope::rope_norm_store_kv_entry); + + m.def( + "rope_norm_store_kv_fp8(Tensor! kcache, Tensor! vcache, Tensor qkv, " + "Tensor cos_sin, Tensor num_seqlen_per_req, Tensor q_index, Tensor kvcache_indices, " + "bool is_prefill, Tensor k_scale, Tensor v_scale, " + "int quant_policy, int max_seqlens, float? upper_max, Tensor? q_scale_inv, " + "Tensor? q_norm_weight, Tensor? k_norm_weight, " + "Tensor? out_q=None, Tensor? out_k=None, Tensor? out_v=None, int qk_norm_policy=0) -> " + "(Tensor, Tensor, Tensor)"); + m.impl("rope_norm_store_kv_fp8", torch::kCUDA, &hpc::rope::rope_norm_store_kv_fp8_entry); +} diff --git a/src/rope/rope.cu b/src/rope/rope.cu new file mode 100644 index 0000000..51aa945 --- /dev/null +++ b/src/rope/rope.cu @@ -0,0 +1,802 @@ +// Copyright (C) 2026 Tencent. + +#include +#include +#include +#include + +#include + +#include "cutlass/fast_math.h" +#include "src/rope/rope.h" +#include "src/utils/utils.cuh" + +namespace hpc { +namespace rope { + +namespace kernels { + +template +__device__ __forceinline__ constexpr int ceil_div() { + static_assert(kDenominator > 0, "denominator must >0"); + return (kNumerator + kDenominator - 1) / kDenominator; +} + +constexpr float kEps = 1e-6f; + +/// In-place rotate a pair of RoPE elements (NeoX version) +__device__ __forceinline__ void rope_rotate_pair(float &x1, float &x2, float cos_val, + float sin_val) { + float y1 = x1 * cos_val - x2 * sin_val; + float y2 = x2 * cos_val + x1 * sin_val; + x1 = y1; + x2 = y2; +} + +/// RMSNorm in-place: compute RMS over register values, apply weight from shared memory +template +__device__ __forceinline__ void rms_norm_apply(vec_t &data, const float *smem_weight, + int ilane) { + float sum_sq = 0.f; +#pragma unroll + for (int i = 0; i < kNumItemPerThread; ++i) sum_sq += data[i] * data[i]; + sum_sq = warp_reduce_sum_xor(sum_sq); + float inv_rms = rsqrtf(sum_sq / kHeadDim + kEps); + constexpr int kRoundsHalf = (kHeadDim / 2 + kWarpSize - 1) / kWarpSize; +#pragma unroll + for (int r = 0; r < kRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kHeadDim / 2) { + data[r * 2] *= inv_rms * smem_weight[i]; + data[r * 2 + 1] *= inv_rms * smem_weight[i + kHeadDim / 2]; + } + } +} + +/// Warp-level max absolute value +template +__device__ __forceinline__ float warp_abs_max(vec_t &data) { + float m = kEps; +#pragma unroll + for (int i = 0; i < kN; ++i) m = fmaxf(m, fabsf(data[i])); + return warp_reduce_max_xor(m); +} + +/// Zero rows [from_row, to_row) of a KV cache block +template +__device__ __forceinline__ void zero_kv_rows(CacheT *block_start, int from_row, int to_row, + int ilane) { + constexpr int kItemPerThread = 16 / sizeof(CacheT); + vec_t zero_vec; +#pragma unroll + for (int i = 0; i < kItemPerThread; ++i) zero_vec[i] = CacheT(0); + for (int row = from_row; row < to_row; ++row) { + CacheT *row_ptr = block_start + row * kElemPerRow; + for (int idx = ilane * kItemPerThread; idx < kElemPerRow; idx += kWarpSize * kItemPerThread) + store(row_ptr + idx, zero_vec); + } +} + +template +__global__ void rope_norm_store_kv_kernel( + __nv_bfloat16 *out_q_ptr, __nv_bfloat16 *kcache_ptr, __nv_bfloat16 *vcache_ptr, + __nv_bfloat16 *out_k_ptr, __nv_bfloat16 *out_v_ptr, const __nv_bfloat16 *in_qkv_ptr, + const float *cos_sin_ptr, const int *num_seqlen_per_req_ptr, const int *q_index_ptr, + const int *kvcache_indices_ptr, const float *q_norm_weight_ptr, const float *k_norm_weight_ptr, + int kcache_block_offset, int vcache_block_offset, int num_batch, int max_num_kv_block_per_batch, + cutlass::FastDivmod kv_block_size_divider, int num_rows, int num_compute_blocks) { + using DType = __nv_bfloat16; + + constexpr int kWarpSize = 32; + constexpr int kNumElemPerRow = + kNumQHeads * kQKHeadDim + kNumKVHeads * kQKHeadDim + kNumKVHeads * kVHeadDim; + constexpr int kNumRoundsHalf = ceil_div(); + constexpr int kNumItemPerThread = kNumRoundsHalf * 2; + + int tid = threadIdx.x; + int bid = blockIdx.x; + int iwarp = tid / kWarpSize; + int ilane = tid % kWarpSize; + + __shared__ float smem_cos_sin[kWarpsPerBlock][kQKHeadDim]; + __shared__ float smem_q_norm_w[kQKHeadDim]; + __shared__ float smem_k_norm_w[kQKHeadDim]; + __shared__ int smem_batch_id[kWarpsPerBlock]; + __shared__ int smem_token_pos[kWarpsPerBlock]; + + // ---- Clear blocks: bid >= num_compute_blocks → one block per request ----- + if (bid >= num_compute_blocks) { + int req_id = bid - num_compute_blocks; + if (req_id >= num_batch) return; + + // Last token of this request determines the clear range + int last_token_pos = num_seqlen_per_req_ptr[req_id] - 1; + if (last_token_pos < 0) return; + + int block_idx_in_batch, pos_in_block; + kv_block_size_divider(block_idx_in_batch, pos_in_block, last_token_pos); + int phys_block_id = + kvcache_indices_ptr[req_id * max_num_kv_block_per_batch + block_idx_in_batch]; + + int zero_from = pos_in_block + 1; + int zero_to = kv_block_size_divider.divisor; + if (zero_from < zero_to) { + // Use all kWarpsPerBlock warps cooperatively to zero rows + for (int row = zero_from + iwarp; row < zero_to; row += kWarpsPerBlock) { + DType *k_row = kcache_ptr + (int64_t)phys_block_id * (int64_t)kcache_block_offset + + row * (kNumKVHeads * kQKHeadDim); + DType *v_row = vcache_ptr + (int64_t)phys_block_id * (int64_t)vcache_block_offset + + row * (kNumKVHeads * kVHeadDim); + constexpr int kKItemPerThread = 16 / sizeof(DType); + vec_t zero_vec; +#pragma unroll + for (int z = 0; z < kKItemPerThread; ++z) zero_vec[z] = DType(0); + for (int idx = ilane * kKItemPerThread; idx < kNumKVHeads * kQKHeadDim; + idx += kWarpSize * kKItemPerThread) + store(k_row + idx, zero_vec); + for (int idx = ilane * kKItemPerThread; idx < kNumKVHeads * kVHeadDim; + idx += kWarpSize * kKItemPerThread) + store(v_row + idx, zero_vec); + } + } + return; + } + + // Search q_index to find batch_id + int batch_id = 0; + int token_id = 0; + int irow = bid * kWarpsPerBlock + iwarp; + + // First kWarpsPerBlock threads do the q_index search for the whole block + if (tid < kWarpsPerBlock) { + int global_row = bid * kWarpsPerBlock + tid; + if (global_row < num_rows) { + int b = -1; + for (int i = 0; i < num_batch; ++i) { + if (global_row < q_index_ptr[i + 1]) { + b = i; + break; + } + } + if (b >= 0) { + smem_batch_id[tid] = b; + smem_token_pos[tid] = global_row + num_seqlen_per_req_ptr[b] - q_index_ptr[b + 1]; + } else { + // Padding row: global_row >= q_index[num_batch] (CUDA graph padding) + smem_batch_id[tid] = -1; + smem_token_pos[tid] = -1; + } + } else { + smem_batch_id[tid] = -1; + smem_token_pos[tid] = -1; + } + } + + // Load norm weights into shared memory (once per block) + if constexpr (kNormPolicy > 0) { + constexpr int kItemPerThread = 16 / sizeof(float); + constexpr int kNumPacks = kQKHeadDim / kItemPerThread; + static_assert(kQKHeadDim % kItemPerThread == 0, + "kQKHeadDim must be divisible by kItemPerThread"); + static_assert(kItemPerThread * kWarpSize >= kQKHeadDim, "otherwise here should loop"); + if (tid < kNumPacks) { + int ioffset = tid * kItemPerThread; + store(smem_q_norm_w + ioffset, load(q_norm_weight_ptr + ioffset)); + store(smem_k_norm_w + ioffset, load(k_norm_weight_ptr + ioffset)); + } + } + + __syncthreads(); + + // Early-exit for invalid rows + if (irow >= num_rows) return; + batch_id = smem_batch_id[iwarp]; + token_id = smem_token_pos[iwarp]; + if (token_id < 0) return; + + // Load cos_sin + { + constexpr int kItemPerThread = 16 / sizeof(float); + constexpr int kNumPacks = kQKHeadDim / kItemPerThread; + static_assert(kQKHeadDim % kItemPerThread == 0, ""); + static_assert(kNumPacks <= kWarpSize, ""); + const float *cos_sin_row = cos_sin_ptr + token_id * kQKHeadDim; + if (ilane < kNumPacks) { + int ioffset = ilane * kItemPerThread; + store(&smem_cos_sin[iwarp][0] + ioffset, load(cos_sin_row + ioffset)); + } + __syncwarp(); + } + + // KV cache block addressing + int block_idx_in_batch, block_row; + kv_block_size_divider(block_idx_in_batch, block_row, token_id); + int phys_block_id = + kvcache_indices_ptr[batch_id * max_num_kv_block_per_batch + block_idx_in_batch]; + + DType *k_cache_row_start = kcache_ptr + (int64_t)phys_block_id * (int64_t)kcache_block_offset + + block_row * (kNumKVHeads * kQKHeadDim); + DType *v_cache_row_start = vcache_ptr + (int64_t)phys_block_id * (int64_t)vcache_block_offset + + block_row * (kNumKVHeads * kVHeadDim); + + const DType *qkv_row = in_qkv_ptr + irow * kNumElemPerRow; + + // Process Q heads – load from global, optional norm, RoPE, optional norm, store +#pragma unroll + for (int q_head = 0; q_head < kNumQHeads; ++q_head) { + const DType *q_src = qkv_row + q_head * kQKHeadDim; + DType *q_dst = out_q_ptr + irow * kNumQHeads * kQKHeadDim + q_head * kQKHeadDim; + + vec_t data = {0}; + + // Load Q head from global memory directly into registers +#pragma unroll + for (int r = 0; r < kNumRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kQKHeadDim / 2) { + data[r * 2] = __bfloat162float(q_src[i]); + data[r * 2 + 1] = __bfloat162float(q_src[i + kQKHeadDim / 2]); + } + } + + // norm-then-rope + if constexpr (kNormPolicy == 2) { + rms_norm_apply(data, smem_q_norm_w, ilane); + } + + // RoPE rotation +#pragma unroll + for (int r = 0; r < kNumRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kQKHeadDim / 2) { + rope_rotate_pair(data[r * 2], data[r * 2 + 1], smem_cos_sin[iwarp][i], + smem_cos_sin[iwarp][i + kQKHeadDim / 2]); + } + } + + // rope-then-norm + if constexpr (kNormPolicy == 1) { + rms_norm_apply(data, smem_q_norm_w, ilane); + } + + // Store Q output +#pragma unroll + for (int r = 0; r < kNumRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kQKHeadDim / 2) { + q_dst[i] = __float2bfloat16(data[r * 2]); + q_dst[i + kQKHeadDim / 2] = __float2bfloat16(data[r * 2 + 1]); + } + } + } + + // Process K heads – load from global, optional norm, RoPE, optional norm, + // write to KV cache (or out_k_ptr if non-null) +#pragma unroll + for (int kv_head = 0; kv_head < kNumKVHeads; ++kv_head) { + const DType *k_src = qkv_row + kNumQHeads * kQKHeadDim + kv_head * kQKHeadDim; + + vec_t data = {0}; + +#pragma unroll + for (int r = 0; r < kNumRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kQKHeadDim / 2) { + data[r * 2] = __bfloat162float(k_src[i]); + data[r * 2 + 1] = __bfloat162float(k_src[i + kQKHeadDim / 2]); + } + } + + if constexpr (kNormPolicy == 2) { + rms_norm_apply(data, smem_k_norm_w, ilane); + } + +#pragma unroll + for (int r = 0; r < kNumRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kQKHeadDim / 2) { + rope_rotate_pair(data[r * 2], data[r * 2 + 1], smem_cos_sin[iwarp][i], + smem_cos_sin[iwarp][i + kQKHeadDim / 2]); + } + } + + if constexpr (kNormPolicy == 1) { + rms_norm_apply(data, smem_k_norm_w, ilane); + } + + // Write K output + DType *k_dst = (out_k_ptr != nullptr) + ? out_k_ptr + irow * kNumKVHeads * kQKHeadDim + kv_head * kQKHeadDim + : k_cache_row_start + kv_head * kQKHeadDim; + +#pragma unroll + for (int r = 0; r < kNumRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kQKHeadDim / 2) { + k_dst[i] = __float2bfloat16(data[r * 2]); + k_dst[i + kQKHeadDim / 2] = __float2bfloat16(data[r * 2 + 1]); + } + } + } + + // Process V heads – no RoPE + { + constexpr int kNumVElemPerRow = kNumKVHeads * kVHeadDim; + constexpr int kItemPerThread = 16 / sizeof(DType); + static_assert(kNumVElemPerRow % kItemPerThread == 0, + "kNumKVHeads * kVHeadDim must be multiple of kItemPerThread"); + constexpr int kNumPackPerRow = kNumVElemPerRow / kItemPerThread; + + const DType *v_src = qkv_row + (kNumQHeads + kNumKVHeads) * kQKHeadDim; + DType *v_dst = + (out_v_ptr != nullptr) ? out_v_ptr + irow * kNumKVHeads * kVHeadDim : v_cache_row_start; + + constexpr int kNumLoadRound = ceil_div(); +#pragma unroll + for (int r = 0; r < kNumLoadRound; ++r) { + int ioffset = (r * kWarpSize + ilane) * kItemPerThread; + if (ioffset < kNumVElemPerRow) { + store(v_dst + ioffset, load(v_src + ioffset)); + } + } + } +} + +template +__global__ void rope_norm_store_kv_fp8_kernel( + __nv_fp8_e4m3 *out_q_ptr, __nv_fp8_e4m3 *kcache_ptr, __nv_fp8_e4m3 *vcache_ptr, + __nv_fp8_e4m3 *out_k_ptr, __nv_fp8_e4m3 *out_v_ptr, int32_t *split_k_flag_ptr, + float *q_scale_ptr, const __nv_bfloat16 *in_qkv_ptr, const float *cos_sin_ptr, + const int *num_seqlen_per_req_ptr, const int *q_index_ptr, const int *kvcache_indices_ptr, + const float *q_norm_weight_ptr, const float *k_norm_weight_ptr, const float *k_scale_ptr, + const float *v_scale_ptr, const float *q_scale_inv_ptr, float upper_max, int max_seqlen_aligned, + int kcache_block_offset, int vcache_block_offset, int num_batch, int max_num_kv_block_per_batch, + cutlass::FastDivmod kv_block_size_divider, int num_rows, int num_compute_blocks, + bool is_prefill) { + using DType = __nv_bfloat16; + using QType = __nv_fp8_e4m3; + + constexpr int kWarpSize = 32; + constexpr int kNumElemPerRow = + kNumQHeads * kQKHeadDim + kNumKVHeads * kQKHeadDim + kNumKVHeads * kVHeadDim; + constexpr int kNumRoundsHalf = ceil_div(); + constexpr int kNumItemPerThread = kNumRoundsHalf * 2; + + int tid = threadIdx.x; + int bid = blockIdx.x; + int iwarp = tid / kWarpSize; + int ilane = tid % kWarpSize; + + // Shared memory + __shared__ float smem_cos_sin[kWarpsPerBlock][kQKHeadDim]; + __shared__ float smem_q_norm_w[kQKHeadDim]; + __shared__ float smem_k_norm_w[kQKHeadDim]; + __shared__ int smem_batch_id[kWarpsPerBlock]; + __shared__ int smem_token_pos[kWarpsPerBlock]; + + // ---- Clear blocks: bid >= num_compute_blocks → one block per request ----- + if (bid >= num_compute_blocks) { + int req_id = bid - num_compute_blocks; + if (req_id >= num_batch) return; + + int last_token_pos = num_seqlen_per_req_ptr[req_id] - 1; + if (last_token_pos < 0) return; + + int block_idx_in_batch, pos_in_block; + kv_block_size_divider(block_idx_in_batch, pos_in_block, last_token_pos); + int phys_block_id = + kvcache_indices_ptr[req_id * max_num_kv_block_per_batch + block_idx_in_batch]; + + int zero_from = pos_in_block + 1; + int zero_to = kv_block_size_divider.divisor; + if (zero_from < zero_to) { + for (int row = zero_from + iwarp; row < zero_to; row += kWarpsPerBlock) { + QType *k_row = kcache_ptr + (int64_t)phys_block_id * (int64_t)kcache_block_offset + + row * (kNumKVHeads * kQKHeadDim); + QType *v_row = vcache_ptr + (int64_t)phys_block_id * (int64_t)vcache_block_offset + + row * (kNumKVHeads * kVHeadDim); + constexpr int kKItemPerThread = 16 / sizeof(QType); + vec_t zero_vec; +#pragma unroll + for (int z = 0; z < kKItemPerThread; ++z) zero_vec[z] = QType(0); + for (int idx = ilane * kKItemPerThread; idx < kNumKVHeads * kQKHeadDim; + idx += kWarpSize * kKItemPerThread) + store(k_row + idx, zero_vec); + for (int idx = ilane * kKItemPerThread; idx < kNumKVHeads * kVHeadDim; + idx += kWarpSize * kKItemPerThread) + store(v_row + idx, zero_vec); + } + } + return; + } + + // Determine batch_id and token position — unified for prefill and decode + int batch_id = 0; + int token_id = 0; + int irow = bid * kWarpsPerBlock + iwarp; + + if (tid < kWarpsPerBlock) { + int global_row = bid * kWarpsPerBlock + tid; + if (global_row < num_rows) { + int b = -1; + for (int i = 0; i < num_batch; ++i) { + if (global_row < q_index_ptr[i + 1]) { + b = i; + break; + } + } + if (b >= 0) { + smem_batch_id[tid] = b; + smem_token_pos[tid] = global_row + num_seqlen_per_req_ptr[b] - q_index_ptr[b + 1]; + } else { + smem_batch_id[tid] = -1; + smem_token_pos[tid] = -1; + } + } else { + smem_batch_id[tid] = -1; + smem_token_pos[tid] = -1; + } + } + + // Load norm weights + if constexpr (kNormPolicy > 0) { + constexpr int kItemPerThread = 16 / sizeof(float); + constexpr int kNumPacks = kQKHeadDim / kItemPerThread; + static_assert(kQKHeadDim % kItemPerThread == 0, ""); + static_assert(kItemPerThread * kWarpSize >= kQKHeadDim, ""); + if (tid < kNumPacks) { + int ioffset = tid * kItemPerThread; + store(smem_q_norm_w + ioffset, load(q_norm_weight_ptr + ioffset)); + store(smem_k_norm_w + ioffset, load(k_norm_weight_ptr + ioffset)); + } + } + + // Single barrier: makes batch_id, token_pos, and norm weights visible + __syncthreads(); + + // Early-exit for invalid/padding rows + if (irow >= num_rows) return; + batch_id = smem_batch_id[iwarp]; + token_id = smem_token_pos[iwarp]; + if (token_id < 0) return; + + // Load cos/sin (per-warp, needs __syncwarp for intra-warp visibility) + { + constexpr int kItemPerThread = 16 / sizeof(float); + constexpr int kNumPacks = kQKHeadDim / kItemPerThread; + const float *cos_sin_row = cos_sin_ptr + token_id * kQKHeadDim; + if (ilane < kNumPacks) { + int ioffset = ilane * kItemPerThread; + store(&smem_cos_sin[iwarp][0] + ioffset, load(cos_sin_row + ioffset)); + } + __syncwarp(); + } + + // KV cache block addressing + int block_idx_in_batch, block_row; + kv_block_size_divider(block_idx_in_batch, block_row, token_id); + int phys_block_id = + kvcache_indices_ptr[batch_id * max_num_kv_block_per_batch + block_idx_in_batch]; + + QType *k_cache_row_start = kcache_ptr + (int64_t)phys_block_id * (int64_t)kcache_block_offset + + block_row * (kNumKVHeads * kQKHeadDim); + QType *v_cache_row_start = vcache_ptr + (int64_t)phys_block_id * (int64_t)vcache_block_offset + + block_row * (kNumKVHeads * kVHeadDim); + + const DType *qkv_row = in_qkv_ptr + irow * kNumElemPerRow; + + // ========= Process Q heads ========= +#pragma unroll + for (int q_head = 0; q_head < kNumQHeads; ++q_head) { + const DType *q_src = qkv_row + q_head * kQKHeadDim; + QType *q_dst = out_q_ptr + irow * kNumQHeads * kQKHeadDim + q_head * kQKHeadDim; + + vec_t data = {0}; + +#pragma unroll + for (int r = 0; r < kNumRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kQKHeadDim / 2) { + data[r * 2] = __bfloat162float(q_src[i]); + data[r * 2 + 1] = __bfloat162float(q_src[i + kQKHeadDim / 2]); + } + } + + if constexpr (kNormPolicy == 2) { + rms_norm_apply(data, smem_q_norm_w, ilane); + } + +#pragma unroll + for (int r = 0; r < kNumRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kQKHeadDim / 2) { + rope_rotate_pair(data[r * 2], data[r * 2 + 1], smem_cos_sin[iwarp][i], + smem_cos_sin[iwarp][i + kQKHeadDim / 2]); + } + } + + if constexpr (kNormPolicy == 1) { + rms_norm_apply(data, smem_q_norm_w, ilane); + } + + // Q quantization + float q_mult; + if constexpr (kQuantPolicy == 1) { + // dqskv: dynamic per-token per-head + float max_abs = warp_abs_max(data); + float q_scale_val = max_abs / upper_max; + if (ilane == 0) { + if (is_prefill) { + // Prefill layout: [batch_id, q_head, tok_in_chunk] + int tok_in_chunk = irow - q_index_ptr[batch_id]; + q_scale_ptr[batch_id * kNumQHeads * max_seqlen_aligned + q_head * max_seqlen_aligned + + tok_in_chunk] = q_scale_val; + } else { + // Decode layout: [irow, q_head] + q_scale_ptr[irow * kNumQHeads + q_head] = q_scale_val; + } + } + q_mult = __frcp_rn(q_scale_val); + } else if constexpr (kQuantPolicy == 2) { + // sqskv: static per-tensor + q_mult = q_scale_inv_ptr[0]; + } + + // Store FP8 Q +#pragma unroll + for (int r = 0; r < kNumRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kQKHeadDim / 2) { + q_dst[i] = QType(data[r * 2] * q_mult); + q_dst[i + kQKHeadDim / 2] = QType(data[r * 2 + 1] * q_mult); + } + } + } + + // ========= Process K heads ========= + float k_scale_inv = __frcp_rn(k_scale_ptr[0]); +#pragma unroll + for (int kv_head = 0; kv_head < kNumKVHeads; ++kv_head) { + const DType *k_src = qkv_row + kNumQHeads * kQKHeadDim + kv_head * kQKHeadDim; + + // Zero split_k_flag inside K loop + if (ilane == 0) { + split_k_flag_ptr[batch_id * kNumKVHeads + kv_head] = 0; + } + + vec_t data = {0}; + +#pragma unroll + for (int r = 0; r < kNumRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kQKHeadDim / 2) { + data[r * 2] = __bfloat162float(k_src[i]); + data[r * 2 + 1] = __bfloat162float(k_src[i + kQKHeadDim / 2]); + } + } + + if constexpr (kNormPolicy == 2) { + rms_norm_apply(data, smem_k_norm_w, ilane); + } + +#pragma unroll + for (int r = 0; r < kNumRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kQKHeadDim / 2) { + rope_rotate_pair(data[r * 2], data[r * 2 + 1], smem_cos_sin[iwarp][i], + smem_cos_sin[iwarp][i + kQKHeadDim / 2]); + } + } + + if constexpr (kNormPolicy == 1) { + rms_norm_apply(data, smem_k_norm_w, ilane); + } + + QType *k_dst = (out_k_ptr != nullptr) + ? out_k_ptr + irow * kNumKVHeads * kQKHeadDim + kv_head * kQKHeadDim + : k_cache_row_start + kv_head * kQKHeadDim; + +#pragma unroll + for (int r = 0; r < kNumRoundsHalf; ++r) { + int i = r * kWarpSize + ilane; + if (i < kQKHeadDim / 2) { + k_dst[i] = QType(data[r * 2] * k_scale_inv); + k_dst[i + kQKHeadDim / 2] = QType(data[r * 2 + 1] * k_scale_inv); + } + } + } + + // ========= Process V heads (no RoPE, bf16→fp8) ========= + { + float v_scale_inv = __frcp_rn(v_scale_ptr[0]); + using LoadDType = __nv_bfloat162; + using PackQType = __nv_fp8x4_e4m3; + constexpr int kNumVElemPerRow = kNumKVHeads * kVHeadDim; + constexpr int kItemPerThread = 16 / sizeof(DType); + static_assert(kNumVElemPerRow % kItemPerThread == 0, ""); + constexpr int kNumPackPerRow = kNumVElemPerRow / kItemPerThread; + + const DType *v_src = qkv_row + (kNumQHeads + kNumKVHeads) * kQKHeadDim; + QType *v_dst = (out_v_ptr != nullptr) ? out_v_ptr + irow * kNumKVHeads * kVHeadDim + : reinterpret_cast(v_cache_row_start); + + constexpr int kNumLoadRound = ceil_div(); +#pragma unroll + for (int r = 0; r < kNumLoadRound; ++r) { + int ioffset = (r * kWarpSize + ilane) * kItemPerThread; + if (ioffset < kNumVElemPerRow) { + auto vec_bf162 = load(v_src + ioffset); + auto vec_float = to(vec_bf162); +#pragma unroll + for (int i = 0; i < size(vec_float); i++) { + vec_float[i] = vec_float[i] * v_scale_inv; + } + store(v_dst + ioffset, to(vec_float)); + } + } + } +} + +} // namespace kernels + +template +void launch_rope_norm_store_kv(__nv_bfloat16 *out_q_ptr, __nv_bfloat16 *kcache_ptr, + __nv_bfloat16 *vcache_ptr, __nv_bfloat16 *out_k_ptr, + __nv_bfloat16 *out_v_ptr, const __nv_bfloat16 *in_qkv_ptr, + const float *cos_sin_ptr, const int *num_seqlen_per_req_ptr, + const int *q_index_ptr, const int *kvcache_indices_ptr, + const float *q_norm_weight_ptr, const float *k_norm_weight_ptr, + int kcache_block_offset, int vcache_block_offset, int num_batch, + int max_num_kv_block_per_batch, + cutlass::FastDivmod kv_block_size_divider, int num_rows, + int qk_norm_policy, cudaStream_t stream) { + constexpr int kWarpsPerBlock = 4; + constexpr int kWarpSize = 32; + + int num_compute_blocks = (num_rows + kWarpsPerBlock - 1) / kWarpsPerBlock; + dim3 block(kWarpsPerBlock * kWarpSize); + dim3 grid(num_compute_blocks + num_batch); // compute blocks + 1 clear block per request + + auto launch = [&](auto norm_tag) { + constexpr int kNP = decltype(norm_tag)::value; + kernels::rope_norm_store_kv_kernel<<>>( + out_q_ptr, kcache_ptr, vcache_ptr, out_k_ptr, out_v_ptr, in_qkv_ptr, cos_sin_ptr, + num_seqlen_per_req_ptr, q_index_ptr, kvcache_indices_ptr, q_norm_weight_ptr, + k_norm_weight_ptr, kcache_block_offset, vcache_block_offset, num_batch, + max_num_kv_block_per_batch, kv_block_size_divider, num_rows, num_compute_blocks); + }; + + if (qk_norm_policy == 1) { + launch(std::integral_constant{}); + } else if (qk_norm_policy == 2) { + launch(std::integral_constant{}); + } else { + launch(std::integral_constant{}); + } +} + +void rope_norm_store_kv_async(__nv_bfloat16 *out_q_ptr, __nv_bfloat16 *kcache_ptr, + __nv_bfloat16 *vcache_ptr, __nv_bfloat16 *out_k_ptr, + __nv_bfloat16 *out_v_ptr, const __nv_bfloat16 *in_qkv_ptr, + const float *cos_sin_ptr, const int *num_seqlen_per_req_ptr, + const int *q_index_ptr, const int *kvcache_indices_ptr, + const float *q_norm_weight_ptr, const float *k_norm_weight_ptr, + int kcache_block_offset, int vcache_block_offset, int num_batch, + int max_num_kv_block_per_batch, int kv_block_size, int num_rows, + int num_q_heads, int num_kv_heads, int qk_head_dim, int v_head_dim, + bool is_prefill, int qk_norm_policy, cudaStream_t stream) { + cutlass::FastDivmod kv_block_size_divider(kv_block_size); + + if (num_q_heads == 8 && num_kv_heads == 1 && qk_head_dim == 128 && v_head_dim == 128) { + launch_rope_norm_store_kv<8, 1, 128, 128>( + out_q_ptr, kcache_ptr, vcache_ptr, out_k_ptr, out_v_ptr, in_qkv_ptr, cos_sin_ptr, + num_seqlen_per_req_ptr, q_index_ptr, kvcache_indices_ptr, q_norm_weight_ptr, + k_norm_weight_ptr, kcache_block_offset, vcache_block_offset, num_batch, + max_num_kv_block_per_batch, kv_block_size_divider, num_rows, qk_norm_policy, stream); + } else if (num_q_heads == 64 && num_kv_heads == 8 && qk_head_dim == 128 && v_head_dim == 128) { + launch_rope_norm_store_kv<64, 8, 128, 128>( + out_q_ptr, kcache_ptr, vcache_ptr, out_k_ptr, out_v_ptr, in_qkv_ptr, cos_sin_ptr, + num_seqlen_per_req_ptr, q_index_ptr, kvcache_indices_ptr, q_norm_weight_ptr, + k_norm_weight_ptr, kcache_block_offset, vcache_block_offset, num_batch, + max_num_kv_block_per_batch, kv_block_size_divider, num_rows, qk_norm_policy, stream); + } else { + throw std::invalid_argument("rope_norm_store_kv_async: unsupported config, got: q_heads=" + + std::to_string(num_q_heads) + + ", kv_heads=" + std::to_string(num_kv_heads) + + ", qk_head_dim=" + std::to_string(qk_head_dim) + + ", v_head_dim=" + std::to_string(v_head_dim)); + } +} + +// Launch helper – dispatches kQuantPolicy + kNormPolicy at compile time +template +void launch_rope_norm_store_kv_fp8( + __nv_fp8_e4m3 *out_q_ptr, __nv_fp8_e4m3 *kcache_ptr, __nv_fp8_e4m3 *vcache_ptr, + __nv_fp8_e4m3 *out_k_ptr, __nv_fp8_e4m3 *out_v_ptr, int32_t *split_k_flag_ptr, + float *q_scale_ptr, const __nv_bfloat16 *in_qkv_ptr, const float *cos_sin_ptr, + const int *num_seqlen_per_req_ptr, const int *q_index_ptr, const int *kvcache_indices_ptr, + const float *q_norm_weight_ptr, const float *k_norm_weight_ptr, const float *k_scale_ptr, + const float *v_scale_ptr, const float *q_scale_inv_ptr, float upper_max, int max_seqlen_aligned, + int kcache_block_offset, int vcache_block_offset, int num_batch, int max_num_kv_block_per_batch, + cutlass::FastDivmod kv_block_size_divider, int num_rows, int qk_norm_policy, int quant_policy, + bool is_prefill, cudaStream_t stream) { + constexpr int kWarpsPerBlock = 4; + constexpr int kWarpSize = 32; + + int num_compute_blocks = (num_rows + kWarpsPerBlock - 1) / kWarpsPerBlock; + dim3 block(kWarpsPerBlock * kWarpSize); + dim3 grid(num_compute_blocks + num_batch); + + auto launch = [&](auto quant_tag, auto norm_tag) { + constexpr int kQP = decltype(quant_tag)::value; + constexpr int kNP = decltype(norm_tag)::value; + kernels::rope_norm_store_kv_fp8_kernel<<>>( + out_q_ptr, kcache_ptr, vcache_ptr, out_k_ptr, out_v_ptr, split_k_flag_ptr, q_scale_ptr, + in_qkv_ptr, cos_sin_ptr, num_seqlen_per_req_ptr, q_index_ptr, kvcache_indices_ptr, + q_norm_weight_ptr, k_norm_weight_ptr, k_scale_ptr, v_scale_ptr, q_scale_inv_ptr, upper_max, + max_seqlen_aligned, kcache_block_offset, vcache_block_offset, num_batch, + max_num_kv_block_per_batch, kv_block_size_divider, num_rows, num_compute_blocks, + is_prefill); + }; + + auto dispatch_norm = [&](auto quant_tag) { + if (qk_norm_policy == 1) { + launch(quant_tag, std::integral_constant{}); + } else if (qk_norm_policy == 2) { + launch(quant_tag, std::integral_constant{}); + } else { + launch(quant_tag, std::integral_constant{}); + } + }; + + if (quant_policy == 1) { + dispatch_norm(std::integral_constant{}); + } else { + dispatch_norm(std::integral_constant{}); + } +} + +void rope_norm_store_kv_fp8_async( + __nv_fp8_e4m3 *out_q_ptr, __nv_fp8_e4m3 *kcache_ptr, __nv_fp8_e4m3 *vcache_ptr, + __nv_fp8_e4m3 *out_k_ptr, __nv_fp8_e4m3 *out_v_ptr, int32_t *split_k_flag_ptr, + float *q_scale_ptr, const __nv_bfloat16 *in_qkv_ptr, const float *cos_sin_ptr, + const int *num_seqlen_per_req_ptr, const int *q_index_ptr, const int *kvcache_indices_ptr, + const float *q_norm_weight_ptr, const float *k_norm_weight_ptr, const float *k_scale_ptr, + const float *v_scale_ptr, const float *q_scale_inv_ptr, float upper_max, int max_seqlens, + int kcache_block_offset, int vcache_block_offset, int num_batch, int max_num_kv_block_per_batch, + int kv_block_size, int num_rows, int num_q_heads, int num_kv_heads, int qk_head_dim, + int v_head_dim, bool is_prefill, int qk_norm_policy, int quant_policy, cudaStream_t stream) { + cutlass::FastDivmod kv_block_size_divider(kv_block_size); + int max_seqlen_aligned = ((max_seqlens + 127) / 128) * 128; + + if (num_q_heads == 8 && num_kv_heads == 1 && qk_head_dim == 128 && v_head_dim == 128) { + launch_rope_norm_store_kv_fp8<8, 1, 128, 128>( + out_q_ptr, kcache_ptr, vcache_ptr, out_k_ptr, out_v_ptr, split_k_flag_ptr, q_scale_ptr, + in_qkv_ptr, cos_sin_ptr, num_seqlen_per_req_ptr, q_index_ptr, kvcache_indices_ptr, + q_norm_weight_ptr, k_norm_weight_ptr, k_scale_ptr, v_scale_ptr, q_scale_inv_ptr, upper_max, + max_seqlen_aligned, kcache_block_offset, vcache_block_offset, num_batch, + max_num_kv_block_per_batch, kv_block_size_divider, num_rows, qk_norm_policy, quant_policy, + is_prefill, stream); + } else if (num_q_heads == 64 && num_kv_heads == 8 && qk_head_dim == 128 && v_head_dim == 128) { + launch_rope_norm_store_kv_fp8<64, 8, 128, 128>( + out_q_ptr, kcache_ptr, vcache_ptr, out_k_ptr, out_v_ptr, split_k_flag_ptr, q_scale_ptr, + in_qkv_ptr, cos_sin_ptr, num_seqlen_per_req_ptr, q_index_ptr, kvcache_indices_ptr, + q_norm_weight_ptr, k_norm_weight_ptr, k_scale_ptr, v_scale_ptr, q_scale_inv_ptr, upper_max, + max_seqlen_aligned, kcache_block_offset, vcache_block_offset, num_batch, + max_num_kv_block_per_batch, kv_block_size_divider, num_rows, qk_norm_policy, quant_policy, + is_prefill, stream); + } else { + throw std::invalid_argument("rope_norm_store_kv_fp8_async: unsupported config, got: q_heads=" + + std::to_string(num_q_heads) + + ", kv_heads=" + std::to_string(num_kv_heads) + + ", qk_head_dim=" + std::to_string(qk_head_dim) + + ", v_head_dim=" + std::to_string(v_head_dim)); + } +} + +} // namespace rope +} // namespace hpc diff --git a/src/rope/rope.h b/src/rope/rope.h new file mode 100644 index 0000000..3c0a5f7 --- /dev/null +++ b/src/rope/rope.h @@ -0,0 +1,41 @@ +// Copyright (C) 2026 Tencent. + +#ifndef SRC_ROPE_ROPE_H_ +#define SRC_ROPE_ROPE_H_ + +#include +#include +#include +#include + +#include + +namespace hpc { +namespace rope { + +void rope_norm_store_kv_async(__nv_bfloat16 *out_q_ptr, __nv_bfloat16 *kcache_ptr, + __nv_bfloat16 *vcache_ptr, __nv_bfloat16 *out_k_ptr, + __nv_bfloat16 *out_v_ptr, const __nv_bfloat16 *in_qkv_ptr, + const float *cos_sin_ptr, const int *num_seqlen_per_req_ptr, + const int *q_index_ptr, const int *kvcache_indices_ptr, + const float *q_norm_weight_ptr, const float *k_norm_weight_ptr, + int kcache_block_offset, int vcache_block_offset, int num_batch, + int max_num_kv_block_per_batch, int kv_block_size, int num_rows, + int num_q_heads, int num_kv_heads, int qk_head_dim, int v_head_dim, + bool is_prefill, int qk_norm_policy, cudaStream_t stream); + +void rope_norm_store_kv_fp8_async( + __nv_fp8_e4m3 *out_q_ptr, __nv_fp8_e4m3 *kcache_ptr, __nv_fp8_e4m3 *vcache_ptr, + __nv_fp8_e4m3 *out_k_ptr, __nv_fp8_e4m3 *out_v_ptr, int32_t *split_k_flag_ptr, + float *q_scale_ptr, const __nv_bfloat16 *in_qkv_ptr, const float *cos_sin_ptr, + const int *num_seqlen_per_req_ptr, const int *q_index_ptr, const int *kvcache_indices_ptr, + const float *q_norm_weight_ptr, const float *k_norm_weight_ptr, const float *k_scale_ptr, + const float *v_scale_ptr, const float *q_scale_inv_ptr, float upper_max, int max_seqlens, + int kcache_block_offset, int vcache_block_offset, int num_batch, int max_num_kv_block_per_batch, + int kv_block_size, int num_rows, int num_q_heads, int num_kv_heads, int qk_head_dim, + int v_head_dim, bool is_prefill, int qk_norm_policy, int quant_policy, cudaStream_t stream); + +} // namespace rope +} // namespace hpc + +#endif // SRC_ROPE_ROPE_H_ diff --git a/tests/test_rope.py b/tests/test_rope.py new file mode 100644 index 0000000..e92d4fa --- /dev/null +++ b/tests/test_rope.py @@ -0,0 +1,1254 @@ +import os +import sys +from pathlib import Path + +sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) + +import numpy as np +import pytest +import torch +import torch.nn.functional as F + +import hpc +from utils import allclose + + +def generate_cos_sin_cache(max_position, head_dim, base=10000.0): + """Generate RoPE cos/sin cache.""" + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + t = torch.arange(max_position).float() + freqs = torch.outer(t, inv_freq) # [max_position, head_dim/2] + + cos = freqs.cos() # [max_position, head_dim/2] + sin = freqs.sin() # [max_position, head_dim/2] + + # Concatenate cos and sin: [max_position, head_dim] + cos_sin = torch.cat([cos, sin], dim=-1) + return cos_sin + + +def generate_kv_block_indices(kcache, req_length: list): + + num_req = len(req_length) + total_blocks_in_pool = kcache.shape[0] + kv_block_size = kcache.shape[1] + num_blocks_per_req = [(length + kv_block_size - 1) // kv_block_size for length in req_length] + total_blocks_used = sum(num_blocks_per_req) + max_blocks_used = max(num_blocks_per_req) + + shuffled_blocks = torch.randperm(total_blocks_in_pool) + + # +4 for testing + kv_indices = torch.ones(num_req, max_blocks_used + 4, dtype=torch.int32) * -1 + + block_offset = 0 + for i in range(num_req): + kv_indices[i, : num_blocks_per_req[i]] = shuffled_blocks[ + block_offset : block_offset + num_blocks_per_req[i] + ] + block_offset += num_blocks_per_req[i] + + assert block_offset == total_blocks_used + + return kv_indices + + +def apply_rms_norm_reference(x, weight, eps=1e-6): + variance = x.pow(2).mean(dim=-1, keepdim=True) + x_normed = x * torch.rsqrt(variance + eps) + return x_normed * weight + + +def apply_rotary_pos_emb_neox_reference(x, cos_sin): + num_tokens, num_heads, head_dim = x.shape + half_dim = head_dim // 2 + + # Split x into two halves + x1 = x[..., :half_dim] # [num_tokens, num_heads, half_dim] + x2 = x[..., half_dim:] # [num_tokens, num_heads, half_dim] + + # Extract cos and sin from cos_sin tensor + cos_half = cos_sin[:, :half_dim].unsqueeze(1) # [num_tokens, 1, half_dim] + sin_half = cos_sin[:, half_dim:].unsqueeze(1) # [num_tokens, 1, half_dim] + + # Apply rotation (neox version) + o1 = x1 * cos_half - x2 * sin_half + o2 = x2 * cos_half + x1 * sin_half + + # Concatenate + output = torch.cat([o1, o2], dim=-1) + return output + + +def prepare_prefill_input( + num_req, + req_length, + num_q_heads, + num_kv_heads, + qk_head_dim, + v_head_dim, + kv_block_size, + max_num_kv_blocks, + max_rope_position, + dtype=torch.bfloat16, + device="cuda", +): + + if req_length is None: + req_length = torch.randint(20, 200, (num_req,)).tolist() + if isinstance(req_length, int): + req_length = [req_length] * num_req + total_rows = sum(req_length) + qkv = torch.randn( + total_rows, + num_q_heads * qk_head_dim + num_kv_heads * qk_head_dim + num_kv_heads * v_head_dim, + dtype=dtype, + device=device, + ) + cos_sin = generate_cos_sin_cache(max_rope_position, qk_head_dim).to( + dtype=torch.float32, device=device + ) + kcache = torch.randn( + max_num_kv_blocks, kv_block_size, num_kv_heads, qk_head_dim, dtype=dtype, device=device + ) + vcache = torch.randn( + max_num_kv_blocks, kv_block_size, num_kv_heads, v_head_dim, dtype=dtype, device=device + ) + + kv_indices = generate_kv_block_indices(kcache, req_length).to(device) + + q_norm_weight = torch.randn(qk_head_dim, dtype=torch.float32, device=device) + k_norm_weight = torch.randn(qk_head_dim, dtype=torch.float32, device=device) + + num_seqlen_per_req = torch.tensor(req_length, dtype=torch.int32, device=device) + + return ( + qkv, + num_seqlen_per_req, + cos_sin, + kcache, + vcache, + kv_indices, + q_norm_weight, + k_norm_weight, + ) + + +def prepare_decode_input( + num_req, + req_length, + num_q_heads, + num_kv_heads, + qk_head_dim, + v_head_dim, + kv_block_size, + max_num_kv_blocks, + max_rope_position, + dtype=torch.bfloat16, + device="cuda", +): + if req_length is None: + req_length = torch.randint(20, 200, (num_req,)).tolist() + if isinstance(req_length, int): + req_length = [req_length] * num_req + # input req length is the existing length, not the new length, we add 1 here for kvcache update + req_length = [x + 1 for x in req_length] + total_rows = num_req + qkv = torch.randn( + total_rows, + num_q_heads * qk_head_dim + num_kv_heads * qk_head_dim + num_kv_heads * v_head_dim, + dtype=dtype, + device=device, + ) + cos_sin = generate_cos_sin_cache(max_rope_position, qk_head_dim).to( + dtype=torch.float32, device=device + ) + kcache = torch.randn( + max_num_kv_blocks, kv_block_size, num_kv_heads, qk_head_dim, dtype=dtype, device=device + ) + vcache = torch.randn( + max_num_kv_blocks, kv_block_size, num_kv_heads, v_head_dim, dtype=dtype, device=device + ) + + kv_indices = generate_kv_block_indices(kcache, req_length).cuda() + + q_norm_weight = torch.randn(qk_head_dim, dtype=torch.float32, device=device) + k_norm_weight = torch.randn(qk_head_dim, dtype=torch.float32, device=device) + + num_seqlen_per_req = torch.tensor(req_length, dtype=torch.int32, device=device) + + return ( + qkv, + num_seqlen_per_req, + cos_sin, + kcache, + vcache, + kv_indices, + q_norm_weight, + k_norm_weight, + ) + + +def torch_rope_norm_blocked_prefill( + kcache, + vcache, + qkv, + cos_sin, + num_seqlen_per_req, + q_index, + kv_indices, + is_prefill=True, + use_qknorm=False, + q_norm_weight=None, + k_norm_weight=None, + qk_norm_policy=1, + clear_kv_tail=False, +): + """Test RoPE prefill mode with PyTorch reference implementation.""" + assert is_prefill + assert ( + q_index.shape[0] == num_seqlen_per_req.shape[0] + 1 + ) # q_index is a prefix sum of each len + dtype = qkv.dtype + num_kv_heads = kcache.shape[2] + v_head_dim = vcache.shape[3] + qk_head_dim = kcache.shape[3] + num_q_heads = ( + qkv.shape[1] - num_kv_heads * qk_head_dim - num_kv_heads * v_head_dim + ) // qk_head_dim + q_seq_lens = (q_index[1:] - q_index[:-1]).tolist() + + num_rows = q_index[-1].item() + num_req = num_seqlen_per_req.shape[0] + q_input = qkv[:, : num_q_heads * qk_head_dim].to(torch.float32) + k_input = qkv[ + :, num_q_heads * qk_head_dim : num_q_heads * qk_head_dim + num_kv_heads * qk_head_dim + ].to(torch.float32) + v_input = qkv[:, num_q_heads * qk_head_dim + num_kv_heads * qk_head_dim :] + + # Build cos_sin for each token + cos_sin_for_tokens = torch.zeros(num_rows, qk_head_dim, dtype=torch.float32, device="cuda") + token_offset = 0 + for batch_idx in range(num_req): + seq_len = num_seqlen_per_req[batch_idx].item() + q_seq_len = q_seq_lens[batch_idx] + cos_sin_for_tokens[token_offset : token_offset + q_seq_len] = cos_sin[ + seq_len - q_seq_len : seq_len + ] + token_offset += q_seq_len + + q_ref = q_input.view(num_rows, num_q_heads, qk_head_dim) + k_ref = k_input.view(num_rows, num_kv_heads, qk_head_dim) + v_ref = v_input.view(num_rows, num_kv_heads, v_head_dim) + # Compute reference Q and K + + if use_qknorm and qk_norm_policy == 2: + q_ref = apply_rms_norm_reference(q_ref, q_norm_weight) + k_ref = apply_rms_norm_reference(k_ref, k_norm_weight) + + q_ref = apply_rotary_pos_emb_neox_reference(q_ref, cos_sin_for_tokens) + k_ref = apply_rotary_pos_emb_neox_reference(k_ref, cos_sin_for_tokens) + + if use_qknorm and qk_norm_policy == 1: + q_ref = apply_rms_norm_reference(q_ref, q_norm_weight) + k_ref = apply_rms_norm_reference(k_ref, k_norm_weight) + + # update kvcache + kv_block_size = kcache.shape[1] + token_idx = 0 + # breakpoint() + for req_idx in range(num_req): + seq_len = num_seqlen_per_req[req_idx].item() + q_seq_len = q_seq_lens[req_idx] + for pos_in_seq in range(seq_len - q_seq_len, seq_len): + block_idx_in_req = pos_in_seq // kv_block_size + pos_in_block = pos_in_seq % kv_block_size + cache_block_idx = kv_indices[req_idx, block_idx_in_req].item() + assert cache_block_idx >= 0, f"Invalid cache block index: {cache_block_idx}" + # Update K cache + kcache[cache_block_idx, pos_in_block, :, :] = k_ref[token_idx, :, :].to(dtype) + # Update V cache + vcache[cache_block_idx, pos_in_block, :, :] = v_ref[token_idx, :, :].to(dtype) + # Clear rows [pos_in_block+1, kv_block_size) for last token of each request + if clear_kv_tail and pos_in_seq == seq_len - 1 and pos_in_block + 1 < kv_block_size: + kcache[cache_block_idx, pos_in_block + 1 :, :, :] = 0 + vcache[cache_block_idx, pos_in_block + 1 :, :, :] = 0 + token_idx += 1 + + out_q = q_ref.to(dtype) + out_k = k_ref.to(dtype) + return out_q, out_k, kcache, vcache + + +def torch_rope_norm_blocked_decode( + kcache, + vcache, + qkv, + cos_sin, + num_seqlen_per_req, + q_index, + kv_indices, + is_prefill=False, + use_qknorm=False, + q_norm_weight=None, + k_norm_weight=None, + qk_norm_policy=1, + clear_kv_tail=False, +): + """Test RoPE decode mode with PyTorch reference implementation.""" + assert not is_prefill + dtype = qkv.dtype + num_kv_heads = kcache.shape[2] + v_head_dim = vcache.shape[3] + qk_head_dim = kcache.shape[3] + num_q_heads = ( + qkv.shape[1] - num_kv_heads * qk_head_dim - num_kv_heads * v_head_dim + ) // qk_head_dim + + num_req = num_seqlen_per_req.shape[0] + num_rows = num_req + q_input = qkv[:, : num_q_heads * qk_head_dim].to(torch.float32) + k_input = qkv[ + :, num_q_heads * qk_head_dim : num_q_heads * qk_head_dim + num_kv_heads * qk_head_dim + ].to(torch.float32) + v_input = qkv[:, num_q_heads * qk_head_dim + num_kv_heads * qk_head_dim :] + + # Build cos_sin for each token + cos_sin_for_tokens = torch.zeros(num_rows, qk_head_dim, dtype=torch.float32, device="cuda") + for batch_idx in range(num_req): + seq_len = num_seqlen_per_req[batch_idx].item() + position = seq_len - 1 + cos_sin_for_tokens[batch_idx] = cos_sin[position] + + q_ref = q_input.view(num_rows, num_q_heads, qk_head_dim) + k_ref = k_input.view(num_rows, num_kv_heads, qk_head_dim) + v_ref = v_input.view(num_rows, num_kv_heads, v_head_dim) + + if use_qknorm and qk_norm_policy == 2: + q_ref = apply_rms_norm_reference(q_ref, q_norm_weight) + k_ref = apply_rms_norm_reference(k_ref, k_norm_weight) + + # Compute reference Q and K + q_ref = apply_rotary_pos_emb_neox_reference(q_ref, cos_sin_for_tokens) + k_ref = apply_rotary_pos_emb_neox_reference(k_ref, cos_sin_for_tokens) + + if use_qknorm and qk_norm_policy == 1: + q_ref = apply_rms_norm_reference(q_ref, q_norm_weight) + k_ref = apply_rms_norm_reference(k_ref, k_norm_weight) + + # update kvcache + kv_block_size = kcache.shape[1] + token_idx = 0 + for req_idx in range(num_req): + seq_len = num_seqlen_per_req[req_idx].item() + pos_in_seq = seq_len - 1 + block_idx_in_req = pos_in_seq // kv_block_size + pos_in_block = pos_in_seq % kv_block_size + cache_block_idx = kv_indices[req_idx, block_idx_in_req].item() + assert cache_block_idx >= 0, f"Invalid cache block index: {cache_block_idx}" + # Update K cache + kcache[cache_block_idx, pos_in_block, :, :] = k_ref[token_idx, :, :].to(dtype) + # Update V cache + vcache[cache_block_idx, pos_in_block, :, :] = v_ref[token_idx, :, :].to(dtype) + + # Clear KV cache tail rows + if clear_kv_tail: + # New unified clearing: always clear [pos_in_block+1, kv_block_size) + if pos_in_block + 1 < kv_block_size: + kcache[cache_block_idx, pos_in_block + 1 :, :, :] = 0 + vcache[cache_block_idx, pos_in_block + 1 :, :, :] = 0 + else: + # Old behavior: clear only when pos_in_block == 0 + if pos_in_block == 0: + kcache[cache_block_idx, 1:, :, :] = 0 + vcache[cache_block_idx, 1:, :, :] = 0 + + token_idx += 1 + + out_q = q_ref.to(dtype) + out_k = k_ref.to(dtype) + # out_qkv = torch.cat([out_q, out_k, out_v], dim=1) + return out_q, out_k, kcache, vcache + + +def sample_and_extract_qkv(req_length, qkv): + + device = qkv.device + req_length = torch.tensor(req_length).to(device) + batch_size = req_length.size(0) + + # rand a ratio + rand_factors = torch.rand(batch_size, device=device) + q_length = (rand_factors * req_length).long() + 1 + q_length = torch.min(q_length, req_length) + + # ensure not larger + req_cumsum = torch.cumsum(req_length, dim=0) + + slices = [] + + for i in range(batch_size): + curr_original_end = req_cumsum[i].item() + curr_new_len = q_length[i].item() + slice_start = curr_original_end - curr_new_len + slice_end = curr_original_end + slices.append(qkv[slice_start:slice_end]) + + qkv_new = torch.cat(slices, dim=0) + + q_cumsum = torch.cumsum(q_length, dim=0) + + # add a zero + zero_pad = torch.tensor([0], device=device, dtype=q_cumsum.dtype) + q_index = torch.cat((zero_pad, q_cumsum), dim=0) + + return q_index.to(torch.int32), qkv_new + + +@pytest.mark.parametrize("num_req", [7]) +@pytest.mark.parametrize( + "num_q_heads,num_kv_heads,qk_head_dim", + [(8, 1, 128), (64, 8, 128)], +) +@pytest.mark.parametrize("qk_norm_policy", [1]) +def test_rope_norm_store_kv_prefill( + num_req, num_q_heads, num_kv_heads, qk_head_dim, qk_norm_policy +): + req_length = torch.randint(20, 200, (num_req,)).tolist() + v_head_dim = qk_head_dim + kv_block_size = 64 + max_num_kv_blocks = 1024 + max_rope_position = 2048 + dtype = torch.bfloat16 + ( + qkv, + num_seqlen_per_req, + cos_sin, + kcache, + vcache, + kv_indices, + q_norm_weight, + k_norm_weight, + ) = prepare_prefill_input( + num_req, + req_length, + num_q_heads, + num_kv_heads, + qk_head_dim, + v_head_dim, + kv_block_size, + max_num_kv_blocks, + max_rope_position, + dtype, + ) + q_index, qkv_new = sample_and_extract_qkv(req_length, qkv) + + qkv_ref = qkv_new.clone() + kcache_ref = kcache.clone() + vcache_ref = vcache.clone() + + my_out_q = hpc.rope_norm_store_kv( + kcache, + vcache, + qkv_new, + cos_sin, + num_seqlen_per_req, + q_index, + kv_indices, + True, + q_norm_weight=q_norm_weight if qk_norm_policy > 0 else None, + k_norm_weight=k_norm_weight if qk_norm_policy > 0 else None, + qk_norm_policy=qk_norm_policy, + ) + + torch_out_q, torch_out_k, torch_kcache, torch_vcache = torch_rope_norm_blocked_prefill( + kcache_ref, + vcache_ref, + qkv_ref, + cos_sin, + num_seqlen_per_req, + q_index, + kv_indices, + is_prefill=True, + use_qknorm=(qk_norm_policy > 0), + q_norm_weight=q_norm_weight, + k_norm_weight=k_norm_weight, + qk_norm_policy=qk_norm_policy, + clear_kv_tail=True, + ) + + assert allclose(torch_out_q, my_out_q, atol=8e-2) + assert allclose(torch_kcache, kcache, atol=8e-2) + assert allclose(torch_vcache, vcache, atol=8e-2) + + +def pad_decode_inputs_to_align8(qkv, num_seqlen_per_req, q_index, kv_indices): + """Pad decode inputs so total rows and num_batch are aligned to 8. + Simulates CUDA graph padding: extra batches have q_index[i+1]-q_index[i]=0 + and num_seqlen_per_req[i]=0. + """ + num_rows = qkv.shape[0] + num_batch = num_seqlen_per_req.shape[0] + hidden = qkv.shape[1] + + padded_batch = (num_batch + 7) // 8 * 8 + pad_batch = padded_batch - num_batch + padded_rows = (num_rows + 7) // 8 * 8 + pad_rows = padded_rows - num_rows + + if pad_rows > 0: + qkv = torch.cat([qkv, torch.zeros(pad_rows, hidden, dtype=qkv.dtype, device=qkv.device)]) + + if pad_batch > 0: + num_seqlen_per_req = torch.cat( + [ + num_seqlen_per_req, + torch.zeros( + pad_batch, dtype=num_seqlen_per_req.dtype, device=num_seqlen_per_req.device + ), + ] + ) + + # q_index: original ends at num_rows, padding batches have 0 tokens each, + # but we assign all pad_rows to the first padding batch so q_index covers padded_rows + last_val = q_index[-1] # == num_rows + if pad_batch > 0: + pad_q = torch.full((pad_batch,), padded_rows, dtype=q_index.dtype, device=q_index.device) + q_index = torch.cat([q_index, pad_q]) + + if pad_batch > 0: + kv_indices = torch.cat( + [ + kv_indices, + torch.zeros( + pad_batch, kv_indices.shape[1], dtype=kv_indices.dtype, device=kv_indices.device + ), + ] + ) + + return qkv, num_seqlen_per_req, q_index, kv_indices, num_rows + + +@pytest.mark.parametrize("num_req", [8]) +@pytest.mark.parametrize( + "num_q_heads,num_kv_heads,qk_head_dim", + [(8, 1, 128), (64, 8, 128)], +) +@pytest.mark.parametrize("qk_norm_policy", [1]) +def test_rope_norm_store_kv_decode(num_req, num_q_heads, num_kv_heads, qk_head_dim, qk_norm_policy): + req_length = torch.randint(20, 200, (num_req,)).tolist() + v_head_dim = qk_head_dim + kv_block_size = 64 + max_num_kv_blocks = 1024 + max_rope_position = 2048 + dtype = torch.bfloat16 + ( + qkv, + num_seqlen_per_req, + cos_sin, + kcache, + vcache, + kv_indices, + q_norm_weight, + k_norm_weight, + ) = prepare_decode_input( + num_req, + req_length, + num_q_heads, + num_kv_heads, + qk_head_dim, + v_head_dim, + kv_block_size, + max_num_kv_blocks, + max_rope_position, + dtype, + ) + + q_index_decode = torch.arange(num_req + 1, dtype=torch.int32, device=qkv.device) + + # Pad to align-8 (simulates CUDA graph padding) + qkv, num_seqlen_per_req, q_index_decode, kv_indices, real_rows = pad_decode_inputs_to_align8( + qkv, num_seqlen_per_req, q_index_decode, kv_indices + ) + + qkv_ref = qkv[:real_rows].clone() + kcache_ref = kcache.clone() + vcache_ref = vcache.clone() + q_index_ref = torch.arange(num_req + 1, dtype=torch.int32, device=qkv.device) + + my_out_q = hpc.rope_norm_store_kv( + kcache, + vcache, + qkv, + cos_sin, + num_seqlen_per_req, + q_index_decode, + kv_indices, + False, # is prefill + q_norm_weight=q_norm_weight if qk_norm_policy > 0 else None, + k_norm_weight=k_norm_weight if qk_norm_policy > 0 else None, + qk_norm_policy=qk_norm_policy, + ) + + torch_out_q, torch_out_k, torch_kcache, torch_vcache = torch_rope_norm_blocked_decode( + kcache_ref, + vcache_ref, + qkv_ref, + cos_sin, + num_seqlen_per_req[:num_req], + q_index_ref, + kv_indices[:num_req], + is_prefill=False, + use_qknorm=(qk_norm_policy > 0), + q_norm_weight=q_norm_weight, + k_norm_weight=k_norm_weight, + qk_norm_policy=qk_norm_policy, + clear_kv_tail=True, + ) + + assert allclose(torch_out_q, my_out_q[:real_rows], atol=5e-2) + assert allclose(torch_kcache, kcache, atol=5e-2) + assert allclose(torch_vcache, vcache, atol=5e-2) + + +@pytest.mark.skipif(bool(os.getenv("SANITIZER_CHECK")), reason="skip sanitizer") +@pytest.mark.parametrize("num_req", [7]) +@pytest.mark.parametrize( + "num_q_heads,num_kv_heads,qk_head_dim", + [(8, 1, 128), (64, 8, 128)], +) +@pytest.mark.parametrize("qk_norm_policy", [1, 2]) +def test_rope_norm_store_kv_fp8_prefill_dqskv( + num_req, num_q_heads, num_kv_heads, qk_head_dim, qk_norm_policy +): + req_length = torch.randint(20, 200, (num_req,)).tolist() + v_head_dim = qk_head_dim + kv_block_size = 64 + max_num_kv_blocks = 1024 + max_rope_position = 2048 + dtype = torch.bfloat16 + ( + qkv, + num_seqlen_per_req, + cos_sin, + kcache, + vcache, + kv_indices, + q_norm_weight, + k_norm_weight, + ) = prepare_prefill_input( + num_req, + req_length, + num_q_heads, + num_kv_heads, + qk_head_dim, + v_head_dim, + kv_block_size, + max_num_kv_blocks, + max_rope_position, + dtype, + ) + q_index, qkv_new = sample_and_extract_qkv(req_length, qkv) + + qkv_ref = qkv_new.clone() + kcache_ref = kcache.clone() + vcache_ref = vcache.clone() + + k_scale = torch.tensor([0.1], dtype=torch.float32, device=qkv_new.device) + v_scale = torch.tensor([0.1], dtype=torch.float32, device=qkv_new.device) + kcache_fp8 = kcache.to(torch.float8_e4m3fn) + vcache_fp8 = vcache.to(torch.float8_e4m3fn) + + seqlens = q_index[1:] - q_index[:-1] + max_seqlens = seqlens.max().item() + + q_fp8, q_scale, split_k_flag = hpc.rope_norm_store_kv_fp8( + key_cache=kcache_fp8, + value_cache=vcache_fp8, + qkv=qkv_new, + cos_sin=cos_sin, + num_seqlen_per_req=num_seqlen_per_req, + q_index=q_index, + kvcache_indices=kv_indices, + is_prefill=True, + k_scale=k_scale, + v_scale=v_scale, + quant_policy=1, # 1 for dqskv , 2 for sqskv + max_seqlens=max_seqlens, + q_norm_weight=q_norm_weight if qk_norm_policy > 0 else None, + k_norm_weight=k_norm_weight if qk_norm_policy > 0 else None, + qk_norm_policy=qk_norm_policy, # 0 for no norm, 1 for rope first, 2 for norm first + ) + + mask = torch.arange(q_scale.shape[2]).expand( + q_scale.shape[0], q_scale.shape[2] + ).cuda() < seqlens.unsqueeze(1) + qk_scale_normal = q_scale.permute(0, 2, 1)[mask].cuda() + q_bf16 = (q_fp8.to(torch.bfloat16) * qk_scale_normal[:, :, None]).to(torch.bfloat16) + + torch_out_q, _, _, _ = torch_rope_norm_blocked_prefill( + kcache_ref, + vcache_ref, + qkv_ref, + cos_sin, + num_seqlen_per_req, + q_index, + kv_indices, + is_prefill=True, + use_qknorm=(qk_norm_policy > 0), + q_norm_weight=q_norm_weight, + k_norm_weight=k_norm_weight, + qk_norm_policy=qk_norm_policy, + clear_kv_tail=True, + ) + + assert allclose(torch_out_q, q_bf16, atol=0.5) + + +@pytest.mark.parametrize("num_req", [8]) +@pytest.mark.parametrize( + "num_q_heads,num_kv_heads,qk_head_dim", + [(8, 1, 128), (64, 8, 128)], +) +@pytest.mark.parametrize("qk_norm_policy", [0, 1]) +def test_rope_norm_store_kv_fp8_decode_dqskv( + num_req, num_q_heads, num_kv_heads, qk_head_dim, qk_norm_policy +): + req_length = torch.randint(20, 200, (num_req,)).tolist() + v_head_dim = qk_head_dim + kv_block_size = 64 + max_num_kv_blocks = 1024 + max_rope_position = 2048 + dtype = torch.bfloat16 + ( + qkv, + num_seqlen_per_req, + cos_sin, + kcache, + vcache, + kv_indices, + q_norm_weight, + k_norm_weight, + ) = prepare_decode_input( + num_req, + req_length, + num_q_heads, + num_kv_heads, + qk_head_dim, + v_head_dim, + kv_block_size, + max_num_kv_blocks, + max_rope_position, + dtype, + ) + + q_index_decode = torch.arange(num_req + 1, dtype=torch.int32, device=qkv.device) + + # Pad to align-8 + qkv, num_seqlen_per_req, q_index_decode, kv_indices, real_rows = pad_decode_inputs_to_align8( + qkv, num_seqlen_per_req, q_index_decode, kv_indices + ) + + qkv_ref = qkv[:real_rows].clone() + kcache_ref = kcache.clone() + vcache_ref = vcache.clone() + q_index_ref = torch.arange(num_req + 1, dtype=torch.int32, device=qkv.device) + + k_scale = torch.tensor([0.1], dtype=torch.float32, device=qkv.device) + v_scale = torch.tensor([0.1], dtype=torch.float32, device=qkv.device) + kcache_fp8 = kcache.to(torch.float8_e4m3fn) + vcache_fp8 = vcache.to(torch.float8_e4m3fn) + + q_fp8, q_scale, split_k_flag = hpc.rope_norm_store_kv_fp8( + key_cache=kcache_fp8, + value_cache=vcache_fp8, + qkv=qkv, + cos_sin=cos_sin, + num_seqlen_per_req=num_seqlen_per_req, + q_index=q_index_decode, + kvcache_indices=kv_indices, + is_prefill=False, + k_scale=k_scale, + v_scale=v_scale, + quant_policy=1, + max_seqlens=1, + q_norm_weight=q_norm_weight if qk_norm_policy > 0 else None, + k_norm_weight=k_norm_weight if qk_norm_policy > 0 else None, + qk_norm_policy=qk_norm_policy, + ) + + q_bf16 = (q_fp8[:real_rows].to(torch.bfloat16) * q_scale[:real_rows, :, None]).to( + torch.bfloat16 + ) + + torch_out_q, _, _, _ = torch_rope_norm_blocked_decode( + kcache_ref, + vcache_ref, + qkv_ref, + cos_sin, + num_seqlen_per_req[:num_req], + q_index_ref, + kv_indices[:num_req], + is_prefill=False, + use_qknorm=(qk_norm_policy > 0), + q_norm_weight=q_norm_weight, + k_norm_weight=k_norm_weight, + qk_norm_policy=qk_norm_policy, + clear_kv_tail=True, + ) + + assert allclose(torch_out_q, q_bf16, atol=0.5) + + +@pytest.mark.skipif(bool(os.getenv("SANITIZER_CHECK")), reason="skip sanitizer") +@pytest.mark.parametrize("num_req", [7]) +@pytest.mark.parametrize( + "num_q_heads,num_kv_heads,qk_head_dim", + [(8, 1, 128), (64, 8, 128)], +) +@pytest.mark.parametrize("qk_norm_policy", [0, 2]) +def test_rope_norm_store_kv_fp8_prefill_sqskv( + num_req, num_q_heads, num_kv_heads, qk_head_dim, qk_norm_policy +): + req_length = torch.randint(20, 200, (num_req,)).tolist() + v_head_dim = qk_head_dim + kv_block_size = 64 + max_num_kv_blocks = 1024 + max_rope_position = 2048 + dtype = torch.bfloat16 + ( + qkv, + num_seqlen_per_req, + cos_sin, + kcache, + vcache, + kv_indices, + q_norm_weight, + k_norm_weight, + ) = prepare_prefill_input( + num_req, + req_length, + num_q_heads, + num_kv_heads, + qk_head_dim, + v_head_dim, + kv_block_size, + max_num_kv_blocks, + max_rope_position, + dtype, + ) + q_index, qkv_new = sample_and_extract_qkv(req_length, qkv) + + qkv_ref = qkv_new.clone() + kcache_ref = kcache.clone() + vcache_ref = vcache.clone() + + k_scale = torch.tensor([0.1], dtype=torch.float32, device=qkv_new.device) + v_scale = torch.tensor([0.1], dtype=torch.float32, device=qkv_new.device) + kcache_fp8 = kcache.to(torch.float8_e4m3fn) + vcache_fp8 = vcache.to(torch.float8_e4m3fn) + + seqlens = q_index[1:] - q_index[:-1] + max_seqlens = seqlens.max().item() + + q_scale_val = 2 + q_scale_inv = torch.tensor([1 / q_scale_val], dtype=torch.float32, device=qkv_new.device) + + q_fp8, q_scale, split_k_flag = hpc.rope_norm_store_kv_fp8( + key_cache=kcache_fp8, + value_cache=vcache_fp8, + qkv=qkv_new, + cos_sin=cos_sin, + num_seqlen_per_req=num_seqlen_per_req, + q_index=q_index, + kvcache_indices=kv_indices, + is_prefill=True, + k_scale=k_scale, + v_scale=v_scale, + quant_policy=2, + max_seqlens=max_seqlens, + q_scale_inv=q_scale_inv, + q_norm_weight=q_norm_weight if qk_norm_policy > 0 else None, + k_norm_weight=k_norm_weight if qk_norm_policy > 0 else None, + qk_norm_policy=qk_norm_policy, + ) + + assert q_scale is None + q_bf16 = (q_fp8.to(torch.float32) * q_scale_val).to(torch.bfloat16) + + torch_out_q, _, _, _ = torch_rope_norm_blocked_prefill( + kcache_ref, + vcache_ref, + qkv_ref, + cos_sin, + num_seqlen_per_req, + q_index, + kv_indices, + is_prefill=True, + use_qknorm=(qk_norm_policy > 0), + q_norm_weight=q_norm_weight, + k_norm_weight=k_norm_weight, + qk_norm_policy=qk_norm_policy, + clear_kv_tail=True, + ) + + assert allclose(torch_out_q, q_bf16, atol=0.5) + + +@pytest.mark.parametrize("num_req", [8]) +@pytest.mark.parametrize( + "num_q_heads,num_kv_heads,qk_head_dim", + [(8, 1, 128), (64, 8, 128)], +) +@pytest.mark.parametrize("qk_norm_policy", [0, 1, 2]) +def test_rope_norm_store_kv_fp8_decode_sqskv( + num_req, num_q_heads, num_kv_heads, qk_head_dim, qk_norm_policy +): + req_length = torch.randint(20, 200, (num_req,)).tolist() + v_head_dim = qk_head_dim + kv_block_size = 64 + max_num_kv_blocks = 1024 + max_rope_position = 2048 + dtype = torch.bfloat16 + ( + qkv, + num_seqlen_per_req, + cos_sin, + kcache, + vcache, + kv_indices, + q_norm_weight, + k_norm_weight, + ) = prepare_decode_input( + num_req, + req_length, + num_q_heads, + num_kv_heads, + qk_head_dim, + v_head_dim, + kv_block_size, + max_num_kv_blocks, + max_rope_position, + dtype, + ) + + q_index_decode = torch.arange(num_req + 1, dtype=torch.int32, device=qkv.device) + + # Pad to align-8 + qkv, num_seqlen_per_req, q_index_decode, kv_indices, real_rows = pad_decode_inputs_to_align8( + qkv, num_seqlen_per_req, q_index_decode, kv_indices + ) + + qkv_ref = qkv[:real_rows].clone() + kcache_ref = kcache.clone() + vcache_ref = vcache.clone() + q_index_ref = torch.arange(num_req + 1, dtype=torch.int32, device=qkv.device) + + k_scale = torch.tensor([0.1], dtype=torch.float32, device=qkv.device) + v_scale = torch.tensor([0.1], dtype=torch.float32, device=qkv.device) + kcache_fp8 = kcache.to(torch.float8_e4m3fn) + vcache_fp8 = vcache.to(torch.float8_e4m3fn) + + q_scale_val = 2 + q_scale_inv = torch.tensor([1 / q_scale_val], dtype=torch.float32, device=qkv.device) + + q_fp8, q_scale, split_k_flag = hpc.rope_norm_store_kv_fp8( + key_cache=kcache_fp8, + value_cache=vcache_fp8, + qkv=qkv, + cos_sin=cos_sin, + num_seqlen_per_req=num_seqlen_per_req, + q_index=q_index_decode, + kvcache_indices=kv_indices, + is_prefill=False, + k_scale=k_scale, + v_scale=v_scale, + quant_policy=2, + max_seqlens=1, + q_scale_inv=q_scale_inv, + q_norm_weight=q_norm_weight if qk_norm_policy > 0 else None, + k_norm_weight=k_norm_weight if qk_norm_policy > 0 else None, + qk_norm_policy=qk_norm_policy, + ) + + assert q_scale is None + q_bf16 = (q_fp8[:real_rows].to(torch.float32) * q_scale_val).to(torch.bfloat16) + + torch_out_q, _, _, _ = torch_rope_norm_blocked_decode( + kcache_ref, + vcache_ref, + qkv_ref, + cos_sin, + num_seqlen_per_req[:num_req], + q_index_ref, + kv_indices[:num_req], + is_prefill=False, + use_qknorm=(qk_norm_policy > 0), + q_norm_weight=q_norm_weight, + k_norm_weight=k_norm_weight, + qk_norm_policy=qk_norm_policy, + clear_kv_tail=True, + ) + + assert allclose(torch_out_q, q_bf16, atol=0.5) + + +def prepare_mtp_decode_input( + num_req, + req_length, + mtp_steps, + num_q_heads, + num_kv_heads, + qk_head_dim, + v_head_dim, + kv_block_size, + max_num_kv_blocks, + max_rope_position, + dtype=torch.bfloat16, + device="cuda", +): + """Prepare decode input with MTP (multi-token prediction). + + Each request contributes `mtp_steps` rows instead of 1. + req_length[i] is the existing kv length (before this decode step). + The new tokens occupy positions [req_length[i], req_length[i] + mtp_steps). + """ + if req_length is None: + req_length = torch.randint(20, 200, (num_req,)).tolist() + if isinstance(req_length, int): + req_length = [req_length] * num_req + updated_req_length = [x + mtp_steps for x in req_length] + total_rows = num_req * mtp_steps + qkv = torch.randn( + total_rows, + num_q_heads * qk_head_dim + num_kv_heads * qk_head_dim + num_kv_heads * v_head_dim, + dtype=dtype, + device=device, + ) + cos_sin = generate_cos_sin_cache(max_rope_position, qk_head_dim).to( + dtype=torch.float32, device=device + ) + kcache = torch.randn( + max_num_kv_blocks, kv_block_size, num_kv_heads, qk_head_dim, dtype=dtype, device=device + ) + vcache = torch.randn( + max_num_kv_blocks, kv_block_size, num_kv_heads, v_head_dim, dtype=dtype, device=device + ) + kv_indices = generate_kv_block_indices(kcache, updated_req_length).to(device) + q_norm_weight = torch.randn(qk_head_dim, dtype=torch.float32, device=device) + k_norm_weight = torch.randn(qk_head_dim, dtype=torch.float32, device=device) + num_seqlen_per_req = torch.tensor(updated_req_length, dtype=torch.int32, device=device) + q_lengths = [mtp_steps] * num_req + q_cumsum = torch.cumsum(torch.tensor(q_lengths, device=device), dim=0) + q_index = torch.cat((torch.tensor([0], device=device, dtype=q_cumsum.dtype), q_cumsum)).to( + torch.int32 + ) + return ( + qkv, + num_seqlen_per_req, + q_index, + cos_sin, + kcache, + vcache, + kv_indices, + q_norm_weight, + k_norm_weight, + ) + + +@pytest.mark.parametrize("num_req", [8]) +@pytest.mark.parametrize( + "num_q_heads,num_kv_heads,qk_head_dim", + [(8, 1, 128), (64, 8, 128)], +) +@pytest.mark.parametrize("qk_norm_policy", [0, 1, 2]) +@pytest.mark.parametrize("mtp_steps", [1, 2]) +def test_rope_norm_store_kv_mtp_decode( + num_req, num_q_heads, num_kv_heads, qk_head_dim, qk_norm_policy, mtp_steps +): + """MTP decode: each request has mtp_steps tokens (not just 1).""" + req_length = torch.randint(20, 200, (num_req,)).tolist() + v_head_dim = qk_head_dim + kv_block_size = 64 + max_num_kv_blocks = 1024 + max_rope_position = 2048 + dtype = torch.bfloat16 + ( + qkv, + num_seqlen_per_req, + q_index, + cos_sin, + kcache, + vcache, + kv_indices, + q_norm_weight, + k_norm_weight, + ) = prepare_mtp_decode_input( + num_req, + req_length, + mtp_steps, + num_q_heads, + num_kv_heads, + qk_head_dim, + v_head_dim, + kv_block_size, + max_num_kv_blocks, + max_rope_position, + dtype, + ) + # Pad to align-8 + qkv, num_seqlen_per_req, q_index, kv_indices, real_rows = pad_decode_inputs_to_align8( + qkv, num_seqlen_per_req, q_index, kv_indices + ) + + qkv_ref = qkv[:real_rows].clone() + kcache_ref = kcache.clone() + vcache_ref = vcache.clone() + q_index_ref = torch.cat( + ( + torch.tensor([0], device=qkv.device, dtype=torch.int32), + torch.cumsum(torch.tensor([mtp_steps] * num_req, device=qkv.device), dim=0).to( + torch.int32 + ), + ) + ) + + my_out_q = hpc.rope_norm_store_kv( + kcache, + vcache, + qkv, + cos_sin, + num_seqlen_per_req, + q_index, + kv_indices, + False, + q_norm_weight=q_norm_weight if qk_norm_policy > 0 else None, + k_norm_weight=k_norm_weight if qk_norm_policy > 0 else None, + qk_norm_policy=qk_norm_policy, + ) + + # Use prefill reference (handles multi-token per request correctly) + torch_out_q, torch_out_k, torch_kcache, torch_vcache = torch_rope_norm_blocked_prefill( + kcache_ref, + vcache_ref, + qkv_ref, + cos_sin, + num_seqlen_per_req[:num_req], + q_index_ref, + kv_indices[:num_req], + is_prefill=True, + use_qknorm=(qk_norm_policy > 0), + q_norm_weight=q_norm_weight, + k_norm_weight=k_norm_weight, + qk_norm_policy=qk_norm_policy, + clear_kv_tail=True, + ) + + assert allclose(torch_out_q, my_out_q[:real_rows], atol=5e-2) + assert allclose(torch_kcache, kcache, atol=5e-2) + assert allclose(torch_vcache, vcache, atol=5e-2) + + +@pytest.mark.skipif(bool(os.getenv("SANITIZER_CHECK")), reason="skip sanitizer") +@pytest.mark.parametrize("num_req", [8]) +@pytest.mark.parametrize( + "num_q_heads,num_kv_heads,qk_head_dim", + [(8, 1, 128), (64, 8, 128)], +) +@pytest.mark.parametrize("qk_norm_policy", [0, 1, 2]) +@pytest.mark.parametrize("mtp_steps", [1, 2]) +def test_rope_norm_store_kv_fp8_mtp_decode_dqskv( + num_req, num_q_heads, num_kv_heads, qk_head_dim, qk_norm_policy, mtp_steps +): + """MTP decode + FP8 dqskv.""" + req_length = torch.randint(20, 200, (num_req,)).tolist() + v_head_dim = qk_head_dim + kv_block_size = 64 + max_num_kv_blocks = 1024 + max_rope_position = 2048 + dtype = torch.bfloat16 + ( + qkv, + num_seqlen_per_req, + q_index, + cos_sin, + kcache, + vcache, + kv_indices, + q_norm_weight, + k_norm_weight, + ) = prepare_mtp_decode_input( + num_req, + req_length, + mtp_steps, + num_q_heads, + num_kv_heads, + qk_head_dim, + v_head_dim, + kv_block_size, + max_num_kv_blocks, + max_rope_position, + dtype, + ) + # Pad to align-8 + qkv, num_seqlen_per_req, q_index, kv_indices, real_rows = pad_decode_inputs_to_align8( + qkv, num_seqlen_per_req, q_index, kv_indices + ) + + qkv_ref = qkv[:real_rows].clone() + kcache_ref = kcache.clone() + vcache_ref = vcache.clone() + q_index_ref = torch.cat( + ( + torch.tensor([0], device=qkv.device, dtype=torch.int32), + torch.cumsum(torch.tensor([mtp_steps] * num_req, device=qkv.device), dim=0).to( + torch.int32 + ), + ) + ) + + k_scale = torch.tensor([0.1], dtype=torch.float32, device=qkv.device) + v_scale = torch.tensor([0.1], dtype=torch.float32, device=qkv.device) + kcache_fp8 = kcache.to(torch.float8_e4m3fn) + vcache_fp8 = vcache.to(torch.float8_e4m3fn) + + q_fp8, q_scale, split_k_flag = hpc.rope_norm_store_kv_fp8( + key_cache=kcache_fp8, + value_cache=vcache_fp8, + qkv=qkv, + cos_sin=cos_sin, + num_seqlen_per_req=num_seqlen_per_req, + q_index=q_index, + kvcache_indices=kv_indices, + is_prefill=False, + k_scale=k_scale, + v_scale=v_scale, + quant_policy=1, + max_seqlens=mtp_steps, + q_norm_weight=q_norm_weight if qk_norm_policy > 0 else None, + k_norm_weight=k_norm_weight if qk_norm_policy > 0 else None, + qk_norm_policy=qk_norm_policy, + ) + + # q_scale for decode is [num_rows, num_q_heads] + q_bf16 = (q_fp8[:real_rows].to(torch.bfloat16) * q_scale[:real_rows, :, None]).to( + torch.bfloat16 + ) + + torch_out_q, _, _, _ = torch_rope_norm_blocked_prefill( + kcache_ref, + vcache_ref, + qkv_ref, + cos_sin, + num_seqlen_per_req[:num_req], + q_index_ref, + kv_indices[:num_req], + is_prefill=True, + use_qknorm=(qk_norm_policy > 0), + q_norm_weight=q_norm_weight, + k_norm_weight=k_norm_weight, + qk_norm_policy=qk_norm_policy, + clear_kv_tail=True, + ) + + assert allclose(torch_out_q, q_bf16, atol=0.5)