diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index ef8527e9e5e..c4edd8114c3 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -430,19 +430,204 @@ def __new__(cls, config: TransformerConfig): TEActivationOp = None +if HAVE_TE and is_te_min_version("1.13.0"): + + class TEFusedResidualRMSNorm(te.pytorch.RMSNorm): + """ + RMSNorm with fused residual output for Megatron Core. + + Inherits from te.pytorch.RMSNorm to maintain all parameter management, + checkpoint compatibility, and Megatron-specific features. Creates a fused + implementation using TE's ops API that shares the base class parameters. + + The fused implementation uses: + - MakeExtraOutput: Forks the residual connection + - RMSNorm: Normalizes the main path + + Forward pass returns: (normalized_output, residual) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Fused implementation (stored in tuple to avoid submodule registration) + self._fused_impl: Optional[Tuple[te.pytorch.ops.Sequential]] = None + + def _make_fused_impl(self) -> te.pytorch.ops.Sequential: + """ + Construct fused ops pipeline that shares parameters with base RMSNorm. + + Creates MakeExtraOutput + RMSNorm ops, where the RMSNorm op shares + the weight parameter with self.weight from the base class. + """ + + fused_impl = te.pytorch.ops.Sequential() + + # Op 1: MakeExtraOutput - forks the residual + fused_impl.append(te.pytorch.ops.MakeExtraOutput()) + + # Op 2: RMSNorm - shares weight parameter with self + kwargs = { + "eps": self.eps, + "device": "meta", # Already initialized + "dtype": self.weight.dtype, + "zero_centered_gamma": self.zero_centered_gamma, + } + + # Add sm_margin if available (TE 2.5+) + if hasattr(self, '_sm_margins'): + kwargs["sm_margin"] = self._sm_margins + + rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs) + + rmsnorm_op.weight = self.weight + + fused_impl.append(rmsnorm_op) + + self._register_hooks_on_fused_impl(fused_impl) + + return fused_impl + + def _register_hooks_on_fused_impl(self, fused_impl: torch.nn.Module) -> None: + + forward_pre_hooks = [] + forward_post_hooks = [] + backward_pre_hooks = [] + backward_post_hooks = [] + + for submodule in self.modules(): + for hook in submodule._forward_pre_hooks.values(): + forward_pre_hooks.append((submodule, hook)) + for hook in submodule._forward_hooks.values(): + forward_post_hooks.append((submodule, hook)) + for hook in submodule._backward_pre_hooks.values(): + backward_pre_hooks.append((submodule, hook)) + for hook in submodule._backward_hooks.values(): + backward_post_hooks.append((submodule, hook)) + + # Pre-forward hooks + # Note: DDP pre-forward hooks are safe since they do not + # interact with input tensor. + if forward_pre_hooks: + from megatron.core.distributed import distributed_data_parallel + + if any( + inspect.getmodule(hook) != distributed_data_parallel + for _, hook in forward_pre_hooks + ): + warnings.warn( + "TEFusedResidualRMSNorm module has a submodule with a pre-forward hook. " + "TEFusedResidualRMSNorm module does not expose intermediate tensors, " + "so the hook may have incorrect behavior if it attempts to " + "access the input tensor." + ) + + def forward_pre_hook(module, *_) -> None: + for submodule, hook in forward_pre_hooks: + # Assume that hook does not interact with input + ret = hook(submodule, None) + if ret is not None: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not expose " + "intermediate tensors, but submodule has " + "pre-forward hook that modifies input tensor." + ) + + fused_impl.register_forward_pre_hook(forward_pre_hook) + + # Post-forward hooks + if forward_post_hooks: + warnings.warn( + "TEFusedResidualRMSNorm module has a submodule with a post-forward hook. " + "TEFusedResidualRMSNorm module does not expose intermediate tensors, " + "so the hook may have incorrect behavior if it attempts to " + "access the input or output tensors." + ) + + def forward_post_hook(module, *_) -> None: + for submodule, hook in forward_post_hooks: + # Assume that hook does not interact with input or output + ret = hook(submodule, None, None) + if ret is not None: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not expose " + "intermediate tensors, but submodule has " + "post-forward hook that modifies output tensor." + ) + + fused_impl.register_forward_hook(forward_post_hook) + + # Backward hooks + if backward_pre_hooks: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not support " + "submodules with pre-backward hooks" + ) + if backward_post_hooks: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not support " + "submodules with post-backward hooks" + ) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass with fused residual output. + + Args: + hidden_states: Input tensor [s, b, h] + + Returns: + Tuple of (normalized_output, residual), both [s, b, h] + + Note: + Sequential.forward() automatically returns (output, extra_outputs...) + when MakeExtraOutput is present, so we don't need manual unpacking. + """ + + # Construct fused impl lazily on first forward + # (in case parameters are modified after __init__) + if self._fused_impl is None: + self._fused_impl = (self._make_fused_impl(),) + + # Apply fused implementation + # Sequential returns (normalized_output, residual) automatically + return self._fused_impl[0](hidden_states) + +else: + TEFusedResidualRMSNorm = None # type: ignore[assignment, misc] + + class TENorm: """A conditional wrapper to initialize an instance of - Transformer-Engine's `LayerNorm` or `RMSNorm` based on input.""" + Transformer-Engine's `LayerNorm` or `RMSNorm` based on input. + + Residual Fusion Design: + ---------------------- + Residual fusion is a two-level opt-in mechanism: + + 1. Global capability: config.fused_residual_rmsnorm must be True (enables the feature) + 2. Local intent: has_residual=True must be passed at build site (declares this specific + norm is followed by a residual connection) + + Fusion only happens when BOTH conditions are met. + + """ # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? - def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5): + def __new__( + cls, + config: TransformerConfig, + hidden_size: int, + eps: float = 1e-5, + has_residual: bool = False, + ): if not HAVE_TE: raise ImportError( "Transformer Engine is not installed. " "Please install it with `pip install transformer-engine`." ) - if config.normalization == "LayerNorm": + if config.fused_residual_rmsnorm and has_residual: + raise ValueError("Fused residual RMSNorm is not supported for LayerNorm") instance = te.pytorch.LayerNorm( hidden_size=hidden_size, eps=eps, @@ -454,13 +639,30 @@ def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5) assert hasattr( te.pytorch, "RMSNorm" ), "Transformer-Engine >= v0.11 required to use this feature" - instance = te.pytorch.RMSNorm( - hidden_size=hidden_size, - eps=eps, - sequence_parallel=config.sequence_parallel, - zero_centered_gamma=config.layernorm_zero_centered_gamma, - **_get_extra_te_kwargs(config), - ) + + extra_te_kwargs = _get_extra_te_kwargs(config) + + if config.fused_residual_rmsnorm and has_residual: + # Use fused residual variant + assert ( + TEFusedResidualRMSNorm is not None + ), "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0" + instance = TEFusedResidualRMSNorm( + normalized_shape=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **extra_te_kwargs, + ) + else: + # Standard RMSNorm without fusion + instance = te.pytorch.RMSNorm( + normalized_shape=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **extra_te_kwargs, + ) else: raise Exception("Only LayerNorm and RMSNorm are curently supported") diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index bc5e4e2ee0d..94c67e9f4f9 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -59,7 +59,9 @@ rearrange = None try: - from flash_attn_3.flash_attn_interface import _flash_attn_forward + from flash_attn_3.flash_attn_interface import ( + _flash_attn_forward, + ) from flash_attn_3.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) @@ -70,7 +72,9 @@ if not HAVE_FA3: try: - from flashattn_hopper.flash_attn_interface import _flash_attn_forward + from flashattn_hopper.flash_attn_interface import ( + _flash_attn_forward, + ) from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index dce438520aa..9487317d383 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -382,6 +382,9 @@ class TransformerConfig(ModelParallelConfig): fused_single_qkv_rope: bool = False """If set, avoid splitting QKV before ROPE forward and avoid concatenating ROPE dgrads.""" + fused_residual_rmsnorm: bool = False + """If True, uses fuses residual connection and RMSNorm when TE is used.""" + #################### # activation recomputation #################### @@ -1541,6 +1544,12 @@ def __post_init__(self): "to True and use_te_activation_func to False." ) + if self.fused_residual_rmsnorm: + if self.normalization != "RMSNorm": + raise ValueError( + "fused_residual_rmsnorm is only supported when normalization is RMSNorm." + ) + if self.use_te_activation_func: if self.activation_func not in (F.gelu, F.silu, F.relu): raise ValueError( diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 12c24868473..e460fba56a8 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -282,13 +282,23 @@ def __init__( ) self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout + # import here to avoid circular import + from megatron.core.extensions.transformer_engine import TENorm + + def _build_layernorm(builder: Union[ModuleSpec, type], has_residual_connection: bool): + norm_kwargs: Dict[str, Any] = { + "config": self.config, + "hidden_size": self.config.hidden_size, + "eps": self.config.layernorm_epsilon, + } + if has_residual_connection and builder is TENorm: + norm_kwargs["has_residual"] = True + return build_module(builder, **norm_kwargs) + # [Module 1: Input Layernorm] Optional Layernorm on the input data # TODO: add pytorch only layernorm - self.input_layernorm = build_module( - submodules.input_layernorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, + self.input_layernorm = _build_layernorm( + submodules.input_layernorm, has_residual_connection=True ) attention_optional_kwargs = {} @@ -312,11 +322,8 @@ def __init__( self.self_attn_bda = build_module(submodules.self_attn_bda) # [Module 4: Post SelfAttention] Optional Layernorm after self-attn - self.pre_cross_attn_layernorm = build_module( - submodules.pre_cross_attn_layernorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, + self.pre_cross_attn_layernorm = _build_layernorm( + submodules.pre_cross_attn_layernorm, has_residual_connection=True ) # [Module 5: CrossAttention] @@ -331,11 +338,8 @@ def __init__( self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config) # [Module 7: Pre MLP] Optional Layernorm before MLP - self.pre_mlp_layernorm = build_module( - submodules.pre_mlp_layernorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, + self.pre_mlp_layernorm = _build_layernorm( + submodules.pre_mlp_layernorm, has_residual_connection=True ) # [Module 8: MLP block] additional_mlp_kwargs = {} @@ -569,9 +573,6 @@ def _forward_attention( inference_context = deprecate_inference_params(inference_context, inference_params) - # Residual connection. - residual = hidden_states - # Optional Input Layer norm if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() @@ -583,6 +584,17 @@ def _forward_attention( with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: input_layernorm_output = self.input_layernorm(hidden_states) + if isinstance(input_layernorm_output, tuple): + if len(input_layernorm_output) != 2: + raise ValueError( + f"When the output of input_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(input_layernorm_output)}" + ) + input_layernorm_output, residual = input_layernorm_output + else: + residual = hidden_states + using_fused_tp_inference_kernel = (not self.training) and ( self.config.inference_fuse_tp_communication ) @@ -637,12 +649,21 @@ def _forward_attention( hidden_states, name="attn_norm", forced_released_tensors=[residual] ) - # Residual connection. - residual = hidden_states - # Optional Layer norm after self-attention pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states) + if isinstance(pre_cross_attn_layernorm_output, tuple): + if len(pre_cross_attn_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_cross_attn_layernorm_output " + f"is a tuple, it is expected to have 2 elements " + f"(output, residual), but " + f"got {len(pre_cross_attn_layernorm_output)}" + ) + pre_cross_attn_layernorm_output, residual = pre_cross_attn_layernorm_output + else: + residual = hidden_states + # Cross attention. attention_output_with_bias = self.cross_attention( pre_cross_attn_layernorm_output, @@ -696,12 +717,21 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) output (Tensor): Transformed hidden states of shape [s, b, h]. """ - # Residual connection. - residual = hidden_states - # Optional Layer norm post the cross-attention. pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) + if isinstance(pre_mlp_layernorm_output, tuple): + if len(pre_mlp_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(pre_mlp_layernorm_output)}" + ) + pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output + else: + # Residual connection. + residual = hidden_states + nvtx_range_push(suffix="mlp") # Potentially chunk the MLP computation during prefill to minimize the peak activation size should_chunk_mlp_for_prefill = ( @@ -1111,6 +1141,15 @@ def _te_cuda_graph_replay(self, *args, **kwargs): if not self.is_moe_layer: return residual, None, None, None hidden_states = self.pre_mlp_layernorm(residual) + if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(hidden_states)}" + ) + hidden_states, residual = hidden_states + shared_expert_output = self.mlp.shared_experts_compute(hidden_states) probs, routing_map = self.mlp.route(hidden_states) hidden_states, probs = self.mlp.preprocess(hidden_states, probs, routing_map) @@ -1294,9 +1333,18 @@ def _forward_mlp_router(self, hidden_states, padding_mask=None): This method is isolated so it can be captured by `cudagraph_manager_router`. """ - residual = hidden_states self.mlp.fwd_execution_map = "route" pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) + if isinstance(pre_mlp_layernorm_output, tuple): + if len(pre_mlp_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(pre_mlp_layernorm_output)}" + ) + pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output + else: + residual = hidden_states router_outputs = self.mlp( pre_mlp_layernorm_output, intermediate_tensors=(), padding_mask=padding_mask ) diff --git a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py new file mode 100644 index 00000000000..6c03e0fa801 --- /dev/null +++ b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import pytest +import torch +from transformer_engine.pytorch import RMSNorm + +from megatron.core.extensions.transformer_engine import TEFusedResidualRMSNorm + + +def baseline_rmsnorm_residual(x, rmsnorm: RMSNorm): + return rmsnorm(x), x + + +@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("normalized_shape", [256, 256 * 2, 256 * 4]) +def test_rmsnorm_residual_fusion(input_dtype, normalized_shape): + x_baseline = torch.randn(16, 32, normalized_shape, dtype=input_dtype, device="cuda") + x_baseline.requires_grad = True + x_fused = x_baseline.detach() + x_fused.requires_grad = True + baseline_rmsnorm = RMSNorm(normalized_shape=normalized_shape, dtype=input_dtype).cuda() + fused_rmsnorm = TEFusedResidualRMSNorm( + normalized_shape=normalized_shape, dtype=input_dtype + ).cuda() + + # baseline + baseline_y, baseline_residual = baseline_rmsnorm_residual(x_baseline, baseline_rmsnorm) + baseline_loss = baseline_y.sum() + baseline_residual.sum() + baseline_loss.backward() + + # fused + fused_y, fused_residual = fused_rmsnorm(x_fused) + fused_loss = fused_y.sum() + fused_residual.sum() + fused_loss.backward() + + # Use tolerances appropriate for dtype (pattern from other tests) + tols = ( + dict(rtol=1e-6, atol=1e-6) if input_dtype is torch.float32 else dict(rtol=2e-2, atol=1e-2) + ) + + assert fused_y.dtype == baseline_y.dtype + assert torch.allclose(fused_y, baseline_y, **tols) + assert fused_residual.dtype == baseline_residual.dtype + assert torch.allclose(fused_residual, baseline_residual, **tols) + assert x_fused.grad.dtype == x_baseline.grad.dtype + assert torch.allclose(x_baseline.grad, x_fused.grad, **tols) diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index 2481649bc3f..d0ab295d6e9 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -114,6 +114,7 @@ "fp8_quantizer_factory": None, "fp8_recipe": "delayed", "fp8_wgrad": True, + "fused_residual_rmsnorm": False, "fused_single_qkv_rope": False, "gated_linear_unit": False, "glu_linear_offset": 0.0,