diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt index 6071f9083..e8ad3526f 100644 --- a/benchmarks/cpp/CMakeLists.txt +++ b/benchmarks/cpp/CMakeLists.txt @@ -86,3 +86,4 @@ add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp) add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp) add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp) add_te_benchmark(bench_casttranspose cast/bench_casttranspose.cpp) +add_te_benchmark(bench_normalization normalization/bench_normalization.cpp) diff --git a/benchmarks/cpp/normalization/bench_normalization.cpp b/benchmarks/cpp/normalization/bench_normalization.cpp new file mode 100644 index 000000000..92ac3c946 --- /dev/null +++ b/benchmarks/cpp/normalization/bench_normalization.cpp @@ -0,0 +1,293 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include + +#include "benchmark_utils.h" + +#include "transformer_engine/normalization_hip.h" +#include "transformer_engine/transformer_engine_hip.h" + +using namespace te_bench; +using namespace transformer_engine; + +#define NORM_SHAPES \ + ->Args({8192, 128}) \ + ->Args({8192, 1536}) \ + ->Args({8192, 7168}) + +constexpr float kNormEpsilon = 1e-5f; + +enum class BenchNormType { + LayerNorm, + RMSNorm, +}; + +template +constexpr DType dtype_of() { + if constexpr (std::is_same_v) { + return DType::kFloat32; + } else if constexpr (std::is_same_v) { + return DType::kBFloat16; + } else { + return DType::kFloat16; + } +} + +template +static void BM_NormForward(benchmark::State& state) { + const size_t N = state.range(0); + const size_t H = state.range(1); + constexpr bool zero_centered_gamma = false; + + const DType wtype = dtype_of(); + const DType itype = dtype_of(); + const DType otype = dtype_of(); + + test::Tensor input("input", std::vector{N, H}, itype); + test::Tensor output("output", std::vector{N, H}, otype); + test::Tensor gamma("gamma", std::vector{H}, wtype); + test::Tensor beta("beta", std::vector{H}, wtype); + test::Tensor mu("mu", std::vector{N}, DType::kFloat32); + test::Tensor rsigma("rsigma", std::vector{N}, DType::kFloat32); + test::Tensor workspace; + + test::fillUniform(&input); + test::fillUniform(&gamma); + test::fillUniform(&beta); + test::setRandomScale(&output); + + hipDeviceProp_t prop; + HIP_CHECK(hipGetDeviceProperties(&prop, 0)); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon, + output.data(), mu.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon, + output.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + workspace = test::Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon, + output.data(), mu.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon, + output.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + warmup_gpu(); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon, + output.data(), mu.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon, + output.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0.0f; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + HIP_CHECK(hipStreamDestroy(stream)); + + size_t bytes_read = N * H * sizeof(IType) + H * sizeof(WType); + + size_t bytes_write = N * H * sizeof(OType) + N * sizeof(float); + + if constexpr (Norm == BenchNormType::LayerNorm) { + bytes_read += H * sizeof(WType); // beta + bytes_write += N * sizeof(float); // mu + } + + set_bytes_processed(state, bytes_read + bytes_write); +} + +template +static void BM_NormBackward(benchmark::State& state) { + const size_t N = state.range(0); + const size_t H = state.range(1); + constexpr bool zero_centered_gamma = false; + + const DType wtype = dtype_of(); + const DType itype = dtype_of(); + const DType otype = dtype_of(); + + test::Tensor input("input", std::vector{N, H}, itype); + test::Tensor output("output", std::vector{N, H}, otype); + test::Tensor gamma("gamma", std::vector{H}, wtype); + test::Tensor beta("beta", std::vector{H}, wtype); + test::Tensor mu("mu", std::vector{N}, DType::kFloat32); + test::Tensor rsigma("rsigma", std::vector{N}, DType::kFloat32); + test::Tensor dz("dz", std::vector{N, H}, otype); + test::Tensor dx("dx", std::vector{N, H}, itype); + test::Tensor dgamma("dgamma", std::vector{H}, wtype); + test::Tensor dbeta("dbeta", std::vector{H}, wtype); + test::Tensor workspace_fwd; + test::Tensor workspace_bwd; + + test::fillUniform(&input); + test::fillUniform(&gamma); + test::fillUniform(&beta); + test::setRandomScale(&output); + test::fillUniform(&dz); + + hipDeviceProp_t prop; + HIP_CHECK(hipGetDeviceProperties(&prop, 0)); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon, + output.data(), mu.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon, + output.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + workspace_fwd = test::Tensor("workspace_fwd", + workspace_fwd.rowwise_shape(), + workspace_fwd.dtype()); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), kNormEpsilon, + output.data(), mu.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + + nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), kNormEpsilon, + output.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + workspace_bwd = test::Tensor("workspace_bwd", + workspace_bwd.rowwise_shape(), + workspace_bwd.dtype()); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + warmup_gpu(); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + if constexpr (Norm == BenchNormType::LayerNorm) { + nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } else { + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, stream); + } + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0.0f; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + HIP_CHECK(hipStreamDestroy(stream)); + + size_t bytes_read = + N * H * sizeof(OType) + // dz + N * H * sizeof(IType) + // x + N * sizeof(float) + // rsigma + H * sizeof(WType); // gamma + + size_t bytes_write = + N * H * sizeof(IType) + // dx + H * sizeof(WType); // dgamma + + if constexpr (Norm == BenchNormType::LayerNorm) { + bytes_read += N * sizeof(float); // mu + bytes_write += H * sizeof(WType); // dbeta + } + + set_bytes_processed(state, bytes_read + bytes_write); +} + +#define REGISTER_NORM_BENCH(NORM_ENUM, NORM_NAME, WTYPE, ITYPE, OTYPE, CTYPE, NAME) \ + BENCHMARK_TEMPLATE(BM_NormForward, NORM_ENUM, WTYPE, ITYPE, OTYPE, CTYPE) \ + ->Name("BM_" NORM_NAME "Forward/" NAME) \ + NORM_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_NormBackward, NORM_ENUM, WTYPE, ITYPE, OTYPE, CTYPE) \ + ->Name("BM_" NORM_NAME "Backward/" NAME) \ + NORM_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +#define REGISTER_RMSNORM(WTYPE, ITYPE, OTYPE, CTYPE, NAME) \ + REGISTER_NORM_BENCH(BenchNormType::RMSNorm, "RMSNorm", WTYPE, ITYPE, OTYPE, CTYPE, NAME) + +#define REGISTER_LAYERNORM(WTYPE, ITYPE, OTYPE, CTYPE, NAME) \ + REGISTER_NORM_BENCH(BenchNormType::LayerNorm, "LayerNorm", WTYPE, ITYPE, OTYPE, CTYPE, NAME) + +REGISTER_RMSNORM(hip_bfloat16, hip_bfloat16, hip_bfloat16, float, "BF16_BF16_BF16_FP32") +REGISTER_RMSNORM(half, half, half, float, "FP16_FP16_FP16_FP32") +REGISTER_RMSNORM(float, float, float, float, "FP32_FP32_FP32_FP32") + +REGISTER_LAYERNORM(hip_bfloat16, hip_bfloat16, hip_bfloat16, float, "BF16_BF16_BF16_FP32") +REGISTER_LAYERNORM(half, half, half, float, "FP16_FP16_FP16_FP32") +REGISTER_LAYERNORM(float, float, float, float, "FP32_FP32_FP32_FP32") + +BENCHMARK_MAIN(); diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index d6aa55b37..f189dd72c 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -50,6 +50,7 @@ cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { } #endif //#ifndef __HIP_PLATFORM_AMD__ +// Keep this bit layout in sync with the decode helpers in common.h. TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, @@ -67,6 +68,25 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); } +namespace { + +[[maybe_unused]] const bool kNormKeyLayoutCheck = [] { + const uint64_t key = std::get<0>(get_key( + NVTE_Norm_Backend::Te, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Forward, + DType::kFloat16, DType::kBFloat16, DType::kFloat8E4M3, DType::kFloat32, + 1, 1, false, false)); + + NVTE_CHECK(decode_itype(key) == DType::kBFloat16); + NVTE_CHECK(decode_otype(key) == DType::kFloat8E4M3); + NVTE_CHECK(decode_ctype(key) == DType::kFloat32); + NVTE_CHECK(decode_wtype(key) == DType::kFloat16); + NVTE_CHECK(decode_norm_type(key) == NVTE_Norm_Type::RMSNorm); + + return true; +}(); + +} // namespace + template TeNormalizationPlan::TeNormalizationPlan( NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, @@ -609,4 +629,3 @@ void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) { transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable; } #endif - diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 70584fac3..bd06e2ea9 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "../common.h" @@ -202,6 +203,60 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, bool training = true, bool gamma_in_weight_dtype = false); +// These decode helpers assume the same general_key bit layout used by get_key() +// in common.cpp. If get_key() changes, update these shifts/masks accordingly. +inline DType decode_itype(uint64_t general_key) { + return static_cast(general_key & 0x1F); +} + +inline DType decode_otype(uint64_t general_key) { + return static_cast((general_key >> 5) & 0x1F); +} + +inline DType decode_ctype(uint64_t general_key) { + return static_cast((general_key >> 10) & 0x1F); +} + +inline DType decode_wtype(uint64_t general_key) { + return static_cast((general_key >> 15) & 0x1F); +} + +inline NVTE_Norm_Type decode_norm_type(uint64_t general_key) { + return static_cast((general_key >> 20) & 0x3); +} + +inline const char* dtype_to_string(DType dtype) { + switch (dtype) { + case DType::kFloat32: + return "fp32"; + case DType::kFloat16: + return "fp16"; + case DType::kBFloat16: + return "bf16"; + case DType::kFloat8E4M3: + return "fp8e4m3"; + case DType::kFloat8E5M2: + return "fp8e5m2"; + case DType::kByte: + return "byte"; + case DType::kInt32: + return "int32"; + default: + return "unknown"; + } +} + +inline const char* norm_type_to_string(NVTE_Norm_Type norm_type) { + switch (norm_type) { + case NVTE_Norm_Type::LayerNorm: + return "LayerNorm"; + case NVTE_Norm_Type::RMSNorm: + return "RMSNorm"; + default: + return "unknown"; + } +} + template class TeNormalizationRegistry { private: @@ -226,14 +281,31 @@ class TeNormalizationRegistry { getInstance().general_function_map[general_key].emplace(hidden_size, Function(func)); return 0; } - + static Function getKernel(TupleKeyType key) { auto& instance = getInstance(); auto [general_key, batch_size, hidden_size, is_tuned] = key; if (is_tuned) { auto it = instance.tuned_function_map.find(key); if (it != instance.tuned_function_map.end()) return it->second; - } + + static thread_local std::unordered_set warned_keys; + if (warned_keys.insert(key).second) { + NVTE_WARN("Falling back to general normalization kernel because no tuned kernel " + "is available for this config. norm_type=", + norm_type_to_string(decode_norm_type(general_key)), + ", hidden_size=", + hidden_size, + ", wtype=", + dtype_to_string(decode_wtype(general_key)), + ", itype=", + dtype_to_string(decode_itype(general_key)), + ", otype=", + dtype_to_string(decode_otype(general_key)), + ", ctype=", + dtype_to_string(decode_ctype(general_key))); + } + } if (instance.general_function_map.count(general_key) == 0) { NVTE_ERROR("Unavailable kernel for this normalization config."); } diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index 8f0a8a14b..743a8f209 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -154,6 +154,12 @@ void launch_ln_bwd_general_(LaunchParams &launch_params, // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 128, fp16, fp32, fp16, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 128, bf16, fp32, bf16, fp32, 1, 4, 1, 4, 4); + REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); @@ -214,6 +220,12 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp16, fp32, fp16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 7, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 7, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp16, fp32, fp16, fp32, 1, 1, 7, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 7, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, bf16, fp32, bf16, fp32, 1, 1, 7, 16, 4); + REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index f76fcf582..c628fd547 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -142,6 +142,7 @@ void launch_ln_fwd_general_(LaunchParams &launch_params, // Create tuned launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); @@ -152,6 +153,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, bf16, bf16, fp8e4m3, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); @@ -168,6 +170,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, bf16, bf16, fp8e4m3, fp REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, bf16, bf16, fp8e4m3, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); @@ -178,6 +181,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp16, fp16, fp8e4m3, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp16, fp16, fp8e4m3, fp32, 1, 1, 7, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); @@ -194,6 +198,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp16, fp16, fp8e4m3, fp REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp8e4m3, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); @@ -204,6 +209,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp8e4m3, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp8e4m3, fp32, 1, 1, 7, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); @@ -220,6 +226,12 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, fp8e4m3, fp REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp8e4m3, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp32, fp32, fp16, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp32, fp32, bf16, fp32, 1, 4, 1, 4); + REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp16, fp32, 1, 4, 1, 16); @@ -280,6 +292,12 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 7, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 7, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp16, fp32, 1, 1, 7, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, bf16, fp32, 1, 1, 7, 16); + REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp16, fp32, 1, 1, 4, 16); @@ -423,7 +441,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, bf16, fp32 #ifdef __HIP_PLATFORM_AMD__ // ROCM uses TE normalization for e5m2 - +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); @@ -434,6 +452,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, bf16, bf16, fp8e5m2, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, bf16, bf16, fp8e5m2, fp32, 2, 1, 4, 16); @@ -450,6 +469,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, bf16, bf16, fp8e5m2, fp REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, bf16, bf16, fp8e5m2, fp32, 4, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, bf16, bf16, fp8e5m2, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); @@ -460,6 +480,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp16, fp16, fp8e5m2, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp16, fp16, fp8e5m2, fp32, 1, 1, 7, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp8e5m2, fp32, 2, 1, 4, 16); @@ -476,6 +497,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp16, fp16, fp8e5m2, fp REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp16, fp16, fp8e5m2, fp32, 4, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp8e5m2, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 128, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 16); @@ -486,6 +508,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp8e5m2, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp8e5m2, fp32, 1, 1, 7, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp8e5m2, fp32, 2, 1, 4, 16); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index da940254c..fb3f2862f 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -155,6 +155,10 @@ void launch_rmsnorm_bwd_general_(LaunchParams &launch_para // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4, 4); + REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); @@ -167,6 +171,10 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1 REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); @@ -175,6 +183,10 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp32, fp32, fp32, fp32, 1 REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 7, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 7, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 7, 16, 4); + REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); @@ -217,6 +229,13 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4, 4, + true); + REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4, true); REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4, @@ -238,6 +257,13 @@ REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, fp16, fp16, fp16, fp32 REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4, true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4, + true); + REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, true); REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4, @@ -252,6 +278,13 @@ REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, fp16, fp16, fp16, fp32 REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4, true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 7, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 7, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 7, 16, 4, + true); + REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, true); REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4, diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index 3a28ebf13..3218752f8 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -143,6 +143,13 @@ void launch_rmsnorm_fwd_general_(LaunchParams &launch_param // Create rmsnorm tuned launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp32, fp32, fp32, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp16, fp16, fp16, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, bf16, fp32, 1, 4, 1, 4); + REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); @@ -164,6 +171,13 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); @@ -178,6 +192,13 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp32, fp32, 1, REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); @@ -238,6 +259,9 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, bf16, fp32, #ifdef __HIP_PLATFORM_AMD__ // ROCM uses TE normalization for e5m2 +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 128, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); @@ -251,6 +275,10 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, fp8e5m2, fp32, REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1536, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 16); + REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp32, fp32, fp8e5m2, fp32, 1, 4, 1, 16); @@ -259,6 +287,10 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, fp8e5m2, fp32, REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 7168, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16); + REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp8e5m2, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp32, fp32, fp8e5m2, fp32, 1, 1, 4, 16);