From 47e55f9bdee7dc75128a45347e2dd689aa8b000b Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 26 Jan 2026 16:29:16 +0100 Subject: [PATCH 01/21] add fused te fused layernorm --- .../core/extensions/transformer_engine.py | 11 +++++++++-- .../core/transformer/transformer_layer.py | 19 ++++++++++++++----- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index ef8527e9e5e..5a16423694b 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -461,8 +461,16 @@ def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5) zero_centered_gamma=config.layernorm_zero_centered_gamma, **_get_extra_te_kwargs(config), ) + elif config.normalization == "ResidualRMSNorm": + extra_te_kwargs = _get_extra_te_kwargs(config) + extra_te_kwargs["dtype"] = extra_te_kwargs["params_dtype"] + del extra_te_kwargs["params_dtype"] + instance = te.pytorch.ops.Sequential( + te.pytorch.ops.MakeExtraOutput(), + te.pytorch.ops.RMSNorm(normalized_shape=hidden_size, eps=eps, zero_centered_gamma=config.layernorm_zero_centered_gamma, **extra_te_kwargs), + ) else: - raise Exception("Only LayerNorm and RMSNorm are curently supported") + raise Exception("Only LayerNorm, RMSNorm and ResidualRMSNorm are curently supported") return instance @@ -2187,7 +2195,6 @@ def forward(self, hidden_states: torch.Tensor, **kwargs) -> Tuple[Tensor, Option else: TEFusedMLP = None # type: ignore[assignment, misc] - class TEDelayedScaling(te.common.recipe.DelayedScaling): """ Wrapper for the Transformer-Engine's `DelayedScaling` layer. diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 12c24868473..d42922a3f4a 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -569,8 +569,6 @@ def _forward_attention( inference_context = deprecate_inference_params(inference_context, inference_params) - # Residual connection. - residual = hidden_states # Optional Input Layer norm if self.recompute_input_layernorm: @@ -583,6 +581,13 @@ def _forward_attention( with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: input_layernorm_output = self.input_layernorm(hidden_states) + if isinstance(input_layernorm_output, tuple): + if len(input_layernorm_output) != 2: + raise ValueError(f"When the output of input_layernorm is a tuple, it is expected to have 2 elements (output, residual), but got {len(input_layernorm_output)}") + input_layernorm_output, residual = input_layernorm_output + else: + residual = hidden_states + using_fused_tp_inference_kernel = (not self.training) and ( self.config.inference_fuse_tp_communication ) @@ -637,12 +642,16 @@ def _forward_attention( hidden_states, name="attn_norm", forced_released_tensors=[residual] ) - # Residual connection. - residual = hidden_states - # Optional Layer norm after self-attention pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states) + if isinstance(pre_cross_attn_layernorm_output, tuple): + if len(pre_cross_attn_layernorm_output) != 2: + raise ValueError(f"When the output of pre_cross_attn_layernorm_output is a tuple, it is expected to have 2 elements (output, residual), but got {len(pre_cross_attn_layernorm_output)}") + pre_cross_attn_layernorm_output, residual = pre_cross_attn_layernorm_output + else: + residual = hidden_states + # Cross attention. attention_output_with_bias = self.cross_attention( pre_cross_attn_layernorm_output, From 4e8602604dcc99cfa8982744ae0da3af83aa1b8b Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 02/21] revert changes to normalization parameter, add fusion flag instead --- .../core/extensions/transformer_engine.py | 30 ++++++++++--------- .../core/transformer/transformer_config.py | 9 +++++- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 5a16423694b..970939b6a05 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -454,21 +454,23 @@ def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5) assert hasattr( te.pytorch, "RMSNorm" ), "Transformer-Engine >= v0.11 required to use this feature" - instance = te.pytorch.RMSNorm( - hidden_size=hidden_size, - eps=eps, - sequence_parallel=config.sequence_parallel, - zero_centered_gamma=config.layernorm_zero_centered_gamma, - **_get_extra_te_kwargs(config), - ) - elif config.normalization == "ResidualRMSNorm": + extra_te_kwargs = _get_extra_te_kwargs(config) - extra_te_kwargs["dtype"] = extra_te_kwargs["params_dtype"] - del extra_te_kwargs["params_dtype"] - instance = te.pytorch.ops.Sequential( - te.pytorch.ops.MakeExtraOutput(), - te.pytorch.ops.RMSNorm(normalized_shape=hidden_size, eps=eps, zero_centered_gamma=config.layernorm_zero_centered_gamma, **extra_te_kwargs), - ) + if config.fused_residual_rmsnorm: + extra_te_kwargs["dtype"] = extra_te_kwargs["params_dtype"] + del extra_te_kwargs["params_dtype"] + instance = te.pytorch.ops.Sequential( + te.pytorch.ops.MakeExtraOutput(), + te.pytorch.ops.RMSNorm(normalized_shape=hidden_size, eps=eps, zero_centered_gamma=config.layernorm_zero_centered_gamma, **extra_te_kwargs), + ) + else: + instance = te.pytorch.RMSNorm( + hidden_size=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **extra_te_kwargs, + ) else: raise Exception("Only LayerNorm, RMSNorm and ResidualRMSNorm are curently supported") diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index dce438520aa..c7456c178bb 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -382,6 +382,9 @@ class TransformerConfig(ModelParallelConfig): fused_single_qkv_rope: bool = False """If set, avoid splitting QKV before ROPE forward and avoid concatenating ROPE dgrads.""" + fused_residual_rmsnorm: bool = False + """If True, uses fuses residual connection and RMSNorm when TE is used.""" + #################### # activation recomputation #################### @@ -1540,7 +1543,11 @@ def __post_init__(self): "If you use bias in MLP FC1, we recommend setting bias_activation_fusion " "to True and use_te_activation_func to False." ) - + + if self.fused_residual_rmsnorm: + if self.normalization != "RMSNorm": + raise ValueError("fused_residual_rmsnorm is only supported when normalization is RMSNorm.") + if self.use_te_activation_func: if self.activation_func not in (F.gelu, F.silu, F.relu): raise ValueError( From 1146357c4f8a13d04cf993a12ae3788d7c8a84f9 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 03/21] Refactor TEFusedResidualRMSNorm properly wrapping it for compatibility with mcore --- .../core/extensions/transformer_engine.py | 198 +++++++++++++++++- 1 file changed, 191 insertions(+), 7 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 970939b6a05..ecc2edfe685 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -430,6 +430,183 @@ def __new__(cls, config: TransformerConfig): TEActivationOp = None +if HAVE_TE and is_te_min_version("1.13.0"): + + class TEFusedResidualRMSNorm(te.pytorch.RMSNorm): + """ + RMSNorm with fused residual output for Megatron Core. + + Inherits from te.pytorch.RMSNorm to maintain all parameter management, + checkpoint compatibility, and Megatron-specific features. Creates a fused + implementation using TE's ops API that shares the base class parameters. + + The fused implementation uses: + - MakeExtraOutput: Forks the residual connection + - RMSNorm: Normalizes the main path + + Forward pass returns: (normalized_output, residual) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Fused implementation (stored in tuple to avoid submodule registration) + self._fused_impl: Optional[Tuple[te.pytorch.ops.Sequential]] = None + + def _make_fused_impl(self) -> te.pytorch.ops.Sequential: + """ + Construct fused ops pipeline that shares parameters with base RMSNorm. + + Creates MakeExtraOutput + RMSNorm ops, where the RMSNorm op shares + the weight parameter with self.weight from the base class. + """ + + fused_impl = te.pytorch.ops.Sequential() + + # Op 1: MakeExtraOutput - forks the residual + fused_impl.append(te.pytorch.ops.MakeExtraOutput()) + + # Op 2: RMSNorm - shares weight parameter with self + kwargs = { + "eps": self.eps, + "device": "meta", # Already initialized + "dtype": self.weight.dtype, + "zero_centered_gamma": self.zero_centered_gamma, + } + + # Add sm_margin if available (TE 2.5+) + if hasattr(self, '_sm_margins'): + kwargs["sm_margin"] = self._sm_margins + + rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs) + + # CRITICAL: Share the weight parameter with base class + # This ensures checkpointing works through the base class + rmsnorm_op.weight = self.weight + + fused_impl.append(rmsnorm_op) + + # Transfer hooks from base module to fused implementation + # This is CRITICAL for DDP to work correctly + self._register_hooks_on_fused_impl(fused_impl) + + return fused_impl + + def _register_hooks_on_fused_impl(self, fused_impl: torch.nn.Module) -> None: + """ + Transfer hooks from base RMSNorm to fused implementation. + + This is critical for distributed training - DDP registers hooks on the + base module that must be executed. Follows TEFusedMLP pattern. + + Note: Transformer Engine's op fuser does not expose intermediate tensors, + so hooks that modify tensors will not work correctly. + """ + + # Collect hooks from all submodules (including self) + forward_pre_hooks = [] + forward_post_hooks = [] + backward_pre_hooks = [] + backward_post_hooks = [] + + for submodule in self.modules(): + for hook in submodule._forward_pre_hooks.values(): + forward_pre_hooks.append((submodule, hook)) + for hook in submodule._forward_hooks.values(): + forward_post_hooks.append((submodule, hook)) + for hook in submodule._backward_pre_hooks.values(): + backward_pre_hooks.append((submodule, hook)) + for hook in submodule._backward_hooks.values(): + backward_post_hooks.append((submodule, hook)) + + # Pre-forward hooks + # Note: DDP pre-forward hooks are safe since they do not + # interact with input tensor. + if forward_pre_hooks: + from megatron.core.distributed import distributed_data_parallel + + if any( + inspect.getmodule(hook) != distributed_data_parallel + for _, hook in forward_pre_hooks + ): + warnings.warn( + "TEFusedResidualRMSNorm module has a submodule with a pre-forward hook. " + "TEFusedResidualRMSNorm module does not expose intermediate tensors, " + "so the hook may have incorrect behavior if it attempts to " + "access the input tensor." + ) + + def forward_pre_hook(module, *_) -> None: + for submodule, hook in forward_pre_hooks: + # Assume that hook does not interact with input + ret = hook(submodule, None) + if ret is not None: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not expose intermediate tensors, but " + "submodule has pre-forward hook that modifies input tensor." + ) + + fused_impl.register_forward_pre_hook(forward_pre_hook) + + # Post-forward hooks + if forward_post_hooks: + warnings.warn( + "TEFusedResidualRMSNorm module has a submodule with a post-forward hook. " + "TEFusedResidualRMSNorm module does not expose intermediate tensors, " + "so the hook may have incorrect behavior if it attempts to " + "access the input or output tensors." + ) + + def forward_post_hook(module, *_) -> None: + for submodule, hook in forward_post_hooks: + # Assume that hook does not interact with input or output + ret = hook(submodule, None, None) + if ret is not None: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not expose intermediate tensors, but " + "submodule has post-forward hook that modifies output tensor." + ) + + fused_impl.register_forward_hook(forward_post_hook) + + # Backward hooks + if backward_pre_hooks: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not support submodules with pre-backward hooks" + ) + if backward_post_hooks: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not support submodules with post-backward hooks" + ) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass with fused residual output. + + Args: + hidden_states: Input tensor [s, b, h] + + Returns: + Tuple of (normalized_output, residual), both [s, b, h] + + Note: + Sequential.forward() automatically returns (output, extra_outputs...) + when MakeExtraOutput is present, so we don't need manual unpacking. + """ + + # Construct fused impl lazily on first forward + # (in case parameters are modified after __init__) + if self._fused_impl is None: + self._fused_impl = (self._make_fused_impl(),) + + # Apply fused implementation + # Sequential returns (normalized_output, residual) automatically + return self._fused_impl[0](hidden_states) + +else: + TEFusedResidualRMSNorm = None # type: ignore[assignment, misc] + + class TENorm: """A conditional wrapper to initialize an instance of Transformer-Engine's `LayerNorm` or `RMSNorm` based on input.""" @@ -454,18 +631,25 @@ def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5) assert hasattr( te.pytorch, "RMSNorm" ), "Transformer-Engine >= v0.11 required to use this feature" - + extra_te_kwargs = _get_extra_te_kwargs(config) + if config.fused_residual_rmsnorm: - extra_te_kwargs["dtype"] = extra_te_kwargs["params_dtype"] - del extra_te_kwargs["params_dtype"] - instance = te.pytorch.ops.Sequential( - te.pytorch.ops.MakeExtraOutput(), - te.pytorch.ops.RMSNorm(normalized_shape=hidden_size, eps=eps, zero_centered_gamma=config.layernorm_zero_centered_gamma, **extra_te_kwargs), + # Use fused residual variant + assert TEFusedResidualRMSNorm is not None, ( + "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0" + ) + instance = TEFusedResidualRMSNorm( + normalized_shape=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **extra_te_kwargs, ) else: + # Standard RMSNorm without fusion instance = te.pytorch.RMSNorm( - hidden_size=hidden_size, + normalized_shape=hidden_size, eps=eps, sequence_parallel=config.sequence_parallel, zero_centered_gamma=config.layernorm_zero_centered_gamma, From 321eca650b48fe94ba0f232cf65a534fa5051505 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 04/21] add more spots where tuple outputs break mcore --- megatron/core/extensions/transformer_engine.py | 2 +- megatron/core/transformer/multi_token_prediction.py | 6 ++++++ megatron/core/transformer/transformer_block.py | 4 ++++ megatron/core/transformer/transformer_layer.py | 11 +++++++++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index ecc2edfe685..5f23192a63f 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -656,7 +656,7 @@ def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5) **extra_te_kwargs, ) else: - raise Exception("Only LayerNorm, RMSNorm and ResidualRMSNorm are curently supported") + raise Exception("Only LayerNorm and RMSNorm are curently supported") return instance diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index b0476155ad9..d65e216ce52 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -728,8 +728,12 @@ def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.T Concatenate the tokens before sending to transformer layer. """ decoder_input = self.enorm(decoder_input) + if isinstance(decoder_input, tuple): + decoder_input = decoder_input[0] decoder_input = make_viewless_tensor(inp=decoder_input, requires_grad=True, keep_graph=True) hidden_states = self.hnorm(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) # At the (k - 1)-th MTP module, concatenates the i-th token's hidden_states # and the (i + K)-th token's embedding, and combine them with linear projection. @@ -813,6 +817,8 @@ def _postprocess(self, hidden_states: torch.Tensor): # Layer norm before shared head layer. hidden_states = self.final_layernorm(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is # created to prevent this. diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index f222a2c3a6b..34904a3608c 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -787,6 +787,10 @@ def forward( # Final layer norm. if self.final_layernorm is not None: hidden_states = self.final_layernorm(hidden_states) + # Handle fused residual normalization (returns tuple of (output, residual)) + # For final layernorm, we only need the normalized output, not the residual + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is # created to prevent this. diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index d42922a3f4a..a5471e6f430 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -711,6 +711,15 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) # Optional Layer norm post the cross-attention. pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) + # Handle fused residual normalization (returns tuple of (output, residual)) + if isinstance(pre_mlp_layernorm_output, tuple): + if len(pre_mlp_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is expected to have " + f"2 elements (output, residual), but got {len(pre_mlp_layernorm_output)}" + ) + pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output + nvtx_range_push(suffix="mlp") # Potentially chunk the MLP computation during prefill to minimize the peak activation size should_chunk_mlp_for_prefill = ( @@ -1120,6 +1129,8 @@ def _te_cuda_graph_replay(self, *args, **kwargs): if not self.is_moe_layer: return residual, None, None, None hidden_states = self.pre_mlp_layernorm(residual) + if isinstance(hidden_states, tuple): + hidden_states, residual = hidden_states shared_expert_output = self.mlp.shared_experts_compute(hidden_states) probs, routing_map = self.mlp.route(hidden_states) hidden_states, probs = self.mlp.preprocess(hidden_states, probs, routing_map) From ca374c64a053724bdce8102cc29f8f24c429c6bf Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 05/21] remove excessive comments --- megatron/core/extensions/transformer_engine.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 5f23192a63f..ef9f6d0534f 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -480,30 +480,16 @@ def _make_fused_impl(self) -> te.pytorch.ops.Sequential: rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs) - # CRITICAL: Share the weight parameter with base class - # This ensures checkpointing works through the base class rmsnorm_op.weight = self.weight fused_impl.append(rmsnorm_op) - # Transfer hooks from base module to fused implementation - # This is CRITICAL for DDP to work correctly self._register_hooks_on_fused_impl(fused_impl) return fused_impl def _register_hooks_on_fused_impl(self, fused_impl: torch.nn.Module) -> None: - """ - Transfer hooks from base RMSNorm to fused implementation. - - This is critical for distributed training - DDP registers hooks on the - base module that must be executed. Follows TEFusedMLP pattern. - - Note: Transformer Engine's op fuser does not expose intermediate tensors, - so hooks that modify tensors will not work correctly. - """ - # Collect hooks from all submodules (including self) forward_pre_hooks = [] forward_post_hooks = [] backward_pre_hooks = [] From 0c567e6e2cd5a5ad4f0402333d2d9e9579970848 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 06/21] add quantization --- megatron/core/extensions/transformer_engine.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index ef9f6d0534f..ab731f66837 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -447,8 +447,9 @@ class TEFusedResidualRMSNorm(te.pytorch.RMSNorm): Forward pass returns: (normalized_output, residual) """ - def __init__(self, *args, **kwargs): + def __init__(self, quantize: bool, *args, **kwargs): super().__init__(*args, **kwargs) + self.quantize = quantize # Fused implementation (stored in tuple to avoid submodule registration) self._fused_impl: Optional[Tuple[te.pytorch.ops.Sequential]] = None @@ -477,6 +478,9 @@ def _make_fused_impl(self) -> te.pytorch.ops.Sequential: # Add sm_margin if available (TE 2.5+) if hasattr(self, '_sm_margins'): kwargs["sm_margin"] = self._sm_margins + + if self.quantize: + fused_impl.append(te.ops.Quantize(forward=False, backward=True)) rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs) @@ -484,6 +488,9 @@ def _make_fused_impl(self) -> te.pytorch.ops.Sequential: fused_impl.append(rmsnorm_op) + if self.quantize: + fused_impl.append(te.ops.Quantize(forward=True, backward=False)) + self._register_hooks_on_fused_impl(fused_impl) return fused_impl @@ -630,6 +637,7 @@ def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5) eps=eps, sequence_parallel=config.sequence_parallel, zero_centered_gamma=config.layernorm_zero_centered_gamma, + quantize=config.fp8 or config.fp4, **extra_te_kwargs, ) else: From b982040808eedf14456abf2ea22c9f57eda149f0 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 07/21] add rmsnorm residual fusion test --- .../fusions/test_rmsnorm_residual_fusion.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py diff --git a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py new file mode 100644 index 00000000000..cd6a2dba4f4 --- /dev/null +++ b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py @@ -0,0 +1,38 @@ +import pytest +import torch + +from megatron.core.extensions.transformer_engine import TEFusedResidualRMSNorm +from transformer_engine.pytorch import RMSNorm + +def baseline_rmsnorm_residual(x, rmsnorm: RMSNorm): + return x, rmsnorm(x) + +@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("normalized_shape", [256, 256*2, 256*4]) +def test_rmsnorm_residual_fusion(input_dtype, normalized_shape): + x_baseline = torch.randn(16, 32, 1024, dtype=input_dtype, device="cuda") + x_baseline.requires_grad = True + x_fused = x_baseline.detach() + x_fused.requires_grad = True + baseline_rmsnorm = RMSNorm(normalized_shape=normalized_shape).cuda() + fused_rmsnorm = TEFusedResidualRMSNorm(normalized_shape=normalized_shape, quantize=False).cuda() + + # baseline + baseline_y, baseline_residual = baseline_rmsnorm_residual(x_baseline, baseline_rmsnorm) + baseline_loss = baseline_y.sum() + baseline_residual.sum() + baseline_loss.backward() + + # fused + fused_y, fused_residual = fused_rmsnorm(x_fused) + fused_loss = fused_y.sum() + fused_residual.sum() + fused_loss.backward() + + # Use tolerances appropriate for dtype (pattern from other tests) + tols = dict(rtol=1e-6, atol=1e-6) if input_dtype is torch.float32 else dict(rtol=2e-2, atol=1e-2) + + assert fused_y.dtype == baseline_y.dtype + assert torch.allclose(fused_y, baseline_y, **tols) + assert fused_residual.dtype == baseline_residual.dtype + assert torch.allclose(fused_residual, baseline_residual, **tols) + assert x_fused.grad.dtype == x_baseline.grad.dtype + assert torch.allclose(x_baseline.grad, x_fused.grad, **tols) From 68f5abeed188ecaf0720d4417388b819f4052013 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 08/21] fix tests --- tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py index cd6a2dba4f4..59ad8212eb8 100644 --- a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py +++ b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py @@ -5,17 +5,17 @@ from transformer_engine.pytorch import RMSNorm def baseline_rmsnorm_residual(x, rmsnorm: RMSNorm): - return x, rmsnorm(x) + return rmsnorm(x), x @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("normalized_shape", [256, 256*2, 256*4]) def test_rmsnorm_residual_fusion(input_dtype, normalized_shape): - x_baseline = torch.randn(16, 32, 1024, dtype=input_dtype, device="cuda") + x_baseline = torch.randn(16, 32, normalized_shape, dtype=input_dtype, device="cuda") x_baseline.requires_grad = True x_fused = x_baseline.detach() x_fused.requires_grad = True - baseline_rmsnorm = RMSNorm(normalized_shape=normalized_shape).cuda() - fused_rmsnorm = TEFusedResidualRMSNorm(normalized_shape=normalized_shape, quantize=False).cuda() + baseline_rmsnorm = RMSNorm(normalized_shape=normalized_shape, dtype=input_dtype).cuda() + fused_rmsnorm = TEFusedResidualRMSNorm(normalized_shape=normalized_shape, dtype=input_dtype, quantize=False).cuda() # baseline baseline_y, baseline_residual = baseline_rmsnorm_residual(x_baseline, baseline_rmsnorm) From 26a5f5c4eb03dfafb106e15c9575fe10352191b3 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 09/21] dont use residual_add when not necessary --- .../core/extensions/transformer_engine.py | 23 +++++++++++++++---- .../core/transformer/transformer_layer.py | 3 +++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index ab731f66837..fec3e12a83e 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -602,17 +602,32 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens class TENorm: """A conditional wrapper to initialize an instance of - Transformer-Engine's `LayerNorm` or `RMSNorm` based on input.""" + Transformer-Engine's `LayerNorm` or `RMSNorm` based on input. + + Residual Fusion Design: + ---------------------- + Residual fusion is a two-level opt-in mechanism: + + 1. Global capability: config.fused_residual_rmsnorm must be True (enables the feature) + 2. Local intent: has_residual=True must be passed at build site (declares this specific + norm is followed by a residual connection) + + Fusion only happens when BOTH conditions are met. + + """ # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? - def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5): + def __new__( + cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5, has_residual: bool = False + ): if not HAVE_TE: raise ImportError( "Transformer Engine is not installed. " "Please install it with `pip install transformer-engine`." ) - if config.normalization == "LayerNorm": + if config.fused_residual_rmsnorm and has_residual: + raise ValueError("Fused residual RMSNorm is not supported for LayerNorm") instance = te.pytorch.LayerNorm( hidden_size=hidden_size, eps=eps, @@ -627,7 +642,7 @@ def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5) extra_te_kwargs = _get_extra_te_kwargs(config) - if config.fused_residual_rmsnorm: + if config.fused_residual_rmsnorm and has_residual: # Use fused residual variant assert TEFusedResidualRMSNorm is not None, ( "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index a5471e6f430..3efb6272807 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -289,6 +289,7 @@ def __init__( config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, + has_residual=True, # Followed by self-attention + residual add ) attention_optional_kwargs = {} @@ -317,6 +318,7 @@ def __init__( config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, + has_residual=True, # Followed by cross-attention + residual add ) # [Module 5: CrossAttention] @@ -336,6 +338,7 @@ def __init__( config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, + has_residual=True, # Followed by MLP + residual add ) # [Module 8: MLP block] additional_mlp_kwargs = {} From 41871c686154965d31f26ac440c57cfe703d5c66 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:40:53 +0100 Subject: [PATCH 10/21] Remove quantization for now --- megatron/core/extensions/transformer_engine.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index fec3e12a83e..1403409c1a1 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -447,10 +447,8 @@ class TEFusedResidualRMSNorm(te.pytorch.RMSNorm): Forward pass returns: (normalized_output, residual) """ - def __init__(self, quantize: bool, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.quantize = quantize - # Fused implementation (stored in tuple to avoid submodule registration) self._fused_impl: Optional[Tuple[te.pytorch.ops.Sequential]] = None @@ -479,18 +477,12 @@ def _make_fused_impl(self) -> te.pytorch.ops.Sequential: if hasattr(self, '_sm_margins'): kwargs["sm_margin"] = self._sm_margins - if self.quantize: - fused_impl.append(te.ops.Quantize(forward=False, backward=True)) - rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs) rmsnorm_op.weight = self.weight fused_impl.append(rmsnorm_op) - if self.quantize: - fused_impl.append(te.ops.Quantize(forward=True, backward=False)) - self._register_hooks_on_fused_impl(fused_impl) return fused_impl @@ -652,7 +644,6 @@ def __new__( eps=eps, sequence_parallel=config.sequence_parallel, zero_centered_gamma=config.layernorm_zero_centered_gamma, - quantize=config.fp8 or config.fp4, **extra_te_kwargs, ) else: From aac2b4f047566418addcd8f8b42ace2460b290f0 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:56:24 +0100 Subject: [PATCH 11/21] formatting changes --- .../core/extensions/transformer_engine.py | 31 ++++++++++++------- .../core/transformer/transformer_block.py | 1 + .../core/transformer/transformer_config.py | 8 +++-- .../core/transformer/transformer_layer.py | 16 +++++++--- .../fusions/test_rmsnorm_residual_fusion.py | 16 +++++++--- 5 files changed, 49 insertions(+), 23 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 1403409c1a1..c4edd8114c3 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -476,7 +476,7 @@ def _make_fused_impl(self) -> te.pytorch.ops.Sequential: # Add sm_margin if available (TE 2.5+) if hasattr(self, '_sm_margins'): kwargs["sm_margin"] = self._sm_margins - + rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs) rmsnorm_op.weight = self.weight @@ -527,8 +527,9 @@ def forward_pre_hook(module, *_) -> None: ret = hook(submodule, None) if ret is not None: raise RuntimeError( - "TEFusedResidualRMSNorm module does not expose intermediate tensors, but " - "submodule has pre-forward hook that modifies input tensor." + "TEFusedResidualRMSNorm module does not expose " + "intermediate tensors, but submodule has " + "pre-forward hook that modifies input tensor." ) fused_impl.register_forward_pre_hook(forward_pre_hook) @@ -548,8 +549,9 @@ def forward_post_hook(module, *_) -> None: ret = hook(submodule, None, None) if ret is not None: raise RuntimeError( - "TEFusedResidualRMSNorm module does not expose intermediate tensors, but " - "submodule has post-forward hook that modifies output tensor." + "TEFusedResidualRMSNorm module does not expose " + "intermediate tensors, but submodule has " + "post-forward hook that modifies output tensor." ) fused_impl.register_forward_hook(forward_post_hook) @@ -557,11 +559,13 @@ def forward_post_hook(module, *_) -> None: # Backward hooks if backward_pre_hooks: raise RuntimeError( - "TEFusedResidualRMSNorm module does not support submodules with pre-backward hooks" + "TEFusedResidualRMSNorm module does not support " + "submodules with pre-backward hooks" ) if backward_post_hooks: raise RuntimeError( - "TEFusedResidualRMSNorm module does not support submodules with post-backward hooks" + "TEFusedResidualRMSNorm module does not support " + "submodules with post-backward hooks" ) def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -610,7 +614,11 @@ class TENorm: # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? def __new__( - cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5, has_residual: bool = False + cls, + config: TransformerConfig, + hidden_size: int, + eps: float = 1e-5, + has_residual: bool = False, ): if not HAVE_TE: raise ImportError( @@ -636,9 +644,9 @@ def __new__( if config.fused_residual_rmsnorm and has_residual: # Use fused residual variant - assert TEFusedResidualRMSNorm is not None, ( - "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0" - ) + assert ( + TEFusedResidualRMSNorm is not None + ), "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0" instance = TEFusedResidualRMSNorm( normalized_shape=hidden_size, eps=eps, @@ -2381,6 +2389,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs) -> Tuple[Tensor, Option else: TEFusedMLP = None # type: ignore[assignment, misc] + class TEDelayedScaling(te.common.recipe.DelayedScaling): """ Wrapper for the Transformer-Engine's `DelayedScaling` layer. diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 34904a3608c..c60cf107e73 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -27,6 +27,7 @@ get_transformer_layer_offset, ) from megatron.core.transformer.utils import sharded_state_dict_default +from megatron.core.typed_torch import not_none from megatron.core.utils import ( WrappedTensor, deprecate_inference_params, diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index c7456c178bb..9487317d383 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1543,11 +1543,13 @@ def __post_init__(self): "If you use bias in MLP FC1, we recommend setting bias_activation_fusion " "to True and use_te_activation_func to False." ) - + if self.fused_residual_rmsnorm: if self.normalization != "RMSNorm": - raise ValueError("fused_residual_rmsnorm is only supported when normalization is RMSNorm.") - + raise ValueError( + "fused_residual_rmsnorm is only supported when normalization is RMSNorm." + ) + if self.use_te_activation_func: if self.activation_func not in (F.gelu, F.silu, F.relu): raise ValueError( diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 3efb6272807..b64617818d7 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -572,7 +572,6 @@ def _forward_attention( inference_context = deprecate_inference_params(inference_context, inference_params) - # Optional Input Layer norm if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() @@ -586,11 +585,15 @@ def _forward_attention( if isinstance(input_layernorm_output, tuple): if len(input_layernorm_output) != 2: - raise ValueError(f"When the output of input_layernorm is a tuple, it is expected to have 2 elements (output, residual), but got {len(input_layernorm_output)}") + raise ValueError( + f"When the output of input_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(input_layernorm_output)}" + ) input_layernorm_output, residual = input_layernorm_output else: residual = hidden_states - + using_fused_tp_inference_kernel = (not self.training) and ( self.config.inference_fuse_tp_communication ) @@ -650,7 +653,12 @@ def _forward_attention( if isinstance(pre_cross_attn_layernorm_output, tuple): if len(pre_cross_attn_layernorm_output) != 2: - raise ValueError(f"When the output of pre_cross_attn_layernorm_output is a tuple, it is expected to have 2 elements (output, residual), but got {len(pre_cross_attn_layernorm_output)}") + raise ValueError( + f"When the output of pre_cross_attn_layernorm_output " + f"is a tuple, it is expected to have 2 elements " + f"(output, residual), but " + f"got {len(pre_cross_attn_layernorm_output)}" + ) pre_cross_attn_layernorm_output, residual = pre_cross_attn_layernorm_output else: residual = hidden_states diff --git a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py index 59ad8212eb8..a7504a6e8d0 100644 --- a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py +++ b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py @@ -1,22 +1,26 @@ import pytest import torch +from transformer_engine.pytorch import RMSNorm from megatron.core.extensions.transformer_engine import TEFusedResidualRMSNorm -from transformer_engine.pytorch import RMSNorm + def baseline_rmsnorm_residual(x, rmsnorm: RMSNorm): return rmsnorm(x), x + @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) -@pytest.mark.parametrize("normalized_shape", [256, 256*2, 256*4]) +@pytest.mark.parametrize("normalized_shape", [256, 256 * 2, 256 * 4]) def test_rmsnorm_residual_fusion(input_dtype, normalized_shape): x_baseline = torch.randn(16, 32, normalized_shape, dtype=input_dtype, device="cuda") x_baseline.requires_grad = True x_fused = x_baseline.detach() x_fused.requires_grad = True baseline_rmsnorm = RMSNorm(normalized_shape=normalized_shape, dtype=input_dtype).cuda() - fused_rmsnorm = TEFusedResidualRMSNorm(normalized_shape=normalized_shape, dtype=input_dtype, quantize=False).cuda() - + fused_rmsnorm = TEFusedResidualRMSNorm( + normalized_shape=normalized_shape, dtype=input_dtype, quantize=False + ).cuda() + # baseline baseline_y, baseline_residual = baseline_rmsnorm_residual(x_baseline, baseline_rmsnorm) baseline_loss = baseline_y.sum() + baseline_residual.sum() @@ -28,7 +32,9 @@ def test_rmsnorm_residual_fusion(input_dtype, normalized_shape): fused_loss.backward() # Use tolerances appropriate for dtype (pattern from other tests) - tols = dict(rtol=1e-6, atol=1e-6) if input_dtype is torch.float32 else dict(rtol=2e-2, atol=1e-2) + tols = ( + dict(rtol=1e-6, atol=1e-6) if input_dtype is torch.float32 else dict(rtol=2e-2, atol=1e-2) + ) assert fused_y.dtype == baseline_y.dtype assert torch.allclose(fused_y, baseline_y, **tols) From 1c43ae8c418ed52f623145969af54072891823bb Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 12:15:47 +0100 Subject: [PATCH 12/21] add check tuple has len 2 to pre_mlp_layernorm --- megatron/core/transformer/transformer_block.py | 1 - megatron/core/transformer/transformer_layer.py | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index c60cf107e73..34904a3608c 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -27,7 +27,6 @@ get_transformer_layer_offset, ) from megatron.core.transformer.utils import sharded_state_dict_default -from megatron.core.typed_torch import not_none from megatron.core.utils import ( WrappedTensor, deprecate_inference_params, diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index b64617818d7..f395cc7c6da 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1141,7 +1141,13 @@ def _te_cuda_graph_replay(self, *args, **kwargs): return residual, None, None, None hidden_states = self.pre_mlp_layernorm(residual) if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is expected to have " + f"2 elements (output, residual), but got {len(hidden_states)}" + ) hidden_states, residual = hidden_states + shared_expert_output = self.mlp.shared_experts_compute(hidden_states) probs, routing_map = self.mlp.route(hidden_states) hidden_states, probs = self.mlp.preprocess(hidden_states, probs, routing_map) From d5aa69c8405f7bd7995210f1615a490f0f071fa5 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 12:18:34 +0100 Subject: [PATCH 13/21] fix formatting --- megatron/core/transformer/attention.py | 8 ++++++-- megatron/core/transformer/transformer_layer.py | 7 ++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index bc5e4e2ee0d..94c67e9f4f9 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -59,7 +59,9 @@ rearrange = None try: - from flash_attn_3.flash_attn_interface import _flash_attn_forward + from flash_attn_3.flash_attn_interface import ( + _flash_attn_forward, + ) from flash_attn_3.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) @@ -70,7 +72,9 @@ if not HAVE_FA3: try: - from flashattn_hopper.flash_attn_interface import _flash_attn_forward + from flashattn_hopper.flash_attn_interface import ( + _flash_attn_forward, + ) from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index f395cc7c6da..9d72d308e9e 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1143,11 +1143,12 @@ def _te_cuda_graph_replay(self, *args, **kwargs): if isinstance(hidden_states, tuple): if len(hidden_states) != 2: raise ValueError( - f"When the output of pre_mlp_layernorm is a tuple, it is expected to have " - f"2 elements (output, residual), but got {len(hidden_states)}" + f"When the output of pre_mlp_layernorm is a tuple,\ + it is expected to have 2 elements (output, residual),\ + but got {len(hidden_states)}" ) hidden_states, residual = hidden_states - + shared_expert_output = self.mlp.shared_experts_compute(hidden_states) probs, routing_map = self.mlp.route(hidden_states) hidden_states, probs = self.mlp.preprocess(hidden_states, probs, routing_map) From 68599c0ce48e31db4a92cce98f5477ccd691adde Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Fri, 13 Feb 2026 11:27:08 +0100 Subject: [PATCH 14/21] Add checks for tuple length in MultiTokenPredictionLayer and Transformer classes --- .../transformer/multi_token_prediction.py | 18 ++++++++++++ .../core/transformer/transformer_block.py | 9 ++++-- .../core/transformer/transformer_layer.py | 29 ++++++++++++------- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index d65e216ce52..c637dd35800 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -729,10 +729,22 @@ def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.T """ decoder_input = self.enorm(decoder_input) if isinstance(decoder_input, tuple): + if len(decoder_input) != 2: + raise ValueError( + f"When the output of enorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(decoder_input)}" + ) decoder_input = decoder_input[0] decoder_input = make_viewless_tensor(inp=decoder_input, requires_grad=True, keep_graph=True) hidden_states = self.hnorm(hidden_states) if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError( + f"When the output of hnorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(hidden_states)}" + ) hidden_states = hidden_states[0] hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) # At the (k - 1)-th MTP module, concatenates the i-th token's hidden_states @@ -818,6 +830,12 @@ def _postprocess(self, hidden_states: torch.Tensor): # Layer norm before shared head layer. hidden_states = self.final_layernorm(hidden_states) if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError( + f"When the output of final_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(hidden_states)}" + ) hidden_states = hidden_states[0] # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 34904a3608c..921b38bab21 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -787,9 +787,14 @@ def forward( # Final layer norm. if self.final_layernorm is not None: hidden_states = self.final_layernorm(hidden_states) - # Handle fused residual normalization (returns tuple of (output, residual)) - # For final layernorm, we only need the normalized output, not the residual if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError( + f"When the output of final_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(hidden_states)}" + ) + # For final layernorm, we only need the normalized output, not the residual hidden_states = hidden_states[0] # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 9d72d308e9e..a35db6599c1 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -716,20 +716,20 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) output (Tensor): Transformed hidden states of shape [s, b, h]. """ - # Residual connection. - residual = hidden_states - # Optional Layer norm post the cross-attention. pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) - # Handle fused residual normalization (returns tuple of (output, residual)) if isinstance(pre_mlp_layernorm_output, tuple): if len(pre_mlp_layernorm_output) != 2: raise ValueError( - f"When the output of pre_mlp_layernorm is a tuple, it is expected to have " - f"2 elements (output, residual), but got {len(pre_mlp_layernorm_output)}" + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(pre_mlp_layernorm_output)}" ) pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output + else: + # Residual connection. + residual = hidden_states nvtx_range_push(suffix="mlp") # Potentially chunk the MLP computation during prefill to minimize the peak activation size @@ -1143,9 +1143,9 @@ def _te_cuda_graph_replay(self, *args, **kwargs): if isinstance(hidden_states, tuple): if len(hidden_states) != 2: raise ValueError( - f"When the output of pre_mlp_layernorm is a tuple,\ - it is expected to have 2 elements (output, residual),\ - but got {len(hidden_states)}" + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(hidden_states)}" ) hidden_states, residual = hidden_states @@ -1332,9 +1332,18 @@ def _forward_mlp_router(self, hidden_states, padding_mask=None): This method is isolated so it can be captured by `cudagraph_manager_router`. """ - residual = hidden_states self.mlp.fwd_execution_map = "route" pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) + if isinstance(pre_mlp_layernorm_output, tuple): + if len(pre_mlp_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(pre_mlp_layernorm_output)}" + ) + pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output + else: + residual = hidden_states router_outputs = self.mlp( pre_mlp_layernorm_output, intermediate_tensors=(), padding_mask=padding_mask ) From 4512b19b42e0d315d7a8a0c3de7633bb03cf0e68 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Tue, 17 Feb 2026 18:07:10 +0100 Subject: [PATCH 15/21] remove unnecessary unpacking --- .../transformer/multi_token_prediction.py | 24 ------------------- .../core/transformer/transformer_block.py | 9 ------- 2 files changed, 33 deletions(-) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index c637dd35800..b0476155ad9 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -728,24 +728,8 @@ def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.T Concatenate the tokens before sending to transformer layer. """ decoder_input = self.enorm(decoder_input) - if isinstance(decoder_input, tuple): - if len(decoder_input) != 2: - raise ValueError( - f"When the output of enorm is a tuple, it is " - f"expected to have 2 elements (output, residual), but " - f"got {len(decoder_input)}" - ) - decoder_input = decoder_input[0] decoder_input = make_viewless_tensor(inp=decoder_input, requires_grad=True, keep_graph=True) hidden_states = self.hnorm(hidden_states) - if isinstance(hidden_states, tuple): - if len(hidden_states) != 2: - raise ValueError( - f"When the output of hnorm is a tuple, it is " - f"expected to have 2 elements (output, residual), but " - f"got {len(hidden_states)}" - ) - hidden_states = hidden_states[0] hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) # At the (k - 1)-th MTP module, concatenates the i-th token's hidden_states # and the (i + K)-th token's embedding, and combine them with linear projection. @@ -829,14 +813,6 @@ def _postprocess(self, hidden_states: torch.Tensor): # Layer norm before shared head layer. hidden_states = self.final_layernorm(hidden_states) - if isinstance(hidden_states, tuple): - if len(hidden_states) != 2: - raise ValueError( - f"When the output of final_layernorm is a tuple, it is " - f"expected to have 2 elements (output, residual), but " - f"got {len(hidden_states)}" - ) - hidden_states = hidden_states[0] # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is # created to prevent this. diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 921b38bab21..f222a2c3a6b 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -787,15 +787,6 @@ def forward( # Final layer norm. if self.final_layernorm is not None: hidden_states = self.final_layernorm(hidden_states) - if isinstance(hidden_states, tuple): - if len(hidden_states) != 2: - raise ValueError( - f"When the output of final_layernorm is a tuple, it is " - f"expected to have 2 elements (output, residual), but " - f"got {len(hidden_states)}" - ) - # For final layernorm, we only need the normalized output, not the residual - hidden_states = hidden_states[0] # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is # created to prevent this. From c197e08c64000ee39316aa8f4b8682f61fbda42a Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 23 Feb 2026 14:09:46 +0100 Subject: [PATCH 16/21] guard has_residual behind TENorm check --- .../core/transformer/transformer_layer.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index a35db6599c1..31560585264 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -14,6 +14,7 @@ from megatron.core import parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import apply_prefix_mapping +from megatron.core.extensions.transformer_engine import TENorm from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import is_graph_capturing @@ -282,14 +283,20 @@ def __init__( ) self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout + def _build_layernorm(builder: Union[ModuleSpec, type], has_residual_connection: bool): + norm_kwargs: Dict[str, Any] = { + "config": self.config, + "hidden_size": self.config.hidden_size, + "eps": self.config.layernorm_epsilon, + } + if has_residual_connection and builder is TENorm: + norm_kwargs["has_residual"] = True + return build_module(builder, **norm_kwargs) + # [Module 1: Input Layernorm] Optional Layernorm on the input data # TODO: add pytorch only layernorm - self.input_layernorm = build_module( - submodules.input_layernorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - has_residual=True, # Followed by self-attention + residual add + self.input_layernorm = _build_layernorm( + submodules.input_layernorm, has_residual_connection=True ) attention_optional_kwargs = {} @@ -313,12 +320,8 @@ def __init__( self.self_attn_bda = build_module(submodules.self_attn_bda) # [Module 4: Post SelfAttention] Optional Layernorm after self-attn - self.pre_cross_attn_layernorm = build_module( - submodules.pre_cross_attn_layernorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - has_residual=True, # Followed by cross-attention + residual add + self.pre_cross_attn_layernorm = _build_layernorm( + submodules.pre_cross_attn_layernorm, has_residual_connection=True ) # [Module 5: CrossAttention] @@ -333,12 +336,8 @@ def __init__( self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config) # [Module 7: Pre MLP] Optional Layernorm before MLP - self.pre_mlp_layernorm = build_module( - submodules.pre_mlp_layernorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - has_residual=True, # Followed by MLP + residual add + self.pre_mlp_layernorm = _build_layernorm( + submodules.pre_mlp_layernorm, has_residual_connection=True ) # [Module 8: MLP block] additional_mlp_kwargs = {} From 3b5b953eb724a8f916684a538290ce04d2233c5c Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 23 Feb 2026 14:34:32 +0100 Subject: [PATCH 17/21] avoid circular import --- megatron/core/transformer/transformer_layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 31560585264..ad67a532316 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -14,7 +14,6 @@ from megatron.core import parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import apply_prefix_mapping -from megatron.core.extensions.transformer_engine import TENorm from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import is_graph_capturing @@ -283,6 +282,8 @@ def __init__( ) self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout + # import here to avoid circular import + from megatron.core.extensions.transformer_engine import TENorm def _build_layernorm(builder: Union[ModuleSpec, type], has_residual_connection: bool): norm_kwargs: Dict[str, Any] = { "config": self.config, From c9c0420534a7147f8a2d17a955d31465ac57bbc4 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 23 Feb 2026 15:01:28 +0100 Subject: [PATCH 18/21] add missing copyright header --- tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py index a7504a6e8d0..324c162186d 100644 --- a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py +++ b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py @@ -1,3 +1,5 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import pytest import torch from transformer_engine.pytorch import RMSNorm From cbee42281da326ce2bee100b74d2dd49ae2834cc Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 23 Feb 2026 19:19:42 +0100 Subject: [PATCH 19/21] remove quantize arg from test --- tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py index 324c162186d..6c03e0fa801 100644 --- a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py +++ b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py @@ -20,7 +20,7 @@ def test_rmsnorm_residual_fusion(input_dtype, normalized_shape): x_fused.requires_grad = True baseline_rmsnorm = RMSNorm(normalized_shape=normalized_shape, dtype=input_dtype).cuda() fused_rmsnorm = TEFusedResidualRMSNorm( - normalized_shape=normalized_shape, dtype=input_dtype, quantize=False + normalized_shape=normalized_shape, dtype=input_dtype ).cuda() # baseline From 0c819a3f187446dc273197cf2bb4e874c5286dd1 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 23 Feb 2026 19:30:19 +0100 Subject: [PATCH 20/21] add arg to golden_dict --- tests/unit_tests/models/test_mamba_moe_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index 2481649bc3f..d0ab295d6e9 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -114,6 +114,7 @@ "fp8_quantizer_factory": None, "fp8_recipe": "delayed", "fp8_wgrad": True, + "fused_residual_rmsnorm": False, "fused_single_qkv_rope": False, "gated_linear_unit": False, "glu_linear_offset": 0.0, From 5bce280e74d11b480500346f1d639f51c174e3ac Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 23 Feb 2026 21:37:51 +0100 Subject: [PATCH 21/21] autoformat --- megatron/core/transformer/transformer_layer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index ad67a532316..e460fba56a8 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -284,6 +284,7 @@ def __init__( # import here to avoid circular import from megatron.core.extensions.transformer_engine import TENorm + def _build_layernorm(builder: Union[ModuleSpec, type], has_residual_connection: bool): norm_kwargs: Dict[str, Any] = { "config": self.config,