Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions hpc/normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
from torch import Tensor
from typing import Union, Tuple, Optional


def fused_rmsnorm_with_scale(
a: Tensor,
weight: Tensor,
eps: float = torch.finfo(torch.float32).eps,
scale: Tensor = torch.tensor([1], dtype=torch.float32),
is_moe: bool = False,
) -> Union[Tensor, Tuple[Tensor]]:
"""Perform RMSNorm for input and divide scales, output the fp8_e4m3 results.

Executes type conversion in a custom GPU kernel for optimized performance.

Args:
a: Input tensor: [batch_size, hidden_states]. We only support bfloat16 type and hidden_states = 5120 and 320 now.
weight: [1, hidden_states]. Weight in RMSNorm.
eps: a value added to the denominator for numerical stability.
scale: scales for divide.
output_high_precise: bool. Whether output bfloat16 RMSNorm output.
Returns:
New tensor with result RMSNorm(a) / scales in fp8_e4m3 or
(RMSNorm(a) / scales , RMSNorm(a)) if output_high_precise is True
"""
if scale.device != a.device:
scale = scale.to(a.device)
output_fp8, output_fp32, output_fp8_scale2 = torch.ops.hpc.fused_rmsnorm_with_scale(
a, weight, scale, eps, is_moe
)
return (output_fp32, output_fp8, output_fp8_scale2) if is_moe else output_fp8


def fused_rmsnorm_blockwise_quant(
x: Tensor,
weight: Tensor,
eps: float = torch.finfo(torch.float32).eps,
with_blockwise_quant: bool = False,
block_size: int = 128,
dual_output: bool = False,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
"""Perform RMSNorm and BlockWise Quant.
Args:
x: Input tensor
Shape: [num_tokens, dim]
Dtype: torch.bfloat16
weight: Weight for RMSNorm
Shape: [1, dim].
Dtype: torch.bfloat16
eps: a value added to the denominator for numerical stability.
Shape: scalar
Dtype: float
with_blockwise_quant: whether quantinize the output of rmsnorm
block_size: now only support 128
dual_output: if set to true, will return the output of rmsnorm and its quantinization
Returns:
when with_blockwise_quant is True and dual_output is True:
return [rmsnorm(x), quant(rmsnorm(x)), fp32_scale]
when only with_blockwise_quant is True:
return [quant(rmsnorm(x)), fp32_scale]
else:
return [rmsnorm(x)]
"""
assert block_size == 128, "now only support blockwise == 128"

if with_blockwise_quant and dual_output:
y_bf16, y_fp8, y_scale = torch.ops.hpc.fused_rmsnorm_blockwise_quant(
x, weight, eps, with_blockwise_quant, block_size, dual_output
)
return y_bf16, y_fp8, y_scale
elif with_blockwise_quant:
y_fp8, y_scale, _ = torch.ops.hpc.fused_rmsnorm_blockwise_quant(
x, weight, eps, with_blockwise_quant, block_size, dual_output
)
return y_fp8, y_scale
else:
y_bf16, _, _ = torch.ops.hpc.fused_rmsnorm_blockwise_quant(
x, weight, eps, with_blockwise_quant, block_size, dual_output
)
return y_bf16


def fused_rmsnorm_rope(
positions: Tensor,
q: Tensor,
q_weight: Optional[Tensor],
k: Optional[Tensor],
k_weight: Optional[Tensor],
cos_sin_cache: Optional[Tensor],
eps: float = torch.finfo(torch.float32).eps,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Perform RMSNorm and Rope.
Args:
positions: Position indices for each sequence element, used to lookup corresponding rotation angles from cos_sin_cache.
Shape: [batch]
Dtype: torch.int64
q: Input tensor
Shape: [batch, num_q_heads, dim]
Dtype: torch.bfloat16
q_weight: Weight for q in RMSNorm
Shape: [1, dim].
Dtype: torch.bfloat16
k: Input tensor
Shape: [batch, num_k_heads, dim] or [batch, dim]
Dtype: torch.bfloat16
k_weight: Weight for k in RMSNorm
Shape: [1, dim]
Dtype: torch.bfloat16
cos_sin_cache: cos and sin cache for rope, cos_sin_cache should be interleave
Shape: [1, rope_dim]
Dtype: torch.bfloat16
eps: a value added to the denominator for numerical stability.
Shape: scalar
Dtype: float

Returns:
New tensor with result Rope(RMSNorm(q)) and Rope(RMSNorm(k))
"""
return torch.ops.hpc.fused_rmsnorm_rope(positions, q, q_weight, k, k_weight, cos_sin_cache, eps)


@torch.library.register_fake("hpc::fused_rmsnorm_with_scale")
def fused_rmsnorm_with_scale_fake(a, weight, eps, scale, is_moe):
return (
torch.empty_like(a, dtype=torch.float8_e4m3fn),
torch.empty_like(a, dtype=torch.float32),
torch.empty_like(a, dtype=torch.float8_e4m3fn),
)


@torch.library.register_fake("hpc::fused_rmsnorm_blockwise_quant")
def fused_rmsnorm_blockwise_quant_fake(
x, weight, eps, with_blockwise_quant, block_size, dual_output
):
if with_blockwise_quant and dual_output:
return (
torch.empty_like(x, dtype=torch.bfloat16),
torch.empty_like(x, dtype=torch.float8_e4m3fn),
torch.empty_like((x.size(0), x.size(-1) // 128), dtype=torch.y_scale),
)
elif with_blockwise_quant:
return (
torch.empty_like(x, dtype=torch.float8_e4m3fn),
torch.empty_like((x.size(0), x.size(-1) // 128), dtype=torch.y_scale),
)
else:
return torch.empty_like(x, dtype=torch.bfloat16)


@torch.library.register_fake("hpc::fused_rmsnorm_rope")
def fused_rmsnorm_rope_fake(positions, q, q_weight, k, k_weight, cos_sin_cache, eps):
return torch.empty_like(q)
3 changes: 2 additions & 1 deletion src/attention/decode/smallm_splitk_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_splitk_kernel(
int num_blocks = (num_seq_kv + kBlockSize - 1) / kBlockSize;
int num_blocks_per_chunk = (num_seq_per_chunk + kBlockSize - 1) / kBlockSize;

float *lse_batch = lse_ptr + ibatch * kSplitK * num_head_q + ichunk * num_head_q + ihead_kv * heads_per_group;
float *lse_batch =
lse_ptr + ibatch * kSplitK * num_head_q + ichunk * num_head_q + ihead_kv * heads_per_group;

const int *block_ids =
block_ids_ptr + ibatch * num_seq_max_blocks + ichunk * num_blocks_per_chunk;
Expand Down
207 changes: 207 additions & 0 deletions src/normalization/entry.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
// Copyright 2025 hpc-ops authors

#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime_api.h>
#include <torch/all.h>
#include <torch/library.h>

#include <limits>
#include <tuple>

#include "cutlass/float8.h"
#include "src/normalization/fused_rmsnorm_blockwise_quant.h"
#include "src/normalization/fused_rmsnorm_rope.h"
#include "src/normalization/fused_rmsnorm_with_scale.h"

namespace hpc {
namespace normalization {

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> fused_rmsnorm_with_scale_entry(
const torch::Tensor &input, const torch::Tensor &weight, const torch::Tensor &scale, double eps,
bool is_moe) {
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());

TORCH_CHECK(input.scalar_type() == torch::kBFloat16 && weight.scalar_type() == torch::kBFloat16,
"input and weight must be bfloat16.");

torch::Tensor output = torch::empty_like(input, torch::kFloat8_e4m3fn);
torch::Tensor output_fp32 = torch::empty_like(input, torch::kFloat32);
torch::Tensor output_scale2 = torch::empty_like(input, torch::kFloat8_e4m3fn);

void *output_fp32_ptr = nullptr;
void *output_scale2_ptr = nullptr;
if (is_moe) {
output_fp32_ptr = output_fp32.mutable_data_ptr();
output_scale2_ptr = output_scale2.mutable_data_ptr();
}

int hidden_state = input.size(input.dim() - 1);
int batch_size = input.numel() / hidden_state;

const auto *input_ptr = input.const_data_ptr();
const auto *weight_ptr = weight.const_data_ptr();
const auto *scale_ptr = scale.const_data_ptr();
auto *output_ptr = output.mutable_data_ptr();

auto running = fused_rmsnorm_with_scale_async(input_ptr, weight_ptr, output_ptr, output_fp32_ptr,
output_scale2_ptr, scale_ptr, eps, batch_size,
hidden_state, is_moe, stream);

TORCH_CHECK(running, "fused_rmsnorm_with_scale_async launch failed!");

return std::make_tuple(output, output_fp32, output_scale2);
}

std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>>
fused_rmsnorm_blockwise_quant_entry(const torch::Tensor &input, const torch::Tensor &weight,
double eps, const bool with_blockwise_quant,
const int64_t quant_size, const bool dual_output) {
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());

TORCH_CHECK(input.dim() == 2, "input dim must be 2");
TORCH_CHECK(input.size(-1) == 128 || input.size(-1) == 512 || input.size(-1) == 1024 ||
input.size(-1) == 4096,
"now only support dim 128/512/1024/4096");
TORCH_CHECK(input.dtype() == torch::kBFloat16, "input dtype must be bfloat16");
TORCH_CHECK(weight.dtype() == torch::kBFloat16, "weight dtype must be bfloat16");
TORCH_CHECK(weight.size(-1) == input.size(-1), "weight.size(-1) == input.size(-1) must be true");
if (dual_output) {
TORCH_CHECK(with_blockwise_quant == true,
"when dual_output is set, with_blockwise_quant must be true");
}
const auto *input_ptr = input.const_data_ptr();
const auto *weight_ptr = weight.const_data_ptr();

int m = input.size(0);
int hidden_size = input.size(1);
TORCH_CHECK(hidden_size % 128 == 0, "hidden_size % 128 == 0 must be true");

auto options = input.options();

if (with_blockwise_quant) {
auto y_fp8 = torch::empty({m, hidden_size}, options.dtype(torch::kFloat8_e4m3fn));
auto y_scale = torch::empty({m, hidden_size / 128}, options.dtype(torch::kFloat32));
auto *y_fp8_ptr = y_fp8.mutable_data_ptr();
auto *y_scale_ptr = y_scale.mutable_data_ptr();
fused_rmsnorm_blockwise_quant_async(nullptr, y_fp8_ptr, y_scale_ptr, input_ptr, weight_ptr, m,
hidden_size, eps, with_blockwise_quant, stream);
if (dual_output) {
auto y_bf16 = torch::empty({m, hidden_size}, options.dtype(torch::kBFloat16));
auto y_bf16_ptr = y_bf16.mutable_data_ptr();
fused_rmsnorm_blockwise_quant_async(y_bf16_ptr, y_fp8_ptr, y_scale_ptr, input_ptr, weight_ptr,
m, hidden_size, eps, with_blockwise_quant, stream);
return std::make_tuple(y_bf16, y_fp8, y_scale);
}
return std::make_tuple(y_fp8, y_scale, std::nullopt);
} else {
auto y_bf16 = torch::empty({m, hidden_size}, options);
auto *y_bf16_ptr = y_bf16.mutable_data_ptr();
fused_rmsnorm_blockwise_quant_async(y_bf16_ptr, nullptr, nullptr, input_ptr, weight_ptr, m,
hidden_size, eps, with_blockwise_quant, stream);
return std::make_tuple(y_bf16, std::nullopt, std::nullopt);
}
}

std::tuple<torch::Tensor, std::optional<torch::Tensor>> fused_rmsnorm_rope_entry(
const torch::Tensor &positions, const torch::Tensor &q, std::optional<torch::Tensor> q_weight,
std::optional<torch::Tensor> k, std::optional<torch::Tensor> k_weight,
const torch::Tensor &cos_sin_cache, const double eps) {
auto stream = at::cuda::getCurrentCUDAStream(q.get_device());

int num_tokens = q.size(0);
int num_q_heads = q.size(1);
int dim = q.size(2);
int num_k_heads = 1;
TORCH_CHECK(positions.dim() == 1, "position.dim() == 1 must be true");
TORCH_CHECK(positions.size(0) == q.size(0), "positions.size(0) == q.size(0) must be true");
TORCH_CHECK(q.dim() == 3, "q.dim() == 3 must be true");
TORCH_CHECK(q.dtype() == torch::kBFloat16 || q.dtype() == torch::kFloat32,
"now only support bfloat16 or float32 for q");
TORCH_CHECK(positions.dtype() == torch::kInt64, "positions dtype must be int64");
if (q_weight.has_value()) {
TORCH_CHECK(q_weight.value().size(-1) == q.size(-1),
"q_weight.size(-1) == q.size(-1) must be true");
TORCH_CHECK(q_weight.value().dtype() == q.dtype(), "q_weight must has same dtype with q");
}
if (k.has_value()) {
TORCH_CHECK(k.value().dtype() == q.dtype(), "k must has same dtype with q")
TORCH_CHECK(k.value().dim() == 3, "k.dim() must be 3");
TORCH_CHECK(k.value().size(-1) == q.size(-1), "k.size(-1) == q.size(-1) must be true");
TORCH_CHECK(k.value().size(0) == q.size(0), "k.size(0) == q.size(0) must be true");
num_k_heads = k.value().size(1);
TORCH_CHECK(num_q_heads >= num_k_heads, "now only support num_q_heads >= num_k_heads");
}
if (k_weight.has_value()) {
TORCH_CHECK(k.has_value(), "when k_weight is given, k must be provided");
TORCH_CHECK(k_weight.value().size(-1) == k.value().size(-1),
"k_weight.size(-1) == k.size(-1) must be true");
TORCH_CHECK(k_weight.value().dtype() == k.value().dtype(),
"k_weight must has same dtype with k");
}

TORCH_CHECK(cos_sin_cache.dtype() == torch::kBFloat16, "cos_sim_cache's dtype must be bfloat16");

TORCH_CHECK(dim == 128 || dim == 512, "now only support dim 128/512");
int rope_dim = cos_sin_cache.size(-1);

auto options = q.options();
// output always is bfloat16
auto y_q = torch::empty({num_tokens, num_q_heads, dim}, options.dtype(torch::kBFloat16));

const auto *q_ptr = q.const_data_ptr();
const auto *pos_ptr = positions.const_data_ptr();
auto *y_q_ptr = y_q.mutable_data_ptr();
void *y_k_ptr = nullptr;
const void *q_weight_ptr = nullptr;
const void *k_ptr = nullptr;
const void *k_weight_ptr = nullptr;
const auto *cos_sin_ptr = cos_sin_cache.const_data_ptr();

if (q_weight.has_value()) {
q_weight_ptr = q_weight.value().const_data_ptr();
}
if (k.has_value()) {
if (k_weight.has_value()) {
k_weight_ptr = k_weight.value().const_data_ptr();
}
k_ptr = k.value().const_data_ptr();
// output always is bfloat16
auto y_k = torch::empty({num_tokens, num_k_heads, dim}, options.dtype(torch::kBFloat16));
auto *y_k_ptr = y_k.mutable_data_ptr();
int dtype = q.dtype() == torch::kBFloat16 ? 0 : 1;
fused_rmsnorm_rope_async(y_q_ptr, y_k_ptr, q_ptr, q_weight_ptr, k_ptr, k_weight_ptr, pos_ptr,
cos_sin_ptr, num_tokens, dim, rope_dim, num_q_heads, num_k_heads, eps,
dtype, stream);
return std::make_tuple(y_q, y_k);
}
int dtype = q.dtype() == torch::kBFloat16 ? 0 : 1;
fused_rmsnorm_rope_async(y_q_ptr, y_k_ptr, q_ptr, q_weight_ptr, k_ptr, k_weight_ptr, pos_ptr,
cos_sin_ptr, num_tokens, dim, rope_dim, num_q_heads, num_k_heads, eps,
dtype, stream);
return std::make_tuple(y_q, std::nullopt);
}
} // namespace normalization
} // namespace hpc

TORCH_LIBRARY_FRAGMENT(hpc, m) {
m.def(
"fused_rmsnorm_with_scale(Tensor input, Tensor weight, Tensor scale, float eps, bool "
"is_moe) -> (Tensor, Tensor, Tensor)");
m.impl("fused_rmsnorm_with_scale", torch::kCUDA,
&hpc::normalization::fused_rmsnorm_with_scale_entry);

m.def(
"fused_rmsnorm_blockwise_quant(Tensor input, Tensor weight,"
"float eps, bool with_blockwise_quant, int block_size, bool dual_output) -> (Tensor, Tensor "
"?, Tensor ?)");
m.impl("fused_rmsnorm_blockwise_quant", torch::kCUDA,
&hpc::normalization::fused_rmsnorm_blockwise_quant_entry);

m.def(
"fused_rmsnorm_rope(Tensor positions, Tensor q, Tensor ? q_weight,"
"Tensor ? k, Tensor ? k_weight, Tensor cos_sin_cache, float eps) -> (Tensor, Tensor ?)");
m.impl("fused_rmsnorm_rope", torch::kCUDA, &hpc::normalization::fused_rmsnorm_rope_entry);
}
Loading