RMS Norm Optimization#583
Conversation
… missing configs for layer norm
| prop.multiProcessorCount, zero_centered_gamma, stream); | ||
| } | ||
|
|
||
| HIP_CHECK(hipStreamSynchronize(stream)); |
There was a problem hiding this comment.
Is synchronization needed before warmup?
There was a problem hiding this comment.
Good point. These are in fact redundant since the warmup already calls a device-wide sync anyway. Removed in 4256e3c
| #include <typeindex> | ||
| #include <unordered_map> | ||
| #include <vector> | ||
| #include <unordered_set> |
There was a problem hiding this comment.
nit: move it after unordered_map
| bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, | ||
| bool training = true, bool gamma_in_weight_dtype = false); | ||
|
|
||
| inline DType decode_itype(uint64_t general_key) { |
There was a problem hiding this comment.
This code is fragile because encoding could change. At least put comments here and at encoding block that they should match
There was a problem hiding this comment.
Good point. I updated this in d548d54 to make the coupling between encoding/decoding explicit by introducing shared norm_key bit-layout constants and using them in both get_key() and the decode helpers. I also added comments documenting that the layouts must remain in sync, so future changes to the packed key format are less likely to silently diverge.
| REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); | ||
| REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); | ||
|
|
||
| REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16); |
There was a problem hiding this comment.
BWD you have 7 warps set, but here you have 4. Is this optimal?
There was a problem hiding this comment.
For BWD, 7 warps is indeed more performant than 4 warps across all DTypes tested. For FWD, 7 warps gives a performance boost to fp16/fp32, but regresses on bf16, so I kept 4 warps for bf16 for the h=7168 config. See 3d22d82
There was a problem hiding this comment.
Did not see the same performance boost for RMSNorm forward however, so I left warps=4 there
| (uint64_t(NormStage)) << 22 | (uint64_t(NormBackend) << 24) | | ||
| (uint64_t(zero_centered_gamma) << 26) | (uint64_t(mode) << 27) | | ||
| (uint64_t(training) << 37) | (uint64_t(gamma_in_weight_dtype) << 38); | ||
| uint64_t general_key = |
There was a problem hiding this comment.
I get the motivation behind this change, but this affects upstream code. I feel like we're more likely to miss a key change from upstream if we have diverged here.
There was a problem hiding this comment.
Good point — I think I overcorrected this by refactoring the key layout into shared named constants, which does increase divergence from upstream and could make future key-layout changes easier to miss during syncs.
I've reverted the get_key() refactor back to the original upstream-style encoding layout and instead added explicit comments at both the encoding and decode sites documenting that the bit layouts must stay in sync. See 0949b9a
There was a problem hiding this comment.
Why not add a static assert at the decode so the code doesn't compile for us if upstream changes the encoding?
There was a problem hiding this comment.
I'm not sure how to accomplish this without modifying the get_key definition itself (e.g. introducing a shared layout definition/helper used by both get_key and the decode helpers). Otherwise the decode functions only have their own hardcoded assumptions about the bit layout and cannot independently detect drift in get_key.
There was a problem hiding this comment.
Right... We'd have to make encoder a constexpr and move to header. That would be ideal, but ends up with more upstream changes anyway.
There was a problem hiding this comment.
Actually, could we add a runtime check instead of a static assert that verifies the round trip is valid? Maybe something like this:
namespace {
[[maybe_unused]] const bool kNormKeyLayoutCheck = [] {
auto [key, b, h, t] = get_key(
NVTE_Norm_Backend::Te, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Forward,
DType::kFloat16,DType::kBFloat16, DType::kFloat8E4M3, DType::kFloat32, 1, 1, false, false);
NVTE_CHECK(decode_itype(key) == DType::kBFloat16);
NVTE_CHECK(decode_otype(key) == DType::kFloat8E4M3);
NVTE_CHECK(decode_ctype(key) == DType::kFloat32);
NVTE_CHECK(decode_wtype(key) == DType::kFloat16);
NVTE_CHECK(decode_norm_type(key) == NVTE_Norm_Type::RMSNorm);
return true;
}();
} There was a problem hiding this comment.
This is helpful, thanks for the suggestion. I've gone ahead and implemented this in db7b017
alextmagro
left a comment
There was a problem hiding this comment.
LGTM! After merge, please work with Sudharshan to run the E2E configs again and get updated performance numbers
Description
Fixes # (16527)
RMSNorm falls back to general kernel implementation on several DeepSeek and Qwen shapes, causing poor performance. These shapes have been registered with the tuned kernel cache, and a performance benchmark for RMSNorm has been added.
Additionally, a fallback warning is printed the first time at which a tuned config is not found for a requested kernel. For example:
E2E TFLOPS/s/GPU for proxy models (Previous -> Current with RMSNorm tuning) :
Qwen:
bf16: 369.4 -> 374.7
fp8: 352.1 ->358.2
Deepseek:
bf16: 501.4 -> 529.4
fp8: 463.9 -> 511.4
Also added matching tuned configs for LayerNorm.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: