diff --git a/hpc/normalization.py b/hpc/normalization.py new file mode 100644 index 0000000..f5ac97d --- /dev/null +++ b/hpc/normalization.py @@ -0,0 +1,53 @@ +import torch +from torch import Tensor +from typing import Union, Tuple, Optional + + +def fused_rmsnorm_with_scale( + a: Tensor, + weight: Tensor, + eps: float = torch.finfo(torch.float32).eps, + scale: Tensor = torch.tensor([1], dtype=torch.float32), + is_moe: bool = False, +) -> Union[Tensor, Tuple[Tensor]]: + """Perform RMSNorm for input and divide scales, output the fp8_e4m3 results. + + Executes type conversion in a custom GPU kernel for optimized performance. + + Args: + a: Input tensor. We only support bfloat16 type and hidden_states = 5120/4096/320 now. + Shape: [batch_size, hidden_states]. + Dtype: torch.bfloat16 + weight: Weight in RMSNorm. + Shape: [hidden_states]. + Dtype: torch.bfloat16. + eps: a value added to the denominator for numerical stability. + Shape: scalar + Dtype: float + scale: scales for divide. + Shape: [1] or [2] + Dtype: float + is_moe: Whether the operation after this rmsnorm is moe, + if is True, the scale shape is [2], + Returns: + if is_moe is True, return (RMSNorm(a), RMSNorm(a) / scale[0], RMSNorm(a) / scale[1]) + else return RMSNorm(a) / scale[0] + """ + if scale.device != a.device: + scale = scale.to(a.device) + output_fp8, output_fp32, output_fp8_scale2 = torch.ops.hpc.fused_rmsnorm_with_scale( + a, weight, scale, eps, is_moe + ) + return (output_fp32, output_fp8, output_fp8_scale2) if is_moe else output_fp8 + + +@torch.library.register_fake("hpc::fused_rmsnorm_with_scale") +def fused_rmsnorm_with_scale_fake(a, weight, eps, scale, is_moe): + if is_moe: + return ( + torch.empty_like(a, dtype=torch.float32), + torch.empty_like(a, dtype=torch.float8_e4m3fn), + torch.empty_like(a, dtype=torch.float8_e4m3fn), + ) + else: + return torch.empty_like(a, dtype=torch.float8_e4m3fn) diff --git a/src/attention/decode/smallm_splitk_kernels.cuh b/src/attention/decode/smallm_splitk_kernels.cuh index 5f7cfa6..b12b973 100644 --- a/src/attention/decode/smallm_splitk_kernels.cuh +++ b/src/attention/decode/smallm_splitk_kernels.cuh @@ -256,7 +256,8 @@ __global__ void attention_decode_bf16_multistage_ws_smallm_splitk_kernel( int num_blocks = (num_seq_kv + kBlockSize - 1) / kBlockSize; int num_blocks_per_chunk = (num_seq_per_chunk + kBlockSize - 1) / kBlockSize; - float *lse_batch = lse_ptr + ibatch * kSplitK * num_head_q + ichunk * num_head_q + ihead_kv * heads_per_group; + float *lse_batch = + lse_ptr + ibatch * kSplitK * num_head_q + ichunk * num_head_q + ihead_kv * heads_per_group; const int *block_ids = block_ids_ptr + ibatch * num_seq_max_blocks + ichunk * num_blocks_per_chunk; diff --git a/src/normalization/entry.cc b/src/normalization/entry.cc new file mode 100644 index 0000000..d05bc6f --- /dev/null +++ b/src/normalization/entry.cc @@ -0,0 +1,65 @@ +// Copyright 2025 hpc-ops authors + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "cutlass/float8.h" +#include "src/normalization/fused_rmsnorm_with_scale.h" + +namespace hpc { +namespace normalization { + +std::tuple fused_rmsnorm_with_scale_entry( + const torch::Tensor &input, const torch::Tensor &weight, const torch::Tensor &scale, double eps, + bool is_moe) { + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + TORCH_CHECK(input.scalar_type() == torch::kBFloat16 && weight.scalar_type() == torch::kBFloat16, + "input and weight must be bfloat16."); + + torch::Tensor output = torch::empty_like(input, torch::kFloat8_e4m3fn); + torch::Tensor output_fp32 = torch::empty_like(input, torch::kFloat32); + torch::Tensor output_scale2 = torch::empty_like(input, torch::kFloat8_e4m3fn); + + void *output_fp32_ptr = nullptr; + void *output_scale2_ptr = nullptr; + if (is_moe) { + output_fp32_ptr = output_fp32.mutable_data_ptr(); + output_scale2_ptr = output_scale2.mutable_data_ptr(); + } + + int hidden_state = input.size(input.dim() - 1); + int batch_size = input.numel() / hidden_state; + + const auto *input_ptr = input.const_data_ptr(); + const auto *weight_ptr = weight.const_data_ptr(); + const auto *scale_ptr = scale.const_data_ptr(); + auto *output_ptr = output.mutable_data_ptr(); + + auto running = fused_rmsnorm_with_scale_async(input_ptr, weight_ptr, output_ptr, output_fp32_ptr, + output_scale2_ptr, scale_ptr, eps, batch_size, + hidden_state, is_moe, stream); + + TORCH_CHECK(running, "fused_rmsnorm_with_scale_async launch failed!"); + + return std::make_tuple(output, output_fp32, output_scale2); +} + +} // namespace normalization +} // namespace hpc + +TORCH_LIBRARY_FRAGMENT(hpc, m) { + m.def( + "fused_rmsnorm_with_scale(Tensor input, Tensor weight, Tensor scale, float eps, bool " + "is_moe) -> (Tensor, Tensor, Tensor)"); + m.impl("fused_rmsnorm_with_scale", torch::kCUDA, + &hpc::normalization::fused_rmsnorm_with_scale_entry); +} diff --git a/src/normalization/fused_rmsnorm_with_scale.cu b/src/normalization/fused_rmsnorm_with_scale.cu new file mode 100644 index 0000000..eb32898 --- /dev/null +++ b/src/normalization/fused_rmsnorm_with_scale.cu @@ -0,0 +1,228 @@ +// Copyright 2025 hpc-ops authors +#include + +#include + +#include "src/normalization/fused_rmsnorm_with_scale.h" +#include "src/utils/utils.cuh" + +namespace hpc { +namespace normalization { + +namespace kernels { + +template +__global__ void fused_rmsnorm_with_scale(const __nv_bfloat16 *input_ptr, + const __nv_bfloat16 *weight_ptr, float *output_ptr_fp32, + __nv_fp8_e4m3 *output_ptr_fp8, + __nv_fp8_e4m3 *output_ptr_fp8_scale2, const float *scale, + float eps, int batch_size) { + constexpr int kWarpSize = 32; + constexpr int kItemPer16B = 8; + constexpr float kInvHiddenStates = 1.0f / kHiddenStates; + constexpr int kIterPerBatch = (kHiddenStates + kWarpPerBatch * kWarpSize * kItemPer16B - 1) / + (kWarpPerBatch * kWarpSize * kItemPer16B); + float inv_scale = rcpf_ftz(scale[0]); + float inv_scale2 = 1; + if constexpr (kIsMoe) { + inv_scale2 = rcpf_ftz(scale[1]); + inv_scale *= scale[1]; + } + + const int iwarp = threadIdx.x / kWarpSize; + const int ilane = threadIdx.x % kWarpSize; + + const uint64_t ibatch = blockIdx.x * kBatchPerBlock + iwarp / kWarpPerBatch; + const uint64_t icol_thr_offset = ((iwarp % kWarpPerBatch) * kWarpSize + ilane) * kItemPer16B; + + __shared__ float smem_sum[kWarpPerBatch]; + + vec_t reg_input[kIterPerBatch]; + vec_t reg_weight[kIterPerBatch]; + +#pragma unroll + for (int i = 0; i < kIterPerBatch; i++) { +#pragma unroll + for (int j = 0; j < kItemPer16B; j++) { + reg_input[i][j] = 0; + reg_weight[i][j] = 0; + } + } + +#pragma unroll + for (int iter = 0; iter < kIterPerBatch; iter++) { + uint64_t icol = icol_thr_offset + iter * kWarpPerBatch * kWarpSize * kItemPer16B; + + if (icol < kHiddenStates) { + reg_weight[iter] = to(load<__nv_bfloat162, kItemPer16B / 2>(&weight_ptr[icol])); + } + } + + float local_mean = 0.0f; + + if (ibatch < batch_size) { +#pragma unroll + for (int iter = 0; iter < kIterPerBatch; iter++) { + uint64_t icol = icol_thr_offset + iter * kWarpPerBatch * kWarpSize * kItemPer16B; + + if (icol < kHiddenStates) { + reg_input[iter] = to( + load<__nv_bfloat162, kItemPer16B / 2>(&input_ptr[ibatch * kHiddenStates + icol])); +#pragma unroll + for (int i = 0; i < kItemPer16B; i++) { + local_mean += reg_input[iter][i] * reg_input[iter][i]; + } + } + } + } + + // warp reduce + local_mean = warp_reduce_sum_xor(local_mean); + + if (ilane == 0) { + smem_sum[iwarp] = local_mean; + } + __syncthreads(); + + int first_warp_in_batch = (iwarp / kWarpPerBatch) * kWarpPerBatch; +#pragma unroll + for (int iwarp_in_batch = 0; iwarp_in_batch < kWarpPerBatch; iwarp_in_batch++) { + int reduce_warp = first_warp_in_batch + iwarp_in_batch; + if (iwarp != reduce_warp) { + local_mean += smem_sum[reduce_warp]; + } + } + + local_mean = rsqrtf_ftz(local_mean * kInvHiddenStates + eps); + + if constexpr (!kIsMoe) { + local_mean *= inv_scale; + } + + if (ibatch < batch_size) { +#pragma unroll + for (int iter = 0; iter < kIterPerBatch; iter++) { + uint64_t icol = icol_thr_offset + iter * kWarpPerBatch * kWarpSize * kItemPer16B; + uint64_t istore = ibatch * kHiddenStates + icol; + + if (icol < kHiddenStates) { +#pragma unroll + for (int i = 0; i < kItemPer16B; i++) { + reg_input[iter][i] = reg_input[iter][i] * reg_weight[iter][i] * local_mean; + } + + auto &split_input = reshape<2, kItemPer16B / 2>(reg_input[iter]); + if constexpr (kIsMoe) { + store(&output_ptr_fp32[istore], split_input[0]); + store(&output_ptr_fp32[istore + kItemPer16B / 2], split_input[1]); +#pragma unroll + for (int i = 0; i < kItemPer16B; i++) { + reg_input[iter][i] *= inv_scale2; + } + + store(&output_ptr_fp8_scale2[istore], to<__nv_fp8x4_e4m3>(split_input[0])); + store(&output_ptr_fp8_scale2[istore + kItemPer16B / 2], + to<__nv_fp8x4_e4m3>(split_input[1])); +#pragma unroll + for (int i = 0; i < kItemPer16B; i++) { + reg_input[iter][i] *= inv_scale; + } + } + + store(&output_ptr_fp8[istore], to<__nv_fp8x4_e4m3>(split_input[0])); + store(&output_ptr_fp8[istore + kItemPer16B / 2], to<__nv_fp8x4_e4m3>(split_input[1])); + } + } + } +} + +} // namespace kernels + +bool fused_rmsnorm_with_scale_async(const void *input_ptr, const void *weight_ptr, void *output_ptr, + void *output_fp32_ptr, void *output_fp8_scale2_ptr, + const void *scale, float eps, int batch_size, int hidden_states, + bool is_moe, cudaStream_t stream) { + constexpr int kWarpCount = 4; + constexpr int kWarpSize = 32; + + using Tin = const __nv_bfloat16; + using Tout = __nv_fp8_e4m3; + + dim3 block(kWarpSize * kWarpCount); + if (hidden_states == 5120) { + constexpr int kHiddenStates = 5120; + constexpr int kWarpPerBatch = 4; + constexpr int kBatchPerBlock = kWarpCount / kWarpPerBatch; + dim3 grid((batch_size + kBatchPerBlock - 1) / kBatchPerBlock); + if (is_moe) { + constexpr bool kIsMoe = true; + kernels::fused_rmsnorm_with_scale + <<>>( + reinterpret_cast(input_ptr), reinterpret_cast(weight_ptr), + reinterpret_cast(output_fp32_ptr), reinterpret_cast(output_ptr), + reinterpret_cast(output_fp8_scale2_ptr), + reinterpret_cast(scale), eps, batch_size); + } else { + constexpr bool kIsMoe = false; + kernels::fused_rmsnorm_with_scale + <<>>( + reinterpret_cast(input_ptr), reinterpret_cast(weight_ptr), + reinterpret_cast(output_fp32_ptr), reinterpret_cast(output_ptr), + reinterpret_cast(output_fp8_scale2_ptr), + reinterpret_cast(scale), eps, batch_size); + } + } else if (hidden_states == 4096) { + constexpr int kHiddenStates = 4096; + constexpr int kWarpPerBatch = 4; + constexpr int kBatchPerBlock = kWarpCount / kWarpPerBatch; + dim3 grid((batch_size + kBatchPerBlock - 1) / kBatchPerBlock); + if (is_moe) { + constexpr bool kIsMoe = true; + kernels::fused_rmsnorm_with_scale + <<>>( + reinterpret_cast(input_ptr), reinterpret_cast(weight_ptr), + reinterpret_cast(output_fp32_ptr), reinterpret_cast(output_ptr), + reinterpret_cast(output_fp8_scale2_ptr), + reinterpret_cast(scale), eps, batch_size); + } else { + constexpr bool kIsMoe = false; + kernels::fused_rmsnorm_with_scale + <<>>( + reinterpret_cast(input_ptr), reinterpret_cast(weight_ptr), + reinterpret_cast(output_fp32_ptr), reinterpret_cast(output_ptr), + reinterpret_cast(output_fp8_scale2_ptr), + reinterpret_cast(scale), eps, batch_size); + } + } else if (hidden_states == 320) { + constexpr int kHiddenStates = 320; + constexpr int kWarpPerBatch = 1; + constexpr int kBatchPerBlock = kWarpCount / kWarpPerBatch; + dim3 grid((batch_size + kBatchPerBlock - 1) / kBatchPerBlock); + if (is_moe) { + constexpr bool kIsMoe = true; + kernels::fused_rmsnorm_with_scale + <<>>( + reinterpret_cast(input_ptr), reinterpret_cast(weight_ptr), + reinterpret_cast(output_fp32_ptr), reinterpret_cast(output_ptr), + reinterpret_cast(output_fp8_scale2_ptr), + reinterpret_cast(scale), eps, batch_size); + } else { + constexpr bool kIsMoe = false; + kernels::fused_rmsnorm_with_scale + <<>>( + reinterpret_cast(input_ptr), reinterpret_cast(weight_ptr), + reinterpret_cast(output_fp32_ptr), reinterpret_cast(output_ptr), + reinterpret_cast(output_fp8_scale2_ptr), + reinterpret_cast(scale), eps, batch_size); + } + } else { + std::cout << "not supported hidden_size for fused_rmsnorm_with_scale_async:" << hidden_states + << std::endl; + return false; + } + + return true; +} + +} // namespace normalization +} // namespace hpc diff --git a/src/normalization/fused_rmsnorm_with_scale.h b/src/normalization/fused_rmsnorm_with_scale.h new file mode 100644 index 0000000..5e63494 --- /dev/null +++ b/src/normalization/fused_rmsnorm_with_scale.h @@ -0,0 +1,27 @@ +// Copyright 2025 hpc-ops authors + +#ifndef SRC_NORMALIZATION_FUSED_RMSNORM_WITH_SCALE_H_ +#define SRC_NORMALIZATION_FUSED_RMSNORM_WITH_SCALE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace hpc { +namespace normalization { + +bool fused_rmsnorm_with_scale_async(const void* input_ptr, const void* weight_ptr, void* output_ptr, + void* output_fp32_ptr, void* output_fp8_scale2_ptr, + const void* scale, float eps, int batch_size, int hidden_state, + bool is_moe, cudaStream_t stream); + +} // namespace normalization +} // namespace hpc + +#endif // SRC_NORMALIZATION_FUSED_RMSNORM_WITH_SCALE_H_ diff --git a/tests/test_normalization.py b/tests/test_normalization.py new file mode 100644 index 0000000..52ba9ab --- /dev/null +++ b/tests/test_normalization.py @@ -0,0 +1,66 @@ +import sys +import os +import pytest +from pathlib import Path + +sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0])) + +import hpc +import torch +from utils import allclose + + +def reference_torch_rmsnorm_with_scale(x, weight, scale, eps): + rms = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps) + x_normalized = x * rms + if weight is not None: + x_normalized = x_normalized * weight.float() + inv_scale = 1.0 / scale + return (x_normalized * inv_scale).to(torch.float8_e4m3fn).to(torch.bfloat16) + + +# torch impl for rmsnorm +def reference_torch_rmsnorm(x, weight, eps): + rms = torch.rsqrt(torch.mean(x.float().pow(2), dim=-1, keepdim=True) + eps) + x_normalized = x * rms + if weight is not None: + x_normalized = x_normalized * weight.float() + return x_normalized + + +@pytest.mark.parametrize("batch_size", [1, 2, 4, 5, 8, 14, 16, 17, 32, 64]) +@pytest.mark.parametrize("hidden_states", [5120, 320, 4096]) +@pytest.mark.parametrize("scale", [2.5]) +@pytest.mark.parametrize("is_moe", [False, True]) +def test_fused_rmsnorm_with_scale(batch_size, hidden_states, scale, is_moe): + torch.manual_seed(0) + rmsnorm_weight = torch.rand((1, hidden_states), dtype=torch.bfloat16).cuda() + x = torch.randn(batch_size, hidden_states, dtype=torch.bfloat16).cuda() + if is_moe: + scale_tensor = torch.tensor([scale, 2 * scale], dtype=torch.float32).cuda() + else: + scale_tensor = torch.tensor([scale], dtype=torch.float32).cuda() + eps = 1e-6 + + if not rmsnorm_weight.is_contiguous(): + rmsnorm_weight = rmsnorm_weight.contiguous() + + gt = reference_torch_rmsnorm_with_scale(x, rmsnorm_weight, scale_tensor[0], eps) + if is_moe: + gt_2 = reference_torch_rmsnorm_with_scale(x, rmsnorm_weight, scale_tensor[1], eps) + else: + gt_2 = gt + + gt_fp32 = reference_torch_rmsnorm(x, rmsnorm_weight, eps) + output = hpc.normalization.fused_rmsnorm_with_scale( + x, rmsnorm_weight, scale=scale_tensor, eps=eps, is_moe=is_moe + ) + + if is_moe: + y_fp32, y_fp8, y_fp8_2 = output + else: + y_fp32, y_fp8, y_fp8_2 = gt_fp32, output, output + + assert allclose(gt_fp32, y_fp32) + assert allclose(gt_2, y_fp8_2.to(torch.bfloat16), atol=0.15, rtol=0.0125) + assert allclose(gt, y_fp8.to(torch.bfloat16), atol=0.15, rtol=0.0125)