From f639c6ec8292721761190ec92233b85e5ee79e92 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 11 May 2026 14:53:30 -0400 Subject: [PATCH 01/18] add rmsnorm perf benchmark --- tests/cpp/operator/CMakeLists.txt | 3 +- tests/cpp/operator/test_rmsnorm_perf.cu | 133 ++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 tests/cpp/operator/test_rmsnorm_perf.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 0ebd7fdfe..cea447854 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -39,7 +39,8 @@ if(USE_CUDA) else() list(APPEND test_cuda_sources test_cublaslt_gemm.cu - test_cast_mxfp4_transpose.cu) + test_cast_mxfp4_transpose.cu + test_rmsnorm_perf.cu) endif() if(USE_CUDA) diff --git a/tests/cpp/operator/test_rmsnorm_perf.cu b/tests/cpp/operator/test_rmsnorm_perf.cu new file mode 100644 index 000000000..69be3ce72 --- /dev/null +++ b/tests/cpp/operator/test_rmsnorm_perf.cu @@ -0,0 +1,133 @@ +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" +#include +#include +#include + +using namespace transformer_engine; +using namespace test; + +namespace { + +std::vector> tensor_dims = { + {8192, 1536}, + {8192, 7168}, +}; + +template +double time_ms(Fn&& fn, int warmup = 20, int iters = 100) { + for (int i = 0; i < warmup; ++i) { + fn(); + } + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + cudaEvent_t start, stop; + NVTE_CHECK_CUDA(cudaEventCreate(&start)); + NVTE_CHECK_CUDA(cudaEventCreate(&stop)); + + NVTE_CHECK_CUDA(cudaEventRecord(start)); + for (int i = 0; i < iters; ++i) { + fn(); + } + NVTE_CHECK_CUDA(cudaEventRecord(stop)); + NVTE_CHECK_CUDA(cudaEventSynchronize(stop)); + + float ms = 0.f; + NVTE_CHECK_CUDA(cudaEventElapsedTime(&ms, start, stop)); + + NVTE_CHECK_CUDA(cudaEventDestroy(start)); + NVTE_CHECK_CUDA(cudaEventDestroy(stop)); + + return static_cast(ms) / iters; +} + +template +void performTest(const size_t N, const size_t H) { + using InputType = OutputType; + using WeightType = OutputType; + const DType itype = TypeInfo::dtype; + const DType otype = TypeInfo::dtype; + const DType wtype = TypeInfo::dtype; + + float epsilon = 1e-5; + + Tensor input("input", std::vector{ N, H }, itype); + Tensor z("z", std::vector{ N, H }, otype); + Tensor gamma("gamma", std::vector{ H }, wtype); + Tensor rsigma("rsigma", std::vector{ N }, DType::kFloat32); + Tensor dz("dz", std::vector{ N, H }, wtype); + Tensor dx("dx", std::vector{ N, H }, itype); + Tensor dgamma("dgamma", std::vector{ H }, wtype); + Tensor workspace_fwd, workspace_bwd; + + fillUniform(&input); + fillUniform(&gamma); + setRandomScale(&z); + fillUniform(&dz); + + cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); + + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, false, 0); + workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype()); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + double fwd_ms = time_ms([&] { + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, false, 0); + }); + + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount, + false, 0); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + double bwd_ms = time_ms([&] { + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, false, 0); + }); + + std::cout << "RMSNORM_PERF" + << " N=" << N + << " H=" << H + << " dtype=" << typeName(otype) + << " fwd_ms=" << fwd_ms + << " bwd_ms=" << bwd_ms + << std::endl; + +} + +} // namespace + +class RMSNormPerfTestSuite + : public ::testing::TestWithParam< + std::tuple, transformer_engine::DType>> {}; + +TEST_P(RMSNormPerfTestSuite, TestRMSNormPerf) { + const auto tensor_shape = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + + const size_t N = tensor_shape.first; + const size_t H = tensor_shape.second; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, + performTest(N, H); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + RMSNormPerfTestSuite, + ::testing::Combine( + ::testing::ValuesIn(tensor_dims), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + [](const testing::TestParamInfo& info) { + std::string name = + std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + test::typeName(std::get<1>(info.param)); + return name; + }); \ No newline at end of file From b5720e970369c1ea8f94a40b6c888b2f0de988f3 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 11 May 2026 22:56:29 +0000 Subject: [PATCH 02/18] add rmsnorm perf benchmark and missing tuned DS configs --- tests/cpp/operator/test_rmsnorm_perf.cu | 20 +++++++++++++++++-- .../rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 12 +++++++++++ .../rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 18 +++++++++++++++++ 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/tests/cpp/operator/test_rmsnorm_perf.cu b/tests/cpp/operator/test_rmsnorm_perf.cu index 69be3ce72..59c5d22bb 100644 --- a/tests/cpp/operator/test_rmsnorm_perf.cu +++ b/tests/cpp/operator/test_rmsnorm_perf.cu @@ -10,8 +10,9 @@ using namespace test; namespace { std::vector> tensor_dims = { - {8192, 1536}, - {8192, 7168}, + // {8192, 128}, // Qwen + {8192, 1536}, // DS + {8192, 7168}, // DS }; template @@ -90,12 +91,27 @@ void performTest(const size_t N, const size_t H) { prop.multiProcessorCount, false, 0); }); + // Effective bandwidth using a simple algorithmic convention: + // FWD counts one full read of x and one full write of z. + // BWD counts the dominant full-tensor streams: read dz, read x, + // read gamma/logical weight stream, and write dx. + const double elem_bytes = static_cast(sizeof(OutputType)); + const double numel = static_cast(N) * static_cast(H); + + const double fwd_bytes = 2.0 * numel * elem_bytes; + const double bwd_bytes = 4.0 * numel * elem_bytes; + + const double fwd_gbps = fwd_bytes / (fwd_ms * 1.0e-3) / 1.0e9; + const double bwd_gbps = bwd_bytes / (bwd_ms * 1.0e-3) / 1.0e9; + std::cout << "RMSNORM_PERF" << " N=" << N << " H=" << H << " dtype=" << typeName(otype) << " fwd_ms=" << fwd_ms + << " fwd_GBps=" << fwd_gbps << " bwd_ms=" << bwd_ms + << " bwd_GBps=" << bwd_gbps << std::endl; } diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index da940254c..ce9bb7939 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -167,6 +167,12 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1 REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +#ifdef __HIP_PLATFORM_AMD__ +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +#endif + REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); @@ -175,6 +181,12 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp32, fp32, fp32, fp32, 1 REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +#ifdef __HIP_PLATFORM_AMD__ +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +#endif + REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index 3a28ebf13..596e50726 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -164,6 +164,15 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +#ifdef __HIP_PLATFORM_AMD__ +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +#endif + REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); @@ -178,6 +187,15 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp32, fp32, 1, REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +#ifdef __HIP_PLATFORM_AMD__ +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +#endif + REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); From 6c2cd286add365bd5547150535a4cbc22571db93 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 12 May 2026 12:38:17 +0000 Subject: [PATCH 03/18] add missing tuned shape for Qwen and update benchmark --- tests/cpp/operator/test_rmsnorm_perf.cu | 2 +- .../rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 6 ++++++ .../normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 9 +++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/cpp/operator/test_rmsnorm_perf.cu b/tests/cpp/operator/test_rmsnorm_perf.cu index 59c5d22bb..8ddff124a 100644 --- a/tests/cpp/operator/test_rmsnorm_perf.cu +++ b/tests/cpp/operator/test_rmsnorm_perf.cu @@ -10,7 +10,7 @@ using namespace test; namespace { std::vector> tensor_dims = { - // {8192, 128}, // Qwen + {8192, 128}, // Qwen {8192, 1536}, // DS {8192, 7168}, // DS }; diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index ce9bb7939..0c3d51b74 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -155,6 +155,12 @@ void launch_rmsnorm_bwd_general_(LaunchParams &launch_para // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL +#ifdef __HIP_PLATFORM_AMD__ +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4, 4); +#endif + REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index 596e50726..f8a709752 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -143,6 +143,15 @@ void launch_rmsnorm_fwd_general_(LaunchParams &launch_param // Create rmsnorm tuned launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG +#ifdef __HIP_PLATFORM_AMD__ +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4); +#endif + REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); From 912c62bc56acadd52ff6303d8dcf22c1e1f55104 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 12 May 2026 15:29:40 +0000 Subject: [PATCH 04/18] move benchmark to benchmarks folder and rewrite in google benchmark style --- benchmarks/cpp/CMakeLists.txt | 1 + .../cpp/normalization/bench_rmsnorm.cpp | 223 ++++++++++++++++++ tests/cpp/operator/CMakeLists.txt | 3 +- tests/cpp/operator/test_rmsnorm_perf.cu | 149 ------------ 4 files changed, 225 insertions(+), 151 deletions(-) create mode 100644 benchmarks/cpp/normalization/bench_rmsnorm.cpp delete mode 100644 tests/cpp/operator/test_rmsnorm_perf.cu diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt index 6071f9083..c660b8b06 100644 --- a/benchmarks/cpp/CMakeLists.txt +++ b/benchmarks/cpp/CMakeLists.txt @@ -86,3 +86,4 @@ add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp) add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp) add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp) add_te_benchmark(bench_casttranspose cast/bench_casttranspose.cpp) +add_te_benchmark(bench_rmsnorm normalization/bench_rmsnorm.cpp) diff --git a/benchmarks/cpp/normalization/bench_rmsnorm.cpp b/benchmarks/cpp/normalization/bench_rmsnorm.cpp new file mode 100644 index 000000000..d5bfc2ce3 --- /dev/null +++ b/benchmarks/cpp/normalization/bench_rmsnorm.cpp @@ -0,0 +1,223 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include + +#include "benchmark_utils.h" + +#include "transformer_engine/normalization_hip.h" +#include "transformer_engine/transformer_engine_hip.h" + +using namespace te_bench; +using namespace transformer_engine; + +#define RMSNORM_SHAPES \ + ->Args({8192, 128}) \ + ->Args({8192, 1536}) \ + ->Args({8192, 7168}) + +template +constexpr DType dtype_of() { + if constexpr (std::is_same_v) { + return DType::kFloat32; + } else if constexpr (std::is_same_v) { + return DType::kBFloat16; + } else { + return DType::kFloat16; + } +} + +template +static void BM_RMSNormForward(benchmark::State& state) { + const size_t N = state.range(0); + const size_t H = state.range(1); + const float epsilon = 1e-5f; + + const DType wtype = dtype_of(); + const DType itype = dtype_of(); + const DType otype = dtype_of(); + + test::Tensor input("input", std::vector{N, H}, itype); + test::Tensor output("output", std::vector{N, H}, otype); + test::Tensor gamma("gamma", std::vector{H}, wtype); + test::Tensor rsigma("rsigma", std::vector{N}, DType::kFloat32); + test::Tensor workspace; + + test::fillUniform(&input); + test::fillUniform(&gamma); + test::setRandomScale(&output); + + hipDeviceProp_t prop; + HIP_CHECK(hipGetDeviceProperties(&prop, 0)); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + output.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, false, stream); + + workspace = test::Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + output.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, false, stream); + + HIP_CHECK(hipStreamSynchronize(stream)); + warmup_gpu(); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + output.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, false, stream); + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0.0f; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + HIP_CHECK(hipStreamDestroy(stream)); + + // Algorithmic byte traffic by tensor role: + // read x + gamma, write z + rsigma. + const size_t bytes_read = + N * H * sizeof(IType) + + H * sizeof(WType); + + const size_t bytes_write = + N * H * sizeof(OType) + + N * sizeof(CType); + + set_bytes_processed(state, bytes_read + bytes_write); +} + +template +static void BM_RMSNormBackward(benchmark::State& state) { + const size_t N = state.range(0); + const size_t H = state.range(1); + const float epsilon = 1e-5f; + + const DType wtype = dtype_of(); + const DType itype = dtype_of(); + const DType otype = dtype_of(); + + test::Tensor input("input", std::vector{N, H}, itype); + test::Tensor output("output", std::vector{N, H}, otype); + test::Tensor gamma("gamma", std::vector{H}, wtype); + test::Tensor rsigma("rsigma", std::vector{N}, DType::kFloat32); + test::Tensor dz("dz", std::vector{N, H}, otype); + test::Tensor dx("dx", std::vector{N, H}, itype); + test::Tensor dgamma("dgamma", std::vector{H}, wtype); + test::Tensor workspace_fwd; + test::Tensor workspace_bwd; + + test::fillUniform(&input); + test::fillUniform(&gamma); + test::setRandomScale(&output); + test::fillUniform(&dz); + + hipDeviceProp_t prop; + HIP_CHECK(hipGetDeviceProperties(&prop, 0)); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + output.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, false, stream); + + workspace_fwd = test::Tensor("workspace_fwd", + workspace_fwd.rowwise_shape(), + workspace_fwd.dtype()); + + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + output.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, false, stream); + + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, false, stream); + + workspace_bwd = test::Tensor("workspace_bwd", + workspace_bwd.rowwise_shape(), + workspace_bwd.dtype()); + + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, false, stream); + + HIP_CHECK(hipStreamSynchronize(stream)); + warmup_gpu(); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, false, stream); + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0.0f; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + HIP_CHECK(hipStreamDestroy(stream)); + + // Algorithmic byte traffic by tensor role: + // read dz + x + rsigma + gamma, write dx + dgamma. + const size_t bytes_read = + N * H * sizeof(OType) + + N * H * sizeof(IType) + + N * sizeof(CType) + + H * sizeof(WType); + + const size_t bytes_write = + N * H * sizeof(IType) + + H * sizeof(WType); + + set_bytes_processed(state, bytes_read + bytes_write); +} + +#define REGISTER_RMSNORM(WTYPE, ITYPE, OTYPE, CTYPE, NAME) \ + BENCHMARK_TEMPLATE(BM_RMSNormForward, WTYPE, ITYPE, OTYPE, CTYPE) \ + ->Name("BM_RMSNormForward/" NAME) \ + RMSNORM_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_RMSNormBackward, WTYPE, ITYPE, OTYPE, CTYPE) \ + ->Name("BM_RMSNormBackward/" NAME) \ + RMSNORM_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +REGISTER_RMSNORM(hip_bfloat16, hip_bfloat16, hip_bfloat16, float, "BF16_BF16_BF16_FP32") +REGISTER_RMSNORM(half, half, half, float, "FP16_FP16_FP16_FP32") +REGISTER_RMSNORM(float, float, float, float, "FP32_FP32_FP32_FP32") + +BENCHMARK_MAIN(); diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index cea447854..0ebd7fdfe 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -39,8 +39,7 @@ if(USE_CUDA) else() list(APPEND test_cuda_sources test_cublaslt_gemm.cu - test_cast_mxfp4_transpose.cu - test_rmsnorm_perf.cu) + test_cast_mxfp4_transpose.cu) endif() if(USE_CUDA) diff --git a/tests/cpp/operator/test_rmsnorm_perf.cu b/tests/cpp/operator/test_rmsnorm_perf.cu deleted file mode 100644 index 8ddff124a..000000000 --- a/tests/cpp/operator/test_rmsnorm_perf.cu +++ /dev/null @@ -1,149 +0,0 @@ -#include "../test_common.h" -#include "transformer_engine/transformer_engine.h" -#include -#include -#include - -using namespace transformer_engine; -using namespace test; - -namespace { - -std::vector> tensor_dims = { - {8192, 128}, // Qwen - {8192, 1536}, // DS - {8192, 7168}, // DS -}; - -template -double time_ms(Fn&& fn, int warmup = 20, int iters = 100) { - for (int i = 0; i < warmup; ++i) { - fn(); - } - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - - cudaEvent_t start, stop; - NVTE_CHECK_CUDA(cudaEventCreate(&start)); - NVTE_CHECK_CUDA(cudaEventCreate(&stop)); - - NVTE_CHECK_CUDA(cudaEventRecord(start)); - for (int i = 0; i < iters; ++i) { - fn(); - } - NVTE_CHECK_CUDA(cudaEventRecord(stop)); - NVTE_CHECK_CUDA(cudaEventSynchronize(stop)); - - float ms = 0.f; - NVTE_CHECK_CUDA(cudaEventElapsedTime(&ms, start, stop)); - - NVTE_CHECK_CUDA(cudaEventDestroy(start)); - NVTE_CHECK_CUDA(cudaEventDestroy(stop)); - - return static_cast(ms) / iters; -} - -template -void performTest(const size_t N, const size_t H) { - using InputType = OutputType; - using WeightType = OutputType; - const DType itype = TypeInfo::dtype; - const DType otype = TypeInfo::dtype; - const DType wtype = TypeInfo::dtype; - - float epsilon = 1e-5; - - Tensor input("input", std::vector{ N, H }, itype); - Tensor z("z", std::vector{ N, H }, otype); - Tensor gamma("gamma", std::vector{ H }, wtype); - Tensor rsigma("rsigma", std::vector{ N }, DType::kFloat32); - Tensor dz("dz", std::vector{ N, H }, wtype); - Tensor dx("dx", std::vector{ N, H }, itype); - Tensor dgamma("dgamma", std::vector{ H }, wtype); - Tensor workspace_fwd, workspace_bwd; - - fillUniform(&input); - fillUniform(&gamma); - setRandomScale(&z); - fillUniform(&dz); - - cudaDeviceProp prop; - NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); - - nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, - z.data(), rsigma.data(), workspace_fwd.data(), - prop.multiProcessorCount, false, 0); - workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype()); - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - double fwd_ms = time_ms([&] { - nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, - z.data(), rsigma.data(), workspace_fwd.data(), - prop.multiProcessorCount, false, 0); - }); - - nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount, - false, 0); - workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - double bwd_ms = time_ms([&] { - nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), workspace_bwd.data(), - prop.multiProcessorCount, false, 0); - }); - - // Effective bandwidth using a simple algorithmic convention: - // FWD counts one full read of x and one full write of z. - // BWD counts the dominant full-tensor streams: read dz, read x, - // read gamma/logical weight stream, and write dx. - const double elem_bytes = static_cast(sizeof(OutputType)); - const double numel = static_cast(N) * static_cast(H); - - const double fwd_bytes = 2.0 * numel * elem_bytes; - const double bwd_bytes = 4.0 * numel * elem_bytes; - - const double fwd_gbps = fwd_bytes / (fwd_ms * 1.0e-3) / 1.0e9; - const double bwd_gbps = bwd_bytes / (bwd_ms * 1.0e-3) / 1.0e9; - - std::cout << "RMSNORM_PERF" - << " N=" << N - << " H=" << H - << " dtype=" << typeName(otype) - << " fwd_ms=" << fwd_ms - << " fwd_GBps=" << fwd_gbps - << " bwd_ms=" << bwd_ms - << " bwd_GBps=" << bwd_gbps - << std::endl; - -} - -} // namespace - -class RMSNormPerfTestSuite - : public ::testing::TestWithParam< - std::tuple, transformer_engine::DType>> {}; - -TEST_P(RMSNormPerfTestSuite, TestRMSNormPerf) { - const auto tensor_shape = std::get<0>(GetParam()); - const DType output_type = std::get<1>(GetParam()); - - const size_t N = tensor_shape.first; - const size_t H = tensor_shape.second; - - TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, - performTest(N, H); - ); -} - -INSTANTIATE_TEST_SUITE_P( - OperatorTest, - RMSNormPerfTestSuite, - ::testing::Combine( - ::testing::ValuesIn(tensor_dims), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), - [](const testing::TestParamInfo& info) { - std::string name = - std::to_string(std::get<0>(info.param).first) + "X" + - std::to_string(std::get<0>(info.param).second) + "X" + - test::typeName(std::get<1>(info.param)); - return name; - }); \ No newline at end of file From c24a0915f825f151ec06380fd140dfeeada5aaa8 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 12 May 2026 19:44:51 +0000 Subject: [PATCH 05/18] add matching tuned configs for layernorm --- .../layernorm/ln_bwd_semi_cuda_kernel.cu | 12 +++++++++ .../layernorm/ln_fwd_cuda_kernel.cu | 13 ++++++++- .../rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 27 ++++++++++++++----- .../rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 17 +++++++----- 4 files changed, 56 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index 8f0a8a14b..9129100a7 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -154,6 +154,12 @@ void launch_ln_bwd_general_(LaunchParams &launch_params, // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 128, fp16, fp32, fp16, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 128, bf16, fp32, bf16, fp32, 1, 4, 1, 4, 4); + REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); @@ -214,6 +220,12 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp16, fp32, fp16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); + REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index f76fcf582..95df4e707 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -142,6 +142,7 @@ void launch_ln_fwd_general_(LaunchParams &launch_params, // Create tuned launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); @@ -168,6 +169,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, bf16, bf16, fp8e4m3, fp REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, bf16, bf16, fp8e4m3, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); @@ -194,6 +196,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp16, fp16, fp8e4m3, fp REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp8e4m3, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); @@ -220,6 +223,12 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, fp8e4m3, fp REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp8e4m3, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp32, fp32, fp16, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp32, fp32, bf16, fp32, 1, 4, 1, 4); + REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp16, fp32, 1, 4, 1, 16); @@ -423,7 +432,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, bf16, fp32 #ifdef __HIP_PLATFORM_AMD__ // ROCM uses TE normalization for e5m2 - +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); @@ -450,6 +459,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, bf16, bf16, fp8e5m2, fp REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, bf16, bf16, fp8e5m2, fp32, 4, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, bf16, bf16, fp8e5m2, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); @@ -476,6 +486,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp16, fp16, fp8e5m2, fp REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp16, fp16, fp8e5m2, fp32, 4, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp8e5m2, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 16); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index 0c3d51b74..40fdde753 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -155,11 +155,9 @@ void launch_rmsnorm_bwd_general_(LaunchParams &launch_para // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL -#ifdef __HIP_PLATFORM_AMD__ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4, 4); -#endif REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); @@ -173,11 +171,9 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1 REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -#ifdef __HIP_PLATFORM_AMD__ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -#endif REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); @@ -187,11 +183,9 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp32, fp32, fp32, fp32, 1 REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -#ifdef __HIP_PLATFORM_AMD__ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -#endif REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); @@ -235,6 +229,13 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4, 4, + true); + REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4, true); REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4, @@ -256,6 +257,13 @@ REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, fp16, fp16, fp16, fp32 REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4, true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4, + true); + REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, true); REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4, @@ -270,6 +278,13 @@ REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, fp16, fp16, fp16, fp32 REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4, true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4, + true); + REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, true); REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4, diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index f8a709752..3218752f8 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -143,14 +143,12 @@ void launch_rmsnorm_fwd_general_(LaunchParams &launch_param // Create rmsnorm tuned launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG -#ifdef __HIP_PLATFORM_AMD__ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4); -#endif REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); @@ -173,14 +171,12 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -#ifdef __HIP_PLATFORM_AMD__ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -#endif REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); @@ -196,14 +192,12 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp32, fp32, 1, REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -#ifdef __HIP_PLATFORM_AMD__ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -#endif REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); @@ -265,6 +259,9 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, bf16, fp32, #ifdef __HIP_PLATFORM_AMD__ // ROCM uses TE normalization for e5m2 +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); @@ -278,6 +275,10 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, fp8e5m2, fp32, REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 16); + REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 16); @@ -286,6 +287,10 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, fp8e5m2, fp32, REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); + REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); From 3d8e1dee94b5f7d5d7628c65a5f3554143b9099e Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 12 May 2026 20:11:03 +0000 Subject: [PATCH 06/18] add fallback warning print if tuned config not found for normalization --- transformer_engine/common/normalization/common.h | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 70584fac3..2aff55d9a 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "../common.h" #ifndef __HIP_PLATFORM_AMD__ @@ -233,7 +234,16 @@ class TeNormalizationRegistry { if (is_tuned) { auto it = instance.tuned_function_map.find(key); if (it != instance.tuned_function_map.end()) return it->second; - } + + static thread_local std::unordered_set warned_keys; + if (warned_keys.insert(key).second) { + NVTE_WARN("Falling back to general normalization kernel because no tuned kernel " + "is available for this config. batch_size=", + batch_size, + ", hidden_size=", + hidden_size); + } + } if (instance.general_function_map.count(general_key) == 0) { NVTE_ERROR("Unavailable kernel for this normalization config."); } From 856346deeefb55b361cdbfad29e0e53c378ea137 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 12 May 2026 20:41:57 +0000 Subject: [PATCH 07/18] add fallback warning message when tuned normalization kernel config not found --- .../common/normalization/common.h | 53 +++++++++++++++++-- .../layernorm/ln_bwd_semi_cuda_kernel.cu | 10 ++-- .../rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 4 +- .../rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 4 +- 4 files changed, 57 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 2aff55d9a..3ae30b9d8 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -203,6 +203,43 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, bool training = true, bool gamma_in_weight_dtype = false); +inline DType decode_itype(uint64_t general_key) { + return static_cast(general_key & 0x1F); +} + +inline DType decode_otype(uint64_t general_key) { + return static_cast((general_key >> 5) & 0x1F); +} + +inline DType decode_ctype(uint64_t general_key) { + return static_cast((general_key >> 10) & 0x1F); +} + +inline DType decode_wtype(uint64_t general_key) { + return static_cast((general_key >> 15) & 0x1F); +} + +inline const char* dtype_to_string(DType dtype) { + switch (dtype) { + case DType::kFloat32: + return "fp32"; + case DType::kFloat16: + return "fp16"; + case DType::kBFloat16: + return "bf16"; + case DType::kFloat8E4M3: + return "fp8e4m3"; + case DType::kFloat8E5M2: + return "fp8e5m2"; + case DType::kByte: + return "byte"; + case DType::kInt32: + return "int32"; + default: + return "unknown"; + } +} + template class TeNormalizationRegistry { private: @@ -227,7 +264,7 @@ class TeNormalizationRegistry { getInstance().general_function_map[general_key].emplace(hidden_size, Function(func)); return 0; } - + static Function getKernel(TupleKeyType key) { auto& instance = getInstance(); auto [general_key, batch_size, hidden_size, is_tuned] = key; @@ -238,10 +275,16 @@ class TeNormalizationRegistry { static thread_local std::unordered_set warned_keys; if (warned_keys.insert(key).second) { NVTE_WARN("Falling back to general normalization kernel because no tuned kernel " - "is available for this config. batch_size=", - batch_size, - ", hidden_size=", - hidden_size); + "is available for this config. hidden_size=", + hidden_size, + ", wtype=", + dtype_to_string(decode_wtype(general_key)), + ", itype=", + dtype_to_string(decode_itype(general_key)), + ", otype=", + dtype_to_string(decode_otype(general_key)), + ", ctype=", + dtype_to_string(decode_ctype(general_key))); } } if (instance.general_function_map.count(general_key) == 0) { diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index 9129100a7..8914ac18b 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -220,11 +220,11 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp16, fp32, fp16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 7, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 7, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp16, fp32, fp16, fp32, 1, 1, 7, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 7, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, bf16, fp32, bf16, fp32, 1, 1, 7, 8, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index 40fdde753..3b2b6316f 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -155,9 +155,9 @@ void launch_rmsnorm_bwd_general_(LaunchParams &launch_para // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL -REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4, 4); +/*REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4, 4); -REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4, 4);*/ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index 3218752f8..6ab3a1c6f 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -143,12 +143,12 @@ void launch_rmsnorm_fwd_general_(LaunchParams &launch_param // Create rmsnorm tuned launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG -REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 4); +/*REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4); -REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4);*/ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); From d5293cb851476a8cc2d3fd2a3001abfb01753fcd Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 12 May 2026 20:48:49 +0000 Subject: [PATCH 08/18] uncomment qwen configs --- .../normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 4 ++-- .../common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index 3b2b6316f..40fdde753 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -155,9 +155,9 @@ void launch_rmsnorm_bwd_general_(LaunchParams &launch_para // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL -/*REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4, 4); -REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4, 4);*/ +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index 6ab3a1c6f..3218752f8 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -143,12 +143,12 @@ void launch_rmsnorm_fwd_general_(LaunchParams &launch_param // Create rmsnorm tuned launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG -/*REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4); -REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4);*/ +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); From 78f9aa5fb807f9d058fb77d4747e729908e7ee67 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 12 May 2026 21:06:01 +0000 Subject: [PATCH 09/18] generalization rms norm benchmark to also include layer norm, and add missing configs for layer norm --- benchmarks/cpp/CMakeLists.txt | 2 +- .../cpp/normalization/bench_normalization.cpp | 299 ++++++++++++++++++ .../cpp/normalization/bench_rmsnorm.cpp | 223 ------------- .../common/normalization/common.h | 19 +- .../layernorm/ln_fwd_cuda_kernel.cu | 12 + 5 files changed, 330 insertions(+), 225 deletions(-) create mode 100644 benchmarks/cpp/normalization/bench_normalization.cpp delete mode 100644 benchmarks/cpp/normalization/bench_rmsnorm.cpp diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt index c660b8b06..e8ad3526f 100644 --- a/benchmarks/cpp/CMakeLists.txt +++ b/benchmarks/cpp/CMakeLists.txt @@ -86,4 +86,4 @@ add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp) add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp) add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp) add_te_benchmark(bench_casttranspose cast/bench_casttranspose.cpp) -add_te_benchmark(bench_rmsnorm normalization/bench_rmsnorm.cpp) +add_te_benchmark(bench_normalization normalization/bench_normalization.cpp) diff --git a/benchmarks/cpp/normalization/bench_normalization.cpp b/benchmarks/cpp/normalization/bench_normalization.cpp new file mode 100644 index 000000000..ef6fdab91 --- /dev/null +++ b/benchmarks/cpp/normalization/bench_normalization.cpp @@ -0,0 +1,299 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include + +#include "benchmark_utils.h" + +#include "transformer_engine/normalization_hip.h" +#include "transformer_engine/transformer_engine_hip.h" + +using namespace te_bench; +using namespace transformer_engine; + +#define NORM_SHAPES \ + ->Args({8192, 128}) \ + ->Args({8192, 1536}) \ + ->Args({8192, 7168}) + +enum class BenchNormType { + LayerNorm, + RMSNorm, +}; + +template +constexpr DType dtype_of() { + if constexpr (std::is_same_v) { + return DType::kFloat32; + } else if constexpr (std::is_same_v) { + return DType::kBFloat16; + } else { + return DType::kFloat16; + } +} + +template +static void BM_NormForward(benchmark::State& state) { + const size_t N = state.range(0); + const size_t H = state.range(1); + const float epsilon = 1e-5f; + constexpr bool zero_centered_gamma = false; + + const DType wtype = dtype_of(); + const DType itype = dtype_of(); + const DType otype = dtype_of(); + + test::Tensor input("input", std::vector{N, H}, itype); + test::Tensor output("output", std::vector{N, H}, otype); + test::Tensor gamma("gamma", std::vector{H}, wtype); + test::Tensor beta("beta", std::vector{H}, wtype); + test::Tensor mu("mu", std::vector{N}, DType::kFloat32); + test::Tensor rsigma("rsigma", std::vector{N}, DType::kFloat32); + test::Tensor workspace; + + test::fillUniform(&input); + test::fillUniform(&gamma); + test::fillUniform(&beta); + test::setRandomScale(&output); + + hipDeviceProp_t prop; + HIP_CHECK(hipGetDeviceProperties(&prop, 0)); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + output.data(), mu.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + output.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + workspace = test::Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + output.data(), mu.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + output.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + HIP_CHECK(hipStreamSynchronize(stream)); + warmup_gpu(); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + output.data(), mu.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + output.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0.0f; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + HIP_CHECK(hipStreamDestroy(stream)); + + size_t bytes_read = + N * H * sizeof(IType) + // x + H * sizeof(WType); // gamma + + size_t bytes_write = + N * H * sizeof(OType) + // z + N * sizeof(float); // rsigma + + if constexpr (Norm == BenchNormType::LayerNorm) { + bytes_read += H * sizeof(WType); // beta + bytes_write += N * sizeof(float); // mu + } + + set_bytes_processed(state, bytes_read + bytes_write); +} + +template +static void BM_NormBackward(benchmark::State& state) { + const size_t N = state.range(0); + const size_t H = state.range(1); + const float epsilon = 1e-5f; + constexpr bool zero_centered_gamma = false; + + const DType wtype = dtype_of(); + const DType itype = dtype_of(); + const DType otype = dtype_of(); + + test::Tensor input("input", std::vector{N, H}, itype); + test::Tensor output("output", std::vector{N, H}, otype); + test::Tensor gamma("gamma", std::vector{H}, wtype); + test::Tensor beta("beta", std::vector{H}, wtype); + test::Tensor mu("mu", std::vector{N}, DType::kFloat32); + test::Tensor rsigma("rsigma", std::vector{N}, DType::kFloat32); + test::Tensor dz("dz", std::vector{N, H}, otype); + test::Tensor dx("dx", std::vector{N, H}, itype); + test::Tensor dgamma("dgamma", std::vector{H}, wtype); + test::Tensor dbeta("dbeta", std::vector{H}, wtype); + test::Tensor workspace_fwd; + test::Tensor workspace_bwd; + + test::fillUniform(&input); + test::fillUniform(&gamma); + test::fillUniform(&beta); + test::setRandomScale(&output); + test::fillUniform(&dz); + + hipDeviceProp_t prop; + HIP_CHECK(hipGetDeviceProperties(&prop, 0)); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + output.data(), mu.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + output.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + workspace_fwd = test::Tensor("workspace_fwd", + workspace_fwd.rowwise_shape(), + workspace_fwd.dtype()); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + output.data(), mu.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + + nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + output.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + workspace_bwd = test::Tensor("workspace_bwd", + workspace_bwd.rowwise_shape(), + workspace_bwd.dtype()); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + HIP_CHECK(hipStreamSynchronize(stream)); + warmup_gpu(); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0.0f; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + HIP_CHECK(hipStreamDestroy(stream)); + + size_t bytes_read = + N * H * sizeof(OType) + // dz + N * H * sizeof(IType) + // x + N * sizeof(float) + // rsigma + H * sizeof(WType); // gamma + + size_t bytes_write = + N * H * sizeof(IType) + // dx + H * sizeof(WType); // dgamma + + if constexpr (Norm == BenchNormType::LayerNorm) { + bytes_read += N * sizeof(float); // mu + bytes_write += H * sizeof(WType); // dbeta + } + + set_bytes_processed(state, bytes_read + bytes_write); +} + +#define REGISTER_NORM_BENCH(NORM_ENUM, NORM_NAME, WTYPE, ITYPE, OTYPE, CTYPE, NAME) \ + BENCHMARK_TEMPLATE(BM_NormForward, NORM_ENUM, WTYPE, ITYPE, OTYPE, CTYPE) \ + ->Name("BM_" NORM_NAME "Forward/" NAME) \ + NORM_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_NormBackward, NORM_ENUM, WTYPE, ITYPE, OTYPE, CTYPE) \ + ->Name("BM_" NORM_NAME "Backward/" NAME) \ + NORM_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +#define REGISTER_RMSNORM(WTYPE, ITYPE, OTYPE, CTYPE, NAME) \ + REGISTER_NORM_BENCH(BenchNormType::RMSNorm, "RMSNorm", WTYPE, ITYPE, OTYPE, CTYPE, NAME) + +#define REGISTER_LAYERNORM(WTYPE, ITYPE, OTYPE, CTYPE, NAME) \ + REGISTER_NORM_BENCH(BenchNormType::LayerNorm, "LayerNorm", WTYPE, ITYPE, OTYPE, CTYPE, NAME) + +REGISTER_RMSNORM(hip_bfloat16, hip_bfloat16, hip_bfloat16, float, "BF16_BF16_BF16_FP32") +REGISTER_RMSNORM(half, half, half, float, "FP16_FP16_FP16_FP32") +REGISTER_RMSNORM(float, float, float, float, "FP32_FP32_FP32_FP32") + +REGISTER_LAYERNORM(hip_bfloat16, hip_bfloat16, hip_bfloat16, float, "BF16_BF16_BF16_FP32") +REGISTER_LAYERNORM(half, half, half, float, "FP16_FP16_FP16_FP32") +REGISTER_LAYERNORM(float, float, float, float, "FP32_FP32_FP32_FP32") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/normalization/bench_rmsnorm.cpp b/benchmarks/cpp/normalization/bench_rmsnorm.cpp deleted file mode 100644 index d5bfc2ce3..000000000 --- a/benchmarks/cpp/normalization/bench_rmsnorm.cpp +++ /dev/null @@ -1,223 +0,0 @@ -/************************************************************************* - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - -#include -#include -#include -#include - -#include "benchmark_utils.h" - -#include "transformer_engine/normalization_hip.h" -#include "transformer_engine/transformer_engine_hip.h" - -using namespace te_bench; -using namespace transformer_engine; - -#define RMSNORM_SHAPES \ - ->Args({8192, 128}) \ - ->Args({8192, 1536}) \ - ->Args({8192, 7168}) - -template -constexpr DType dtype_of() { - if constexpr (std::is_same_v) { - return DType::kFloat32; - } else if constexpr (std::is_same_v) { - return DType::kBFloat16; - } else { - return DType::kFloat16; - } -} - -template -static void BM_RMSNormForward(benchmark::State& state) { - const size_t N = state.range(0); - const size_t H = state.range(1); - const float epsilon = 1e-5f; - - const DType wtype = dtype_of(); - const DType itype = dtype_of(); - const DType otype = dtype_of(); - - test::Tensor input("input", std::vector{N, H}, itype); - test::Tensor output("output", std::vector{N, H}, otype); - test::Tensor gamma("gamma", std::vector{H}, wtype); - test::Tensor rsigma("rsigma", std::vector{N}, DType::kFloat32); - test::Tensor workspace; - - test::fillUniform(&input); - test::fillUniform(&gamma); - test::setRandomScale(&output); - - hipDeviceProp_t prop; - HIP_CHECK(hipGetDeviceProperties(&prop, 0)); - - hipStream_t stream; - HIP_CHECK(hipStreamCreate(&stream)); - - nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, - output.data(), rsigma.data(), workspace.data(), - prop.multiProcessorCount, false, stream); - - workspace = test::Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - - nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, - output.data(), rsigma.data(), workspace.data(), - prop.multiProcessorCount, false, stream); - - HIP_CHECK(hipStreamSynchronize(stream)); - warmup_gpu(); - - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for (auto _ : state) { - HIP_CHECK(hipEventRecord(start, stream)); - - nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, - output.data(), rsigma.data(), workspace.data(), - prop.multiProcessorCount, false, stream); - - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float ms = 0.0f; - HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); - state.SetIterationTime(ms / 1000.0); - } - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - HIP_CHECK(hipStreamDestroy(stream)); - - // Algorithmic byte traffic by tensor role: - // read x + gamma, write z + rsigma. - const size_t bytes_read = - N * H * sizeof(IType) + - H * sizeof(WType); - - const size_t bytes_write = - N * H * sizeof(OType) + - N * sizeof(CType); - - set_bytes_processed(state, bytes_read + bytes_write); -} - -template -static void BM_RMSNormBackward(benchmark::State& state) { - const size_t N = state.range(0); - const size_t H = state.range(1); - const float epsilon = 1e-5f; - - const DType wtype = dtype_of(); - const DType itype = dtype_of(); - const DType otype = dtype_of(); - - test::Tensor input("input", std::vector{N, H}, itype); - test::Tensor output("output", std::vector{N, H}, otype); - test::Tensor gamma("gamma", std::vector{H}, wtype); - test::Tensor rsigma("rsigma", std::vector{N}, DType::kFloat32); - test::Tensor dz("dz", std::vector{N, H}, otype); - test::Tensor dx("dx", std::vector{N, H}, itype); - test::Tensor dgamma("dgamma", std::vector{H}, wtype); - test::Tensor workspace_fwd; - test::Tensor workspace_bwd; - - test::fillUniform(&input); - test::fillUniform(&gamma); - test::setRandomScale(&output); - test::fillUniform(&dz); - - hipDeviceProp_t prop; - HIP_CHECK(hipGetDeviceProperties(&prop, 0)); - - hipStream_t stream; - HIP_CHECK(hipStreamCreate(&stream)); - - nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, - output.data(), rsigma.data(), workspace_fwd.data(), - prop.multiProcessorCount, false, stream); - - workspace_fwd = test::Tensor("workspace_fwd", - workspace_fwd.rowwise_shape(), - workspace_fwd.dtype()); - - nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, - output.data(), rsigma.data(), workspace_fwd.data(), - prop.multiProcessorCount, false, stream); - - nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), workspace_bwd.data(), - prop.multiProcessorCount, false, stream); - - workspace_bwd = test::Tensor("workspace_bwd", - workspace_bwd.rowwise_shape(), - workspace_bwd.dtype()); - - nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), workspace_bwd.data(), - prop.multiProcessorCount, false, stream); - - HIP_CHECK(hipStreamSynchronize(stream)); - warmup_gpu(); - - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for (auto _ : state) { - HIP_CHECK(hipEventRecord(start, stream)); - - nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), workspace_bwd.data(), - prop.multiProcessorCount, false, stream); - - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float ms = 0.0f; - HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); - state.SetIterationTime(ms / 1000.0); - } - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - HIP_CHECK(hipStreamDestroy(stream)); - - // Algorithmic byte traffic by tensor role: - // read dz + x + rsigma + gamma, write dx + dgamma. - const size_t bytes_read = - N * H * sizeof(OType) + - N * H * sizeof(IType) + - N * sizeof(CType) + - H * sizeof(WType); - - const size_t bytes_write = - N * H * sizeof(IType) + - H * sizeof(WType); - - set_bytes_processed(state, bytes_read + bytes_write); -} - -#define REGISTER_RMSNORM(WTYPE, ITYPE, OTYPE, CTYPE, NAME) \ - BENCHMARK_TEMPLATE(BM_RMSNormForward, WTYPE, ITYPE, OTYPE, CTYPE) \ - ->Name("BM_RMSNormForward/" NAME) \ - RMSNORM_SHAPES \ - ->Unit(benchmark::kMicrosecond) \ - ->UseManualTime(); \ - BENCHMARK_TEMPLATE(BM_RMSNormBackward, WTYPE, ITYPE, OTYPE, CTYPE) \ - ->Name("BM_RMSNormBackward/" NAME) \ - RMSNORM_SHAPES \ - ->Unit(benchmark::kMicrosecond) \ - ->UseManualTime(); - -REGISTER_RMSNORM(hip_bfloat16, hip_bfloat16, hip_bfloat16, float, "BF16_BF16_BF16_FP32") -REGISTER_RMSNORM(half, half, half, float, "FP16_FP16_FP16_FP32") -REGISTER_RMSNORM(float, float, float, float, "FP32_FP32_FP32_FP32") - -BENCHMARK_MAIN(); diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 3ae30b9d8..c476f1dc5 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -219,6 +219,10 @@ inline DType decode_wtype(uint64_t general_key) { return static_cast((general_key >> 15) & 0x1F); } +inline NVTE_Norm_Type decode_norm_type(uint64_t general_key) { + return static_cast((general_key >> 20) & 0x3); +} + inline const char* dtype_to_string(DType dtype) { switch (dtype) { case DType::kFloat32: @@ -240,6 +244,17 @@ inline const char* dtype_to_string(DType dtype) { } } +inline const char* norm_type_to_string(NVTE_Norm_Type norm_type) { + switch (norm_type) { + case NVTE_Norm_Type::LayerNorm: + return "LayerNorm"; + case NVTE_Norm_Type::RMSNorm: + return "RMSNorm"; + default: + return "unknown"; + } +} + template class TeNormalizationRegistry { private: @@ -275,7 +290,9 @@ class TeNormalizationRegistry { static thread_local std::unordered_set warned_keys; if (warned_keys.insert(key).second) { NVTE_WARN("Falling back to general normalization kernel because no tuned kernel " - "is available for this config. hidden_size=", + "is available for this config. norm_type=", + norm_type_to_string(decode_norm_type(general_key)), + ", hidden_size=", hidden_size, ", wtype=", dtype_to_string(decode_wtype(general_key)), diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index 95df4e707..79a45cd17 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -153,6 +153,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, bf16, bf16, fp8e4m3, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); @@ -180,6 +181,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp16, fp16, fp8e4m3, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); @@ -207,6 +209,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp8e4m3, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); @@ -289,6 +292,12 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp16, fp32, 1, 1, 4, 16); @@ -443,6 +452,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, bf16, bf16, fp8e5m2, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, bf16, bf16, fp8e5m2, fp32, 2, 1, 4, 16); @@ -470,6 +480,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp16, fp16, fp8e5m2, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp8e5m2, fp32, 2, 1, 4, 16); @@ -497,6 +508,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp8e5m2, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp8e5m2, fp32, 2, 1, 4, 16); From 4256e3c867f4461a1fd9ff1eebec37dbe78ba14c Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Fri, 15 May 2026 16:00:08 +0000 Subject: [PATCH 10/18] remove redundant synchronization before gpu warmup in bench_normalization.cpp --- benchmarks/cpp/normalization/bench_normalization.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/benchmarks/cpp/normalization/bench_normalization.cpp b/benchmarks/cpp/normalization/bench_normalization.cpp index ef6fdab91..fbbfe7cd4 100644 --- a/benchmarks/cpp/normalization/bench_normalization.cpp +++ b/benchmarks/cpp/normalization/bench_normalization.cpp @@ -90,7 +90,6 @@ static void BM_NormForward(benchmark::State& state) { prop.multiProcessorCount, zero_centered_gamma, stream); } - HIP_CHECK(hipStreamSynchronize(stream)); warmup_gpu(); hipEvent_t start, stop; @@ -220,7 +219,6 @@ static void BM_NormBackward(benchmark::State& state) { prop.multiProcessorCount, zero_centered_gamma, stream); } - HIP_CHECK(hipStreamSynchronize(stream)); warmup_gpu(); hipEvent_t start, stop; From 2f9ff47c9eea28c4947463a61b7e6193f8dbc437 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Fri, 15 May 2026 16:03:42 +0000 Subject: [PATCH 11/18] address nit: move unordered_set to after unordered_map --- transformer_engine/common/normalization/common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index c476f1dc5..34f1d2314 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -23,8 +23,8 @@ #include #include #include -#include #include +#include #include "../common.h" #ifndef __HIP_PLATFORM_AMD__ From d548d541a1138048d54b585c44fa3621a948c404 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Fri, 15 May 2026 16:37:25 +0000 Subject: [PATCH 12/18] share normalization key bit layout constants --- .../common/normalization/common.cpp | 18 ++++++---- .../common/normalization/common.h | 33 ++++++++++++++++--- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index d6aa55b37..7e8659bea 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -58,12 +58,18 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, static_assert(NVTE_INVALID_SCALING < 1024, "This function assumes at most 10 bits used in the scaling mode."); static_assert(kNVTENumTypes < 32, "This function assumes at most 5 bits used in the NVTEDType"); - uint64_t general_key = static_cast(itype) | (static_cast(otype) << 5) | - (static_cast(ctype) << 10) | - (static_cast(wtype) << 15) | (uint64_t(NormType) << 20) | - (uint64_t(NormStage)) << 22 | (uint64_t(NormBackend) << 24) | - (uint64_t(zero_centered_gamma) << 26) | (uint64_t(mode) << 27) | - (uint64_t(training) << 37) | (uint64_t(gamma_in_weight_dtype) << 38); + uint64_t general_key = + (static_cast(itype) << norm_key::kItypeShift) | + (static_cast(otype) << norm_key::kOtypeShift) | + (static_cast(ctype) << norm_key::kCtypeShift) | + (static_cast(wtype) << norm_key::kWtypeShift) | + (static_cast(NormType) << norm_key::kNormTypeShift) | + (static_cast(NormStage) << norm_key::kNormStageShift) | + (static_cast(NormBackend) << norm_key::kNormBackendShift) | + (static_cast(zero_centered_gamma) << norm_key::kZeroCenteredGammaShift) | + (static_cast(mode) << norm_key::kScalingModeShift) | + (static_cast(training) << norm_key::kTrainingShift) | + (static_cast(gamma_in_weight_dtype) << norm_key::kGammaInWeightDtypeShift); return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); } diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 34f1d2314..e9890b42f 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -38,6 +38,27 @@ namespace transformer_engine { namespace normalization { +// Bit layout for the packed normalization general_key. +// Keep these constants in sync with get_key() encoding in common.cpp. +namespace norm_key { + +constexpr int kItypeShift = 0; +constexpr int kOtypeShift = 5; +constexpr int kCtypeShift = 10; +constexpr int kWtypeShift = 15; +constexpr int kNormTypeShift = 20; +constexpr int kNormStageShift = 22; +constexpr int kNormBackendShift = 24; +constexpr int kZeroCenteredGammaShift = 26; +constexpr int kScalingModeShift = 27; +constexpr int kTrainingShift = 37; +constexpr int kGammaInWeightDtypeShift = 38; + +constexpr uint64_t kDTypeMask = 0x1F; +constexpr uint64_t kNormTypeMask = 0x3; + +} // namespace norm_key + #ifndef __HIP_PLATFORM_AMD__ namespace fe = cudnn_frontend; #endif @@ -203,24 +224,26 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, bool training = true, bool gamma_in_weight_dtype = false); +// Decode helpers assume the same general_key bit layout used by get_key(). inline DType decode_itype(uint64_t general_key) { - return static_cast(general_key & 0x1F); + return static_cast((general_key >> norm_key::kItypeShift) & norm_key::kDTypeMask); } inline DType decode_otype(uint64_t general_key) { - return static_cast((general_key >> 5) & 0x1F); + return static_cast((general_key >> norm_key::kOtypeShift) & norm_key::kDTypeMask); } inline DType decode_ctype(uint64_t general_key) { - return static_cast((general_key >> 10) & 0x1F); + return static_cast((general_key >> norm_key::kCtypeShift) & norm_key::kDTypeMask); } inline DType decode_wtype(uint64_t general_key) { - return static_cast((general_key >> 15) & 0x1F); + return static_cast((general_key >> norm_key::kWtypeShift) & norm_key::kDTypeMask); } inline NVTE_Norm_Type decode_norm_type(uint64_t general_key) { - return static_cast((general_key >> 20) & 0x3); + return static_cast( + (general_key >> norm_key::kNormTypeShift) & norm_key::kNormTypeMask); } inline const char* dtype_to_string(DType dtype) { From b73062c7ddb826c0677d288139ceb9f776f95e7c Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 18 May 2026 13:30:29 +0000 Subject: [PATCH 13/18] remove unneeded line splits and deduplicate norm benchmark epsilon --- .../cpp/normalization/bench_normalization.cpp | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/benchmarks/cpp/normalization/bench_normalization.cpp b/benchmarks/cpp/normalization/bench_normalization.cpp index fbbfe7cd4..92ac3c946 100644 --- a/benchmarks/cpp/normalization/bench_normalization.cpp +++ b/benchmarks/cpp/normalization/bench_normalization.cpp @@ -22,6 +22,8 @@ using namespace transformer_engine; ->Args({8192, 1536}) \ ->Args({8192, 7168}) +constexpr float kNormEpsilon = 1e-5f; + enum class BenchNormType { LayerNorm, RMSNorm, @@ -42,7 +44,6 @@ template (); @@ -69,11 +70,11 @@ static void BM_NormForward(benchmark::State& state) { HIP_CHECK(hipStreamCreate(&stream)); if constexpr (Norm == BenchNormType::LayerNorm) { - nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon, output.data(), mu.data(), rsigma.data(), workspace.data(), prop.multiProcessorCount, zero_centered_gamma, stream); } else { - nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon, output.data(), rsigma.data(), workspace.data(), prop.multiProcessorCount, zero_centered_gamma, stream); } @@ -81,11 +82,11 @@ static void BM_NormForward(benchmark::State& state) { workspace = test::Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); if constexpr (Norm == BenchNormType::LayerNorm) { - nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon, output.data(), mu.data(), rsigma.data(), workspace.data(), prop.multiProcessorCount, zero_centered_gamma, stream); } else { - nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon, output.data(), rsigma.data(), workspace.data(), prop.multiProcessorCount, zero_centered_gamma, stream); } @@ -100,11 +101,11 @@ static void BM_NormForward(benchmark::State& state) { HIP_CHECK(hipEventRecord(start, stream)); if constexpr (Norm == BenchNormType::LayerNorm) { - nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon, output.data(), mu.data(), rsigma.data(), workspace.data(), prop.multiProcessorCount, zero_centered_gamma, stream); } else { - nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon, output.data(), rsigma.data(), workspace.data(), prop.multiProcessorCount, zero_centered_gamma, stream); } @@ -121,13 +122,9 @@ static void BM_NormForward(benchmark::State& state) { HIP_CHECK(hipEventDestroy(stop)); HIP_CHECK(hipStreamDestroy(stream)); - size_t bytes_read = - N * H * sizeof(IType) + // x - H * sizeof(WType); // gamma + size_t bytes_read = N * H * sizeof(IType) + H * sizeof(WType); - size_t bytes_write = - N * H * sizeof(OType) + // z - N * sizeof(float); // rsigma + size_t bytes_write = N * H * sizeof(OType) + N * sizeof(float); if constexpr (Norm == BenchNormType::LayerNorm) { bytes_read += H * sizeof(WType); // beta @@ -141,7 +138,6 @@ template (); @@ -174,11 +170,11 @@ static void BM_NormBackward(benchmark::State& state) { HIP_CHECK(hipStreamCreate(&stream)); if constexpr (Norm == BenchNormType::LayerNorm) { - nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon, output.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, stream); } else { - nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon, output.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, stream); } @@ -188,7 +184,7 @@ static void BM_NormBackward(benchmark::State& state) { workspace_fwd.dtype()); if constexpr (Norm == BenchNormType::LayerNorm) { - nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon, output.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, stream); @@ -196,7 +192,7 @@ static void BM_NormBackward(benchmark::State& state) { dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(), prop.multiProcessorCount, zero_centered_gamma, stream); } else { - nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon, output.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, stream); From 0949b9accdda3d582a386055b20ff64b1eca4ecb Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 18 May 2026 13:50:40 +0000 Subject: [PATCH 14/18] restore norm key encoding and document decode coupling --- .../common/normalization/common.cpp | 20 ++++------- .../common/normalization/common.h | 35 ++++--------------- 2 files changed, 14 insertions(+), 41 deletions(-) diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 7e8659bea..f7195a375 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -50,6 +50,7 @@ cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { } #endif //#ifndef __HIP_PLATFORM_AMD__ +// Keep this bit layout in sync with the decode helpers in common.h. TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, @@ -58,18 +59,12 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, static_assert(NVTE_INVALID_SCALING < 1024, "This function assumes at most 10 bits used in the scaling mode."); static_assert(kNVTENumTypes < 32, "This function assumes at most 5 bits used in the NVTEDType"); - uint64_t general_key = - (static_cast(itype) << norm_key::kItypeShift) | - (static_cast(otype) << norm_key::kOtypeShift) | - (static_cast(ctype) << norm_key::kCtypeShift) | - (static_cast(wtype) << norm_key::kWtypeShift) | - (static_cast(NormType) << norm_key::kNormTypeShift) | - (static_cast(NormStage) << norm_key::kNormStageShift) | - (static_cast(NormBackend) << norm_key::kNormBackendShift) | - (static_cast(zero_centered_gamma) << norm_key::kZeroCenteredGammaShift) | - (static_cast(mode) << norm_key::kScalingModeShift) | - (static_cast(training) << norm_key::kTrainingShift) | - (static_cast(gamma_in_weight_dtype) << norm_key::kGammaInWeightDtypeShift); + uint64_t general_key = static_cast(itype) | (static_cast(otype) << 5) | + (static_cast(ctype) << 10) | + (static_cast(wtype) << 15) | (uint64_t(NormType) << 20) | + (uint64_t(NormStage)) << 22 | (uint64_t(NormBackend) << 24) | + (uint64_t(zero_centered_gamma) << 26) | (uint64_t(mode) << 27) | + (uint64_t(training) << 37) | (uint64_t(gamma_in_weight_dtype) << 38); return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); } @@ -615,4 +610,3 @@ void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) { transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable; } #endif - diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index e9890b42f..bd06e2ea9 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -38,27 +38,6 @@ namespace transformer_engine { namespace normalization { -// Bit layout for the packed normalization general_key. -// Keep these constants in sync with get_key() encoding in common.cpp. -namespace norm_key { - -constexpr int kItypeShift = 0; -constexpr int kOtypeShift = 5; -constexpr int kCtypeShift = 10; -constexpr int kWtypeShift = 15; -constexpr int kNormTypeShift = 20; -constexpr int kNormStageShift = 22; -constexpr int kNormBackendShift = 24; -constexpr int kZeroCenteredGammaShift = 26; -constexpr int kScalingModeShift = 27; -constexpr int kTrainingShift = 37; -constexpr int kGammaInWeightDtypeShift = 38; - -constexpr uint64_t kDTypeMask = 0x1F; -constexpr uint64_t kNormTypeMask = 0x3; - -} // namespace norm_key - #ifndef __HIP_PLATFORM_AMD__ namespace fe = cudnn_frontend; #endif @@ -224,26 +203,26 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, bool training = true, bool gamma_in_weight_dtype = false); -// Decode helpers assume the same general_key bit layout used by get_key(). +// These decode helpers assume the same general_key bit layout used by get_key() +// in common.cpp. If get_key() changes, update these shifts/masks accordingly. inline DType decode_itype(uint64_t general_key) { - return static_cast((general_key >> norm_key::kItypeShift) & norm_key::kDTypeMask); + return static_cast(general_key & 0x1F); } inline DType decode_otype(uint64_t general_key) { - return static_cast((general_key >> norm_key::kOtypeShift) & norm_key::kDTypeMask); + return static_cast((general_key >> 5) & 0x1F); } inline DType decode_ctype(uint64_t general_key) { - return static_cast((general_key >> norm_key::kCtypeShift) & norm_key::kDTypeMask); + return static_cast((general_key >> 10) & 0x1F); } inline DType decode_wtype(uint64_t general_key) { - return static_cast((general_key >> norm_key::kWtypeShift) & norm_key::kDTypeMask); + return static_cast((general_key >> 15) & 0x1F); } inline NVTE_Norm_Type decode_norm_type(uint64_t general_key) { - return static_cast( - (general_key >> norm_key::kNormTypeShift) & norm_key::kNormTypeMask); + return static_cast((general_key >> 20) & 0x3); } inline const char* dtype_to_string(DType dtype) { From e8afcc8996b78b8f5f70e6ff59144018ec35e8a4 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 18 May 2026 14:59:33 +0000 Subject: [PATCH 15/18] use more optimal BYTES_PER_LDG=16 for layer norm backward tuned config for H=7168 --- .../normalization/layernorm/ln_bwd_semi_cuda_kernel.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index 8914ac18b..743a8f209 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -220,11 +220,11 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp16, fp32, fp16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 7, 8, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 7, 8, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp16, fp32, fp16, fp32, 1, 1, 7, 8, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 7, 8, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, bf16, fp32, bf16, fp32, 1, 1, 7, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 7, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 7, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp16, fp32, fp16, fp32, 1, 1, 7, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 7, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, bf16, fp32, bf16, fp32, 1, 1, 7, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); From 3d22d82a67530b5fd3834a59efc6b280d949e9f5 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 18 May 2026 15:35:54 +0000 Subject: [PATCH 16/18] use more optimal 7 warps config for H=7168, for fp16/fp32 layernorm fwd tuned --- .../layernorm/ln_fwd_cuda_kernel.cu | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index 79a45cd17..c628fd547 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -181,7 +181,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp16, fp16, fp8e4m3, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp16, fp16, fp8e4m3, fp32, 1, 1, 7, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); @@ -209,7 +209,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp8e4m3, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp8e4m3, fp32, 1, 1, 7, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); @@ -292,11 +292,11 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 7, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 7, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp16, fp32, 1, 1, 7, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, bf16, fp32, 1, 1, 7, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); @@ -480,7 +480,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp16, fp16, fp8e5m2, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp16, fp16, fp8e5m2, fp32, 1, 1, 7, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp8e5m2, fp32, 2, 1, 4, 16); @@ -508,7 +508,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp8e5m2, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp8e5m2, fp32, 1, 1, 7, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp8e5m2, fp32, 2, 1, 4, 16); From 63d0e264caad9429ffbf4e9dd8ec5ef03740aa8d Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 18 May 2026 16:48:53 +0000 Subject: [PATCH 17/18] mirror layernorm bwd h=7168 uned config to rmsnorm --- .../rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index 40fdde753..fb3f2862f 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -183,9 +183,9 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp32, fp32, fp32, fp32, 1 REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 7, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 7, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 7, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); @@ -278,11 +278,11 @@ REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, fp16, fp16, fp16, fp32 REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4, true); -REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 7, 16, 4, true); -REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4, +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 7, 16, 4, true); -REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4, +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 7, 16, 4, true); REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, From db7b01795d12579af86d5dee0f4e2a706c95f9ff Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 18 May 2026 19:23:39 +0000 Subject: [PATCH 18/18] Add normalization key layout round-trip check --- .../common/normalization/common.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index f7195a375..f189dd72c 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -68,6 +68,25 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); } +namespace { + +[[maybe_unused]] const bool kNormKeyLayoutCheck = [] { + const uint64_t key = std::get<0>(get_key( + NVTE_Norm_Backend::Te, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Forward, + DType::kFloat16, DType::kBFloat16, DType::kFloat8E4M3, DType::kFloat32, + 1, 1, false, false)); + + NVTE_CHECK(decode_itype(key) == DType::kBFloat16); + NVTE_CHECK(decode_otype(key) == DType::kFloat8E4M3); + NVTE_CHECK(decode_ctype(key) == DType::kFloat32); + NVTE_CHECK(decode_wtype(key) == DType::kFloat16); + NVTE_CHECK(decode_norm_type(key) == NVTE_Norm_Type::RMSNorm); + + return true; +}(); + +} // namespace + template TeNormalizationPlan::TeNormalizationPlan( NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype,