Skip to content

RMS Norm Optimization#583

Merged
aris134 merged 19 commits into
devfrom
amartin/rmsnorm
May 18, 2026
Merged

RMS Norm Optimization#583
aris134 merged 19 commits into
devfrom
amartin/rmsnorm

Conversation

@aris134
Copy link
Copy Markdown
Contributor

@aris134 aris134 commented May 12, 2026

Description

Fixes # (16527)

RMSNorm falls back to general kernel implementation on several DeepSeek and Qwen shapes, causing poor performance. These shapes have been registered with the tuned kernel cache, and a performance benchmark for RMSNorm has been added.

Additionally, a fallback warning is printed the first time at which a tuned config is not found for a requested kernel. For example:

in function getKernel: Falling back to general normalization kernel because no tuned kernel is available for this config. hidden_size=128, wtype=bf16, itype=bf16, otype=bf16, ctype=fp32

E2E TFLOPS/s/GPU for proxy models (Previous -> Current with RMSNorm tuning) :

Qwen:
bf16: 369.4 -> 374.7
fp8: 352.1 ->358.2

Deepseek:
bf16: 501.4 -> 529.4
fp8: 463.9 -> 511.4

Also added matching tuned configs for LayerNorm.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@aris134 aris134 requested a review from alextmagro May 12, 2026 12:13
@aris134 aris134 self-assigned this May 12, 2026
@aris134 aris134 marked this pull request as ready for review May 12, 2026 19:15
prop.multiProcessorCount, zero_centered_gamma, stream);
}

HIP_CHECK(hipStreamSynchronize(stream));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is synchronization needed before warmup?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. These are in fact redundant since the warmup already calls a device-wide sync anyway. Removed in 4256e3c

#include <typeindex>
#include <unordered_map>
#include <vector>
#include <unordered_set>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move it after unordered_map

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 2f9ff47

bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING,
bool training = true, bool gamma_in_weight_dtype = false);

inline DType decode_itype(uint64_t general_key) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is fragile because encoding could change. At least put comments here and at encoding block that they should match

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I updated this in d548d54 to make the coupling between encoding/decoding explicit by introducing shared norm_key bit-layout constants and using them in both get_key() and the decode helpers. I also added comments documenting that the layouts must remain in sync, so future changes to the packed key format are less likely to silently diverge.

@aris134 aris134 requested a review from ipanfilo May 15, 2026 16:40
Comment thread benchmarks/cpp/normalization/bench_normalization.cpp Outdated
Comment thread benchmarks/cpp/normalization/bench_normalization.cpp Outdated
Comment thread transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu Outdated
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16);

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BWD you have 7 warps set, but here you have 4. Is this optimal?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For BWD, 7 warps is indeed more performant than 4 warps across all DTypes tested. For FWD, 7 warps gives a performance boost to fp16/fp32, but regresses on bf16, so I kept 4 warps for bf16 for the h=7168 config. See 3d22d82

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did not see the same performance boost for RMSNorm forward however, so I left warps=4 there

Comment thread transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu Outdated
(uint64_t(NormStage)) << 22 | (uint64_t(NormBackend) << 24) |
(uint64_t(zero_centered_gamma) << 26) | (uint64_t(mode) << 27) |
(uint64_t(training) << 37) | (uint64_t(gamma_in_weight_dtype) << 38);
uint64_t general_key =
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get the motivation behind this change, but this affects upstream code. I feel like we're more likely to miss a key change from upstream if we have diverged here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point — I think I overcorrected this by refactoring the key layout into shared named constants, which does increase divergence from upstream and could make future key-layout changes easier to miss during syncs.

I've reverted the get_key() refactor back to the original upstream-style encoding layout and instead added explicit comments at both the encoding and decode sites documenting that the bit layouts must stay in sync. See 0949b9a

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not add a static assert at the decode so the code doesn't compile for us if upstream changes the encoding?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to accomplish this without modifying the get_key definition itself (e.g. introducing a shared layout definition/helper used by both get_key and the decode helpers). Otherwise the decode functions only have their own hardcoded assumptions about the bit layout and cannot independently detect drift in get_key.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right... We'd have to make encoder a constexpr and move to header. That would be ideal, but ends up with more upstream changes anyway.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, could we add a runtime check instead of a static assert that verifies the round trip is valid? Maybe something like this:

  namespace {                                                                                                                                                                                                               
  [[maybe_unused]] const bool kNormKeyLayoutCheck = [] {
      auto [key, b, h, t] = get_key(                                                                                                                                                                                        
          NVTE_Norm_Backend::Te, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Forward,                                                                                                                                         
          DType::kFloat16,DType::kBFloat16, DType::kFloat8E4M3, DType::kFloat32, 1, 1, false, false);                                                                                                                                                                                              
      NVTE_CHECK(decode_itype(key)     == DType::kBFloat16);                                                                                                                                                                
      NVTE_CHECK(decode_otype(key)     == DType::kFloat8E4M3);                                                                                                                                                              
      NVTE_CHECK(decode_ctype(key)     == DType::kFloat32);                                                                                                                                                                 
      NVTE_CHECK(decode_wtype(key)     == DType::kFloat16);                                                                                                                                                                 
      NVTE_CHECK(decode_norm_type(key) == NVTE_Norm_Type::RMSNorm);
      return true;                                                                                                                                                                                                          
  }();                                                                                                                                                                                                                      
  }  

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is helpful, thanks for the suggestion. I've gone ahead and implemented this in db7b017

@aris134 aris134 requested a review from alextmagro May 18, 2026 16:57
Copy link
Copy Markdown
Contributor

@alextmagro alextmagro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! After merge, please work with Sudharshan to run the E2E configs again and get updated performance numbers

@aris134 aris134 merged commit 5cb098b into dev May 18, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants