Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
f639c6e
add rmsnorm perf benchmark
aris134 May 11, 2026
b5720e9
add rmsnorm perf benchmark and missing tuned DS configs
aris134 May 11, 2026
6c2cd28
add missing tuned shape for Qwen and update benchmark
aris134 May 12, 2026
912c62b
move benchmark to benchmarks folder and rewrite in google benchmark s…
aris134 May 12, 2026
c24a091
add matching tuned configs for layernorm
aris134 May 12, 2026
3d8e1de
add fallback warning print if tuned config not found for normalization
aris134 May 12, 2026
856346d
add fallback warning message when tuned normalization kernel config n…
aris134 May 12, 2026
d5293cb
uncomment qwen configs
aris134 May 12, 2026
78f9aa5
generalization rms norm benchmark to also include layer norm, and add…
aris134 May 12, 2026
4256e3c
remove redundant synchronization before gpu warmup in bench_normaliza…
aris134 May 15, 2026
2f9ff47
address nit: move unordered_set to after unordered_map
aris134 May 15, 2026
d548d54
share normalization key bit layout constants
aris134 May 15, 2026
b73062c
remove unneeded line splits and deduplicate norm benchmark epsilon
aris134 May 18, 2026
0949b9a
restore norm key encoding and document decode coupling
aris134 May 18, 2026
e8afcc8
use more optimal BYTES_PER_LDG=16 for layer norm backward tuned confi…
aris134 May 18, 2026
3d22d82
use more optimal 7 warps config for H=7168, for fp16/fp32 layernorm f…
aris134 May 18, 2026
63d0e26
mirror layernorm bwd h=7168 uned config to rmsnorm
aris134 May 18, 2026
db7b017
Add normalization key layout round-trip check
aris134 May 18, 2026
62d38e1
Merge remote-tracking branch 'origin/dev' into amartin/rmsnorm
aris134 May 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,4 @@ add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp)
add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp)
add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp)
add_te_benchmark(bench_casttranspose cast/bench_casttranspose.cpp)
add_te_benchmark(bench_normalization normalization/bench_normalization.cpp)
293 changes: 293 additions & 0 deletions benchmarks/cpp/normalization/bench_normalization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
/*************************************************************************
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/

#include <benchmark/benchmark.h>
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>

#include "benchmark_utils.h"

#include "transformer_engine/normalization_hip.h"
#include "transformer_engine/transformer_engine_hip.h"

using namespace te_bench;
using namespace transformer_engine;

#define NORM_SHAPES \
->Args({8192, 128}) \
->Args({8192, 1536}) \
->Args({8192, 7168})

constexpr float kNormEpsilon = 1e-5f;

enum class BenchNormType {
LayerNorm,
RMSNorm,
};

template <typename T>
constexpr DType dtype_of() {
if constexpr (std::is_same_v<T, float>) {
return DType::kFloat32;
} else if constexpr (std::is_same_v<T, hip_bfloat16>) {
return DType::kBFloat16;
} else {
return DType::kFloat16;
}
}

template <BenchNormType Norm, typename WType, typename IType, typename OType, typename CType>
static void BM_NormForward(benchmark::State& state) {
const size_t N = state.range(0);
const size_t H = state.range(1);
constexpr bool zero_centered_gamma = false;

const DType wtype = dtype_of<WType>();
const DType itype = dtype_of<IType>();
const DType otype = dtype_of<OType>();

test::Tensor input("input", std::vector<size_t>{N, H}, itype);
test::Tensor output("output", std::vector<size_t>{N, H}, otype);
test::Tensor gamma("gamma", std::vector<size_t>{H}, wtype);
test::Tensor beta("beta", std::vector<size_t>{H}, wtype);
test::Tensor mu("mu", std::vector<size_t>{N}, DType::kFloat32);
test::Tensor rsigma("rsigma", std::vector<size_t>{N}, DType::kFloat32);
test::Tensor workspace;

test::fillUniform(&input);
test::fillUniform(&gamma);
test::fillUniform(&beta);
test::setRandomScale(&output);

hipDeviceProp_t prop;
HIP_CHECK(hipGetDeviceProperties(&prop, 0));

hipStream_t stream;
HIP_CHECK(hipStreamCreate(&stream));

if constexpr (Norm == BenchNormType::LayerNorm) {
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon,
output.data(), mu.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
} else {
nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon,
output.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
}

workspace = test::Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());

if constexpr (Norm == BenchNormType::LayerNorm) {
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon,
output.data(), mu.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
} else {
nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon,
output.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
}

warmup_gpu();

hipEvent_t start, stop;
HIP_CHECK(hipEventCreate(&start));
HIP_CHECK(hipEventCreate(&stop));

for (auto _ : state) {
HIP_CHECK(hipEventRecord(start, stream));

if constexpr (Norm == BenchNormType::LayerNorm) {
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon,
output.data(), mu.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
} else {
nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon,
output.data(), rsigma.data(), workspace.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
}

HIP_CHECK(hipEventRecord(stop, stream));
HIP_CHECK(hipEventSynchronize(stop));

float ms = 0.0f;
HIP_CHECK(hipEventElapsedTime(&ms, start, stop));
state.SetIterationTime(ms / 1000.0);
}

HIP_CHECK(hipEventDestroy(start));
HIP_CHECK(hipEventDestroy(stop));
HIP_CHECK(hipStreamDestroy(stream));

size_t bytes_read = N * H * sizeof(IType) + H * sizeof(WType);

size_t bytes_write = N * H * sizeof(OType) + N * sizeof(float);

if constexpr (Norm == BenchNormType::LayerNorm) {
bytes_read += H * sizeof(WType); // beta
bytes_write += N * sizeof(float); // mu
}

set_bytes_processed(state, bytes_read + bytes_write);
}

template <BenchNormType Norm, typename WType, typename IType, typename OType, typename CType>
static void BM_NormBackward(benchmark::State& state) {
const size_t N = state.range(0);
const size_t H = state.range(1);
constexpr bool zero_centered_gamma = false;

const DType wtype = dtype_of<WType>();
const DType itype = dtype_of<IType>();
const DType otype = dtype_of<OType>();

test::Tensor input("input", std::vector<size_t>{N, H}, itype);
test::Tensor output("output", std::vector<size_t>{N, H}, otype);
test::Tensor gamma("gamma", std::vector<size_t>{H}, wtype);
test::Tensor beta("beta", std::vector<size_t>{H}, wtype);
test::Tensor mu("mu", std::vector<size_t>{N}, DType::kFloat32);
test::Tensor rsigma("rsigma", std::vector<size_t>{N}, DType::kFloat32);
test::Tensor dz("dz", std::vector<size_t>{N, H}, otype);
test::Tensor dx("dx", std::vector<size_t>{N, H}, itype);
test::Tensor dgamma("dgamma", std::vector<size_t>{H}, wtype);
test::Tensor dbeta("dbeta", std::vector<size_t>{H}, wtype);
test::Tensor workspace_fwd;
test::Tensor workspace_bwd;

test::fillUniform(&input);
test::fillUniform(&gamma);
test::fillUniform(&beta);
test::setRandomScale(&output);
test::fillUniform(&dz);

hipDeviceProp_t prop;
HIP_CHECK(hipGetDeviceProperties(&prop, 0));

hipStream_t stream;
HIP_CHECK(hipStreamCreate(&stream));

if constexpr (Norm == BenchNormType::LayerNorm) {
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon,
output.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
} else {
nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon,
output.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
}

workspace_fwd = test::Tensor("workspace_fwd",
workspace_fwd.rowwise_shape(),
workspace_fwd.dtype());

if constexpr (Norm == BenchNormType::LayerNorm) {
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon,
output.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);

nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
} else {
nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon,
output.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);

nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
}

workspace_bwd = test::Tensor("workspace_bwd",
workspace_bwd.rowwise_shape(),
workspace_bwd.dtype());

if constexpr (Norm == BenchNormType::LayerNorm) {
nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
} else {
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
}

warmup_gpu();

hipEvent_t start, stop;
HIP_CHECK(hipEventCreate(&start));
HIP_CHECK(hipEventCreate(&stop));

for (auto _ : state) {
HIP_CHECK(hipEventRecord(start, stream));

if constexpr (Norm == BenchNormType::LayerNorm) {
nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
} else {
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, stream);
}

HIP_CHECK(hipEventRecord(stop, stream));
HIP_CHECK(hipEventSynchronize(stop));

float ms = 0.0f;
HIP_CHECK(hipEventElapsedTime(&ms, start, stop));
state.SetIterationTime(ms / 1000.0);
}

HIP_CHECK(hipEventDestroy(start));
HIP_CHECK(hipEventDestroy(stop));
HIP_CHECK(hipStreamDestroy(stream));

size_t bytes_read =
N * H * sizeof(OType) + // dz
N * H * sizeof(IType) + // x
N * sizeof(float) + // rsigma
H * sizeof(WType); // gamma

size_t bytes_write =
N * H * sizeof(IType) + // dx
H * sizeof(WType); // dgamma

if constexpr (Norm == BenchNormType::LayerNorm) {
bytes_read += N * sizeof(float); // mu
bytes_write += H * sizeof(WType); // dbeta
}

set_bytes_processed(state, bytes_read + bytes_write);
}

#define REGISTER_NORM_BENCH(NORM_ENUM, NORM_NAME, WTYPE, ITYPE, OTYPE, CTYPE, NAME) \
BENCHMARK_TEMPLATE(BM_NormForward, NORM_ENUM, WTYPE, ITYPE, OTYPE, CTYPE) \
->Name("BM_" NORM_NAME "Forward/" NAME) \
NORM_SHAPES \
->Unit(benchmark::kMicrosecond) \
->UseManualTime(); \
BENCHMARK_TEMPLATE(BM_NormBackward, NORM_ENUM, WTYPE, ITYPE, OTYPE, CTYPE) \
->Name("BM_" NORM_NAME "Backward/" NAME) \
NORM_SHAPES \
->Unit(benchmark::kMicrosecond) \
->UseManualTime();

#define REGISTER_RMSNORM(WTYPE, ITYPE, OTYPE, CTYPE, NAME) \
REGISTER_NORM_BENCH(BenchNormType::RMSNorm, "RMSNorm", WTYPE, ITYPE, OTYPE, CTYPE, NAME)

#define REGISTER_LAYERNORM(WTYPE, ITYPE, OTYPE, CTYPE, NAME) \
REGISTER_NORM_BENCH(BenchNormType::LayerNorm, "LayerNorm", WTYPE, ITYPE, OTYPE, CTYPE, NAME)

REGISTER_RMSNORM(hip_bfloat16, hip_bfloat16, hip_bfloat16, float, "BF16_BF16_BF16_FP32")
REGISTER_RMSNORM(half, half, half, float, "FP16_FP16_FP16_FP32")
REGISTER_RMSNORM(float, float, float, float, "FP32_FP32_FP32_FP32")

REGISTER_LAYERNORM(hip_bfloat16, hip_bfloat16, hip_bfloat16, float, "BF16_BF16_BF16_FP32")
REGISTER_LAYERNORM(half, half, half, float, "FP16_FP16_FP16_FP32")
REGISTER_LAYERNORM(float, float, float, float, "FP32_FP32_FP32_FP32")

BENCHMARK_MAIN();
21 changes: 20 additions & 1 deletion transformer_engine/common/normalization/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
}
#endif //#ifndef __HIP_PLATFORM_AMD__

// Keep this bit layout in sync with the decode helpers in common.h.
TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma,
Expand All @@ -67,6 +68,25 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
return std::make_tuple(general_key, batch_size, hidden_size, is_tuned);
}

namespace {

[[maybe_unused]] const bool kNormKeyLayoutCheck = [] {
const uint64_t key = std::get<0>(get_key(
NVTE_Norm_Backend::Te, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Forward,
DType::kFloat16, DType::kBFloat16, DType::kFloat8E4M3, DType::kFloat32,
1, 1, false, false));

NVTE_CHECK(decode_itype(key) == DType::kBFloat16);
NVTE_CHECK(decode_otype(key) == DType::kFloat8E4M3);
NVTE_CHECK(decode_ctype(key) == DType::kFloat32);
NVTE_CHECK(decode_wtype(key) == DType::kFloat16);
NVTE_CHECK(decode_norm_type(key) == NVTE_Norm_Type::RMSNorm);

return true;
}();

} // namespace

template <typename KernelParamsType>
TeNormalizationPlan<KernelParamsType>::TeNormalizationPlan(
NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype,
Expand Down Expand Up @@ -609,4 +629,3 @@ void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) {
transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable;
}
#endif

Loading