From 55d1e34f6ef92b602f7e8fb8fca6157378965587 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Thu, 30 Apr 2026 16:47:11 -0700 Subject: [PATCH 01/16] Add grouped HybridStack overlap support --- .../common/model_chunk_schedule_plan.py | 61 ++-- .../core/models/gpt/fine_grained_callables.py | 3 +- .../models/hybrid/fine_grained_callables.py | 288 ++++++++++++++++++ megatron/core/models/hybrid/hybrid_block.py | 150 +++++++-- .../models/hybrid/hybrid_layer_allocation.py | 263 ++++++++++++++-- megatron/core/models/hybrid/hybrid_model.py | 248 ++++++++++++++- .../core/pipeline_parallel/combined_1f1b.py | 8 +- pretrain_hybrid.py | 33 +- tests/unit_tests/models/test_hybrid_model.py | 26 ++ tests/unit_tests/ssm/test_hybrid_block.py | 168 +++++++++- .../ssm/test_hybrid_layer_allocation.py | 56 ++++ 11 files changed, 1204 insertions(+), 100 deletions(-) create mode 100644 megatron/core/models/hybrid/fine_grained_callables.py diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index 9032d337e00..3d3591ea6dd 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -27,8 +27,8 @@ class ModelChunkState: pass -class TransformerLayerSchedulePlan: - """Schedule the executing plan of the nodes in a transformer/mtp layer. +class HybridStackSchedulePlan: + """Schedule the executing plan of nodes in a transformer, MTP, or hybrid layer. This class organizes the sub-modules of a transformer/mtp layer, including attention, post attention, MLP, dispatch, combine and @@ -55,7 +55,7 @@ class TransformerLayerSchedulePlan: mtp_post_process = None def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_args={}): - """Initializes a transformer layer schedule plan. + """Initializes a layer schedule plan. Args: layer (TransformerLayer): @@ -76,11 +76,12 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar self.layer_state = TransformerLayerState() self.chunk_state = chunk_state self.layer = layer + self.layer_type = extra_args.get("layer_type", None) self.event = event self.comp_stream = comp_stream self.comm_stream = comm_stream - # get callable nodes for transformer/mtp layer + # get callable nodes for transformer/mtp/hybrid layer self._build_callable_nodes(event, comp_stream, comm_stream, extra_args) def release_state(self): @@ -108,24 +109,35 @@ def release_state(self): def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args): """ - Builds the callable nodes for the transformer/mtp layer: + Builds the callable nodes for the transformer/mtp/hybrid layer: attn, mlp, moe_dispatch and moe_combine, and mtp_post_process. """ from megatron.core.models.gpt.fine_grained_callables import ( TransformerLayerNode, build_layer_callables, ) + from megatron.core.models.hybrid.fine_grained_callables import ( + build_hybrid_stack_callables, + ) + from megatron.core.models.hybrid.hybrid_layer_allocation import LayerPatternItem from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer - # build the forward and backward callables for the transformer/mtp layer - fwd_callables, bwd_dw_callable_map = build_layer_callables(self.layer) + layer_type: LayerPatternItem = extra_args.get("layer_type", None) + if layer_type is None: + # build the forward and backward callables for the transformer/mtp layer + fwd_callables, bwd_dw_callable_map = build_layer_callables(self.layer) - # get flags for latter use - is_mtp = isinstance(self.layer, MultiTokenPredictionLayer) - transformer_layer = self.layer.mtp_model_layer if is_mtp else self.layer - is_moe = isinstance(transformer_layer.mlp, MoELayer) - num_local_experts = transformer_layer.mlp.num_local_experts if is_moe else None + # get flags for later use + is_mtp = isinstance(self.layer, MultiTokenPredictionLayer) + transformer_layer = self.layer.mtp_model_layer if is_mtp else self.layer + is_moe = isinstance(transformer_layer.mlp, MoELayer) + num_local_experts = transformer_layer.mlp.num_local_experts if is_moe else None + else: + fwd_callables, bwd_dw_callable_map, is_moe, num_local_experts = ( + build_hybrid_stack_callables(self.layer, layer_type=layer_type) + ) + is_mtp = False extra_args["config"] = self.layer.config extra_args["is_moe"] = is_moe @@ -136,6 +148,9 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args): # wrapper to help create TransformerLayerNode def create_node(stream, module, name): bwd_dw_callables = bwd_dw_callable_map.get(name, None) + node_extra_args = dict(extra_args) + if bwd_dw_callables is None: + node_extra_args["delay_wgrad_compute"] = False return TransformerLayerNode( stream, event, @@ -144,7 +159,7 @@ def create_node(stream, module, name): module, name=name, bwd_dw_callables=bwd_dw_callables, - extra_args=extra_args, + extra_args=node_extra_args, ) ( @@ -180,6 +195,8 @@ def get_fp8_context(self): use_inner_fp8_context = ( self.layer.config.fp8 and self.layer.config.fp8_recipe != Fp8Recipe.delayed ) + if self.layer_type is not None or not hasattr(self.layer, "layer_number"): + return nullcontext() return ( get_fp8_context(self.layer.config, self.layer.layer_number - 1) if use_inner_fp8_context @@ -254,7 +271,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False) return f_input, b_grad -class TransformerModelChunkSchedulePlan(AbstractSchedulePlan): +class HybridStackModelChunkSchedulePlan(AbstractSchedulePlan): """Schedule the executing plan of the sub-modules in a model chunk sub-modules. This class organizes the computation nodes for a model chunk, @@ -357,7 +374,10 @@ def _build_layer_schedule_plan(self, module, comp_stream, comm_stream): "is_first_layer": layer_idx == 0, "is_last_layer": layer_idx == num_layers - 1, } - layer_plan = TransformerLayerSchedulePlan( + extra_args["layer_type"] = ( + module.layer_type_list[layer_idx] if hasattr(module, "layer_type_list") else None + ) + layer_plan = HybridStackSchedulePlan( module.layers[layer_idx], self.event, self.state, @@ -476,7 +496,7 @@ def run( b_layer = b_schedule_plan.pop_layer() nvtx_msg = f"layer_{i}f-layer_{b_schedule_plan.num_layers()}b" nvtx_range_push(nvtx_msg) - f_input, b_grad = TransformerLayerSchedulePlan.run( + f_input, b_grad = HybridStackSchedulePlan.run( f_layer, b_layer, f_input=f_input, @@ -492,7 +512,7 @@ def run( b_layer = b_schedule_plan.pop_layer() nvtx_msg = f"layer_{b_schedule_plan.num_layers()}b" nvtx_range_push(nvtx_msg) - _, b_grad = TransformerLayerSchedulePlan.run( + _, b_grad = HybridStackSchedulePlan.run( None, b_layer, b_grad=b_grad, is_last_layer_in_bwd=(i == b_num_layers - 1) ) if i < b_num_layers - 1: @@ -504,7 +524,7 @@ def run( f_layer = f_schedule_plan.get_layer(i) nvtx_msg = f"layer_{i}f" nvtx_range_push(nvtx_msg) - f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input) + f_input, _ = HybridStackSchedulePlan.run(f_layer, None, f_input=f_input) nvtx_range_pop(nvtx_msg) if f_schedule_plan is not None and post_forward is not None: @@ -542,3 +562,8 @@ def run( b_schedule_plan.release_state() return f_input + + +# Backward-compatible aliases for GPT callers and existing tests. +TransformerLayerSchedulePlan = HybridStackSchedulePlan +TransformerModelChunkSchedulePlan = HybridStackModelChunkSchedulePlan diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index fa2a2ec4934..03d963c481b 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -204,7 +204,8 @@ def forward_impl(self, hidden_states): """ empty_decoder = len(self.gpt_model.decoder.layers) == 0 - layer_norm = self.gpt_model.decoder.final_layernorm + layer_norm = getattr(self.gpt_model.decoder, "final_layernorm", None) + layer_norm = layer_norm or getattr(self.gpt_model.decoder, "final_norm", None) if not self.gpt_model.config.mtp_num_layers and empty_decoder and layer_norm: hidden_states = layer_norm(hidden_states) hidden_states = make_viewless_tensor( diff --git a/megatron/core/models/hybrid/fine_grained_callables.py b/megatron/core/models/hybrid/fine_grained_callables.py new file mode 100644 index 00000000000..b94111e3f36 --- /dev/null +++ b/megatron/core/models/hybrid/fine_grained_callables.py @@ -0,0 +1,288 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +from contextlib import nullcontext +from functools import partial +from typing import Optional + +import torch +from torch import Tensor + +from megatron.core.enums import Fp8Recipe +from megatron.core.fp4_utils import get_fp4_context +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.models.hybrid.hybrid_block import HybridStack +from megatron.core.models.hybrid.hybrid_layer_allocation import ( + LayerPatternItem, + Symbols as LayerSymbols, + is_layer_group, +) +from megatron.core.pipeline_parallel.utils import ScheduleNode +from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor + + +def _get_inner_quant_context(layer): + config = layer.config + if config.fp8 and config.fp8_recipe != Fp8Recipe.delayed: + return get_fp8_context(config, layer.layer_number - 1) + if config.fp4: + return get_fp4_context(config, layer.layer_number - 1) + return nullcontext() + + +def _as_hybrid_layers(layer, layer_type: Optional[LayerPatternItem]): + """Return ``(layer_type, layer)`` pairs for a hybrid logical layer.""" + if isinstance(layer, HybridStack): + return list(zip(layer.layer_type_list, layer.layers)) + assert layer_type is not None, "Hybrid layer scheduling requires the layer type symbol." + return [(layer_type, layer)] + + +def _apply_attention_layer( + layer: TransformerLayer, + node: ScheduleNode, + hidden_states: Tensor, +): + hidden_states, _ = layer._forward_attention( + hidden_states=hidden_states, + attention_mask=node.chunk_state.attention_mask, + rotary_pos_emb=node.chunk_state.rotary_pos_emb, + rotary_pos_cos=node.chunk_state.rotary_pos_cos, + rotary_pos_sin=node.chunk_state.rotary_pos_sin, + packed_seq_params=node.chunk_state.packed_seq_params, + sequence_len_offset=node.chunk_state.sequence_len_offset, + ) + return hidden_states + + +def _apply_mamba_layer(layer, node: ScheduleNode, hidden_states: Tensor): + return layer( + hidden_states=hidden_states, + attention_mask=node.chunk_state.attention_mask, + inference_context=getattr(node.chunk_state, "inference_context", None), + packed_seq_params=node.chunk_state.packed_seq_params, + ) + + +def _maybe_apply_final_norm(node: ScheduleNode, hidden_states: Tensor): + final_norm = getattr(node.chunk_state.model.decoder, "final_norm", None) + final_norm = final_norm or getattr(node.chunk_state.model.decoder, "final_layernorm", None) + if not node.is_mtp and final_norm is not None and node.is_last_layer: + hidden_states = final_norm(hidden_states) + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + return hidden_states + + +def _get_moe_padding_mask(node: ScheduleNode): + padding_mask = node.chunk_state.padding_mask + if padding_mask is not None: + # MoELayer.forward receives [batch, seq] and transposes before routing. + padding_mask = padding_mask.transpose(0, 1).bool() + return padding_mask + + +class _SharedExpertBackwardDWWrapper: + """Backward weight-gradient wrapper for MoE-only hybrid terminal layers.""" + + def __init__(self, layer): + self.layer = layer + self.shared_expert_dw_callable = None + if layer.mlp.use_shared_expert: + self.shared_expert_dw_callable = partial( + layer.mlp.backward_dw, routed_experts=False, shared_experts=True + ) + + def backward_dw(self): + if self.shared_expert_dw_callable is not None: + self.shared_expert_dw_callable() + self.layer = None + self.shared_expert_dw_callable = None + + +def _run_moe_preprocess(layer, node: ScheduleNode, hidden_states: Tensor): + pre_mlp_layernorm_output = layer._forward_pre_mlp_layernorm(hidden_states) + if isinstance(pre_mlp_layernorm_output, tuple): + if len(pre_mlp_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is expected to have " + f"2 elements (output, residual), but got {len(pre_mlp_layernorm_output)}" + ) + pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output + else: + residual = hidden_states + + if layer.config.fp32_residual_connection: + residual = residual.float() + + shared_expert_output = layer.mlp.shared_experts_compute(pre_mlp_layernorm_output) + probs, routing_map = layer.mlp.route(pre_mlp_layernorm_output, _get_moe_padding_mask(node)) + local_tokens, probs = layer.mlp.preprocess(pre_mlp_layernorm_output, probs, routing_map) + + node.layer_state.residual = node.detach(residual) + if layer.mlp.use_shared_expert and not layer.mlp.shared_expert_overlap: + node.layer_state.shared_expert_output = node.detach(shared_expert_output) + + return local_tokens, probs + + +def _run_moe_experts(layer, node: ScheduleNode, dispatched_tokens: Tensor): + dispatched_probs = node.layer_state.dispatched_probs + enable_hybridep = ( + layer.config.moe_token_dispatcher_type == "flex" + and layer.config.moe_flex_dispatcher_backend == "hybridep" + ) + enable_deepep = ( + layer.config.moe_token_dispatcher_type == "flex" + and layer.config.moe_flex_dispatcher_backend == "deepep" + ) + token_dispatcher = layer.mlp.token_dispatcher + if enable_deepep or enable_hybridep: + token_dispatcher._comm_manager.dispatched_probs = dispatched_probs + + expert_output, _ = layer.mlp.routed_experts_compute(dispatched_tokens, dispatched_probs) + + if enable_hybridep: + tokens_per_expert = token_dispatcher._comm_manager.get_number_of_tokens_per_expert() + node.layer_state.tokens_per_expert = tokens_per_expert + + if layer.recompute_pre_mlp_layernorm: + layer.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(expert_output) + + return expert_output + + +def _run_moe_combine(layer, node: ScheduleNode, output: Tensor): + residual = node.layer_state.residual + shared_expert_output = getattr(node.layer_state, 'shared_expert_output', None) + output = layer.mlp.combine(output) + output = layer.mlp.postprocess(output, shared_expert_output) + output = layer._forward_post_mlp((output, None), residual) + + node.layer_state.residual.record_stream(torch.cuda.current_stream()) + if shared_expert_output is not None: + shared_expert_output.record_stream(torch.cuda.current_stream()) + + node.layer_state.residual = None + node.layer_state.shared_expert_output = None + + return _maybe_apply_final_norm(node, output) + + +def build_hybrid_stack_callables(layer, layer_type: Optional[LayerPatternItem] = None): + """Create fine-grained callables for one logical HybridStack layer. + + A logical layer may be a bracketed nested ``HybridStack`` (for example ``[M*E]``) + or a single legacy hybrid layer symbol. The split is: + pre-dispatch compute -> dispatch -> MLP/experts -> combine. + """ + layer_items = _as_hybrid_layers(layer, layer_type) + if any(is_layer_group(item_type) for item_type, _ in layer_items): + raise ValueError("Nested HybridStack groups are not supported in overlap scheduling.") + + terminal_idx = None + for idx, (item_type, _) in enumerate(layer_items): + if item_type in (LayerSymbols.MLP, LayerSymbols.MOE): + terminal_idx = idx + break + + if terminal_idx is not None and terminal_idx != len(layer_items) - 1: + raise ValueError("HybridStack overlap requires MLP/MoE to be the last layer in a group.") + + terminal_type = layer_items[terminal_idx][0] if terminal_idx is not None else None + terminal_layer = layer_items[terminal_idx][1] if terminal_idx is not None else None + pre_layers = layer_items[:terminal_idx] if terminal_idx is not None else layer_items + is_moe = terminal_type == LayerSymbols.MOE + num_local_experts = terminal_layer.mlp.num_local_experts if is_moe else None + + def pre_dispatch_computation(node: ScheduleNode, hidden_states: Tensor): + for item_type, item_layer in pre_layers: + with _get_inner_quant_context(item_layer): + if item_type == LayerSymbols.MAMBA: + hidden_states = _apply_mamba_layer(item_layer, node, hidden_states) + elif item_type in ( + LayerSymbols.ATTENTION, + LayerSymbols.DS_ATTENTION, + LayerSymbols.GDN, + ): + hidden_states = _apply_attention_layer(item_layer, node, hidden_states) + else: + raise ValueError( + f"HybridStack overlap does not support layer type '{item_type}' before " + "the terminal MLP/MoE layer." + ) + + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + if terminal_type == LayerSymbols.MOE: + with _get_inner_quant_context(terminal_layer): + return _run_moe_preprocess(terminal_layer, node, hidden_states) + + if terminal_type is None: + return _maybe_apply_final_norm(node, hidden_states) + + return hidden_states + + def dispatch(node: ScheduleNode, local_tokens: Tensor, probs: Tensor): + enable_hybridep = ( + terminal_layer.config.moe_token_dispatcher_type == "flex" + and terminal_layer.config.moe_flex_dispatcher_backend == "hybridep" + ) + enable_deepep = ( + terminal_layer.config.moe_token_dispatcher_type == "flex" + and terminal_layer.config.moe_flex_dispatcher_backend == "deepep" + ) + token_dispatcher = terminal_layer.mlp.token_dispatcher + if enable_deepep or enable_hybridep: + token_dispatcher._comm_manager.token_probs = probs + with _get_inner_quant_context(terminal_layer): + dispatched_tokens, dispatched_probs = terminal_layer.mlp.dispatch(local_tokens, probs) + node.layer_state.dispatched_probs = node.detach(dispatched_probs) + return dispatched_tokens + + def mlp(node: ScheduleNode, hidden_states: Tensor): + if terminal_type == LayerSymbols.MLP: + with _get_inner_quant_context(terminal_layer): + hidden_states = terminal_layer._forward_mlp( + hidden_states, + padding_mask=node.chunk_state.padding_mask, + ) + return _maybe_apply_final_norm(node, hidden_states) + if terminal_type == LayerSymbols.MOE: + with _get_inner_quant_context(terminal_layer): + return _run_moe_experts(terminal_layer, node, hidden_states) + return hidden_states + + def combine(node: ScheduleNode, output: Tensor): + with _get_inner_quant_context(terminal_layer): + return _run_moe_combine(terminal_layer, node, output) + + def raise_not_implemented(*args): + raise NotImplementedError("This callable is not implemented for non-MoE hybrid layers.") + + backward_dw = {} + pre_bwd_dw = [] + for item_type, item_layer in pre_layers: + if item_type in (LayerSymbols.ATTENTION, LayerSymbols.DS_ATTENTION, LayerSymbols.GDN): + item_layer.init_backward_dw_wrapper() + pre_bwd_dw.append(item_layer.backward_dw_wrapper) + if is_moe: + shared_expert_dw = _SharedExpertBackwardDWWrapper(terminal_layer) + if shared_expert_dw.shared_expert_dw_callable is not None: + pre_bwd_dw.append(shared_expert_dw) + backward_dw["mlp"] = terminal_layer.mlp + elif terminal_type == LayerSymbols.MLP: + backward_dw["mlp"] = terminal_layer.mlp + + if pre_bwd_dw: + backward_dw["attn"] = pre_bwd_dw + + forward_funcs = [ + pre_dispatch_computation, + dispatch if is_moe else raise_not_implemented, + mlp, + combine if is_moe else raise_not_implemented, + None, + ] + return forward_funcs, backward_dw, is_moe, num_local_experts diff --git a/megatron/core/models/hybrid/hybrid_block.py b/megatron/core/models/hybrid/hybrid_block.py index 6d20bcdd6e5..d9045aed0b3 100644 --- a/megatron/core/models/hybrid/hybrid_block.py +++ b/megatron/core/models/hybrid/hybrid_block.py @@ -19,7 +19,12 @@ from megatron.core.fp4_utils import get_fp4_context from megatron.core.fp8_utils import get_fp8_context from megatron.core.inference.contexts import BaseInferenceContext -from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols as LayerSymbols +from megatron.core.models.hybrid.hybrid_layer_allocation import ( + LayerPatternItem, + Symbols as LayerSymbols, + get_layer_type_physical_count, + is_layer_group, +) from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer import TransformerConfig @@ -77,8 +82,10 @@ def __init__( config: TransformerConfig, submodules: HybridStackSubmodules, pre_process: bool = True, - layer_type_list: Optional[list[str]] = None, + layer_type_list: Optional[list[LayerPatternItem]] = None, pp_layer_offset: int = 0, + logical_layer_offset: int = 0, + is_layer_group_stack: bool = False, post_layer_norm: bool = True, post_process: bool = True, device=None, @@ -91,6 +98,8 @@ def __init__( self.post_layer_norm = post_layer_norm self.post_process = post_process self.is_mtp_layer = is_mtp_layer + self.logical_layer_offset = logical_layer_offset + self.is_layer_group_stack = is_layer_group_stack assert pg_collection is not None, "pg_collection must be provided for HybridStack" @@ -109,16 +118,39 @@ def __init__( # Build layers from the pre-selected segment self.layers = nn.ModuleList() - for i, layer_type in enumerate(self.layer_type_list): - layer_number = i + 1 + pp_layer_offset - if self.config.fp8: - quant_init_context = get_fp8_context(self.config, i + pp_layer_offset, is_init=True) + physical_layer_offset = pp_layer_offset + for layer_type in self.layer_type_list: + layer_number = physical_layer_offset + 1 + if is_layer_group(layer_type): + quant_init_context = nullcontext() + elif self.config.fp8: + quant_init_context = get_fp8_context( + self.config, physical_layer_offset, is_init=True + ) elif self.config.fp4: - quant_init_context = get_fp4_context(self.config, i + pp_layer_offset, is_init=True) + quant_init_context = get_fp4_context( + self.config, physical_layer_offset, is_init=True + ) else: quant_init_context = nullcontext() with quant_init_context: - if layer_type == LayerSymbols.MAMBA: + if is_layer_group(layer_type): + layer = HybridStack( + config=self.config, + submodules=submodules, + pre_process=True, + layer_type_list=list(layer_type), + pp_layer_offset=physical_layer_offset, + logical_layer_offset=logical_layer_offset + len(self.layers), + is_layer_group_stack=True, + post_layer_norm=False, + post_process=False, + device=device, + dtype=dtype, + pg_collection=pg_collection, + is_mtp_layer=is_mtp_layer, + ) + elif layer_type == LayerSymbols.MAMBA: layer = build_module( submodules.mamba_layer, config=self.config, @@ -174,6 +206,7 @@ def __init__( else: raise ValueError("unexpected layer_type") self.layers.append(layer) + physical_layer_offset += get_layer_type_physical_count(layer_type) # Required for activation recomputation self.num_layers_per_pipeline_rank = len(self.layers) @@ -202,7 +235,11 @@ def mamba_state_shapes_per_request(self) -> Optional[Tuple[Tuple[int], Tuple[int if this block contains Mamba layers (this may not be the case with PP > 1). """ for layer_type, layer in zip(self.layer_type_list, self.layers): - if layer_type == LayerSymbols.MAMBA: + if is_layer_group(layer_type): + shapes = layer.mamba_state_shapes_per_request() + if shapes is not None: + return shapes + elif layer_type == LayerSymbols.MAMBA: return layer.mamba_state_shapes_per_request() return None @@ -212,6 +249,10 @@ def forward( attention_mask: Tensor, inference_context: Optional[BaseInferenceContext] = None, rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + rotary_pos_cos_sin: Optional[Tensor] = None, + sequence_len_offset: Optional[Tensor] = None, *, inference_params: Optional[BaseInferenceContext] = None, packed_seq_params: Optional[PackedSeqParams] = None, @@ -252,7 +293,9 @@ def forward( inference_context.max_seqlen = inference_context.max_sequence_length inference_context.seqlen_offset = inference_context.sequence_len_offset - if ( + if sequence_len_offset is not None: + pass + elif ( ( ( self.config.cuda_graph_impl == "local" @@ -301,14 +344,35 @@ def get_inner_quant_context(config, layer_number): with outer_fp8_context: for layer in self.layers: # Layers have 1-indexed layer numbers attribute. - inner_quant_context = get_inner_quant_context(self.config, layer.layer_number - 1) + if isinstance(layer, HybridStack): + inner_quant_context = nullcontext() + else: + inner_quant_context = get_inner_quant_context( + self.config, layer.layer_number - 1 + ) with inner_quant_context: - if isinstance(layer, TransformerLayer): + if isinstance(layer, HybridStack): + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + sequence_len_offset=sequence_len_offset, + packed_seq_params=packed_seq_params, + padding_mask=padding_mask, + ) + elif isinstance(layer, TransformerLayer): hidden_states, _ = layer( hidden_states=hidden_states, attention_mask=attention_mask, inference_context=inference_context, rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, sequence_len_offset=sequence_len_offset, packed_seq_params=packed_seq_params, padding_mask=padding_mask, @@ -361,17 +425,46 @@ def sharded_state_dict( dict: The sharded state dictionary for the current object. """ + return self._sharded_state_dict( + prefix=prefix, + sharded_offsets=sharded_offsets, + metadata=metadata, + sharded_layer_prefix=None, + ) + + def _sharded_state_dict( + self, + prefix: str = '', + sharded_offsets: Optional[tuple] = None, + metadata: Optional[dict] = None, + sharded_layer_prefix: Optional[str] = None, + ) -> ShardedStateDict: + sharded_offsets = sharded_offsets or () sharded_state_dict = {} layer_prefix = f'{prefix}layers.' - - for local_layer_idx, layer in enumerate(self.layers): - - global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 - state_dict_prefix = ( - f'{layer_prefix}{local_layer_idx}.' # module list index in HybridStack + if sharded_layer_prefix is None: + sharded_layer_prefix = layer_prefix + + for local_layer_idx, (layer_type, layer) in enumerate(zip(self.layer_type_list, self.layers)): + state_dict_prefix = f'{layer_prefix}{local_layer_idx}.' + logical_layer_idx = ( + self.logical_layer_offset + if self.is_layer_group_stack + else self.logical_layer_offset + local_layer_idx ) - sharded_prefix = f'{layer_prefix}{global_layer_offset}.' + if is_layer_group(layer_type): + sharded_state_dict.update( + layer._sharded_state_dict( + state_dict_prefix, + sharded_offsets, + metadata, + sharded_layer_prefix=sharded_layer_prefix, + ) + ) + continue + + sharded_prefix = f'{sharded_layer_prefix}{logical_layer_idx}.' sharded_pp_offset = [] layer_sharded_state_dict = layer.sharded_state_dict( @@ -385,15 +478,20 @@ def sharded_state_dict( # Add modules other than self.layers for name, module in self.named_children(): if not module is self.layers: - sharded_state_dict.update( - sharded_state_dict_default( - module, + module_sharded_state_dict = sharded_state_dict_default( + module, + f'{prefix}{name}.', + sharded_offsets, + metadata, + tp_group=self.tp_group, + ) + if name == "final_norm": + replace_prefix_for_sharding( + module_sharded_state_dict, f'{prefix}{name}.', - sharded_offsets, - metadata, - tp_group=self.tp_group, + f'{prefix}final_layernorm.', ) - ) + sharded_state_dict.update(module_sharded_state_dict) return sharded_state_dict diff --git a/megatron/core/models/hybrid/hybrid_layer_allocation.py b/megatron/core/models/hybrid/hybrid_layer_allocation.py index f1ba94ef7fa..57326301cc5 100644 --- a/megatron/core/models/hybrid/hybrid_layer_allocation.py +++ b/megatron/core/models/hybrid/hybrid_layer_allocation.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch @@ -22,6 +22,8 @@ class Symbols: MOE = 'E' PIPE = '|' MTP_SEPARATOR = "/" + GROUP_START = "[" + GROUP_END = "]" VALID_LAYERS = {MAMBA, GDN, ATTENTION, DS_ATTENTION, MLP, MOE} @classmethod @@ -37,6 +39,57 @@ def name_sorted_valid_layer_symbols(cls) -> list[str]: return [value for (_, value) in valid_layer_attrs] +LayerPatternItem = Union[str, Tuple[str, ...]] + + +def is_layer_group(layer_type: LayerPatternItem) -> bool: + """Return whether a parsed layer item is a bracketed group.""" + return isinstance(layer_type, tuple) + + +def flatten_layer_type_list(layer_type_list: List[LayerPatternItem]) -> List[str]: + """Flatten bracketed layer groups into their physical layer symbols.""" + flattened = [] + for layer_type in layer_type_list: + if is_layer_group(layer_type): + flattened.extend(layer_type) + else: + flattened.append(layer_type) + return flattened + + +def get_layer_type_physical_count(layer_type: LayerPatternItem) -> int: + """Return the number of physical layers represented by a parsed layer item.""" + return len(layer_type) if is_layer_group(layer_type) else 1 + + +def get_layer_type_logical_count(layer_type: LayerPatternItem) -> int: + """Return the number of logical layers represented by a parsed layer item.""" + return 1 + + +def get_layer_type_list_physical_count(layer_type_list: List[LayerPatternItem]) -> int: + """Return the number of physical layers represented by a parsed layer list.""" + return sum(get_layer_type_physical_count(layer_type) for layer_type in layer_type_list) + + +def get_layer_type_list_logical_count(layer_type_list: List[LayerPatternItem]) -> int: + """Return the number of logical layers represented by a parsed layer list.""" + return sum(get_layer_type_logical_count(layer_type) for layer_type in layer_type_list) + + +def layer_type_item_to_str(layer_type: LayerPatternItem) -> str: + """Render one parsed layer item back to pattern syntax.""" + if is_layer_group(layer_type): + return f"{Symbols.GROUP_START}{''.join(layer_type)}{Symbols.GROUP_END}" + return layer_type + + +def layer_type_list_to_str(layer_type_list: List[LayerPatternItem]) -> str: + """Render a parsed layer list back to pattern syntax.""" + return ''.join(layer_type_item_to_str(layer_type) for layer_type in layer_type_list) + + @dataclass class ParsedHybridPattern: """Result of parsing a unified hybrid pattern string. @@ -138,7 +191,10 @@ def get_hybrid_total_layer_count(pattern: str) -> int: """ main_pattern = pattern.split(Symbols.MTP_SEPARATOR)[0] _validate_pattern(main_pattern, "main", allow_pipe=True) - return len(main_pattern.replace(Symbols.PIPE, '')) + return sum( + get_layer_type_list_physical_count(validate_segment_layers(segment)) + for segment in main_pattern.split(Symbols.PIPE) + ) def get_hybrid_total_pipeline_segment_count(pattern: str) -> int: @@ -183,15 +239,14 @@ def get_hybrid_layer_counts(pattern: str) -> Dict[str, int]: # Count main decoder layers (skip '|' pipe separators) if parsed.main_pattern: - for char in parsed.main_pattern: - if char in counts: + for segment in parsed.main_pattern.split(Symbols.PIPE): + for char in flatten_layer_type_list(validate_segment_layers(segment)): counts[char] += 1 # Count MTP layers (pattern repeated mtp_num_depths times) if parsed.mtp_pattern and parsed.mtp_num_depths > 0: - for char in parsed.mtp_pattern: - if char in counts: - counts[char] += parsed.mtp_num_depths + for char in flatten_layer_type_list(validate_segment_layers(parsed.mtp_pattern)): + counts[char] += parsed.mtp_num_depths return counts @@ -284,20 +339,90 @@ def _validate_pattern(pattern: str, pattern_name: str, allow_pipe: bool = False) Raises: ValueError: If pattern contains invalid symbols """ - valid_chars = Symbols.VALID_LAYERS | {Symbols.PIPE} if allow_pipe else Symbols.VALID_LAYERS - for char in pattern: - if char not in valid_chars: + valid_chars = ( + Symbols.VALID_LAYERS + | {Symbols.GROUP_START, Symbols.GROUP_END} + | ({Symbols.PIPE} if allow_pipe else set()) + ) + if not allow_pipe and Symbols.PIPE in pattern: + raise ValueError( + f"In {pattern_name} pattern, '{Symbols.PIPE}' is not a valid layer symbol. " + f"Valid symbols are: {valid_chars}" + ) + flat_layers = [] + for segment in pattern.split(Symbols.PIPE): + flat_layers.extend( + flatten_layer_type_list( + _parse_segment_layers(segment, pattern_name, valid_chars=valid_chars) + ) + ) + + # Disallow Attention + MLA/DSA hybridity. + if Symbols.ATTENTION in flat_layers and Symbols.DS_ATTENTION in flat_layers: + raise ValueError("Not supported to have both Attention and MLA/DSA in one model") + + +def _parse_segment_layers( + segment: str, pattern_name: str, valid_chars: Optional[set[str]] = None +) -> List[LayerPatternItem]: + """Parse a pipe-free pattern segment into symbols and bracketed groups.""" + if valid_chars is None: + valid_chars = Symbols.VALID_LAYERS | {Symbols.GROUP_START, Symbols.GROUP_END} + + layer_type_list: List[LayerPatternItem] = [] + flat_layers = [] + i = 0 + while i < len(segment): + layer_char = segment[i] + if layer_char == Symbols.GROUP_START: + group_end = segment.find(Symbols.GROUP_END, i + 1) + if group_end == -1: + raise ValueError( + f"In {pattern_name} pattern, '[' starts a layer group without a matching ']'." + ) + group = segment[i + 1 : group_end] + if group == "": + raise ValueError(f"In {pattern_name} pattern, layer groups cannot be empty.") + if Symbols.GROUP_START in group or Symbols.GROUP_END in group: + raise ValueError( + f"In {pattern_name} pattern, nested layer groups are not supported." + ) + for group_char in group: + if group_char not in Symbols.VALID_LAYERS: + raise ValueError( + f"In {pattern_name} pattern, '{group_char}' is not a valid layer symbol. " + f"Valid symbols are: {valid_chars}" + ) + if Symbols.MOE in group[:-1]: + raise ValueError( + f"In {pattern_name} pattern, MoE layer '{Symbols.MOE}' must be the last " + f"symbol inside a layer group." + ) + group_tuple = tuple(group) + layer_type_list.append(group_tuple) + flat_layers.extend(group_tuple) + i = group_end + 1 + continue + if layer_char == Symbols.GROUP_END: + raise ValueError( + f"In {pattern_name} pattern, ']' closes a layer group that was not opened." + ) + if layer_char not in Symbols.VALID_LAYERS: raise ValueError( - f"In {pattern_name} pattern, '{char}' is not a valid layer symbol. " + f"In {pattern_name} pattern, '{layer_char}' is not a valid layer symbol. " f"Valid symbols are: {valid_chars}" ) + layer_type_list.append(layer_char) + flat_layers.append(layer_char) + i += 1 - # Disallow Attention + MLA/DSA hybridity. - if Symbols.ATTENTION in pattern and Symbols.DS_ATTENTION in pattern: + if Symbols.ATTENTION in flat_layers and Symbols.DS_ATTENTION in flat_layers: raise ValueError("Not supported to have both Attention and MLA/DSA in one model") + return layer_type_list + -def validate_segment_layers(segment: str) -> List[str]: +def validate_segment_layers(segment: str) -> List[LayerPatternItem]: """Validate and convert a single pipeline segment pattern to a layer type list. This is used after the main pattern has been split by '|' into segments. @@ -312,19 +437,91 @@ def validate_segment_layers(segment: str) -> List[str]: Raises: ValueError: If segment contains invalid layer symbols. """ - layer_type_list = list(segment) - for layer_char in layer_type_list: - if layer_char not in Symbols.VALID_LAYERS: + return _parse_segment_layers(segment, "hybrid layer pattern segment") + + +def _slice_layer_type_list_by_physical_range( + layer_type_list: List[LayerPatternItem], offset: int, count: int +) -> List[LayerPatternItem]: + """Slice parsed layer items by physical layer range without splitting groups.""" + selected = [] + cursor = 0 + end = offset + count + for layer_type in layer_type_list: + item_count = get_layer_type_physical_count(layer_type) + item_end = cursor + item_count + if item_end <= offset: + cursor = item_end + continue + if cursor >= end: + break + if cursor < offset or item_end > end: raise ValueError( - f"In hybrid layer pattern segment, '{layer_char}' is not " - f"one of {Symbols.VALID_LAYERS}" + "Pipeline splitting would split a bracketed hybrid layer group. " + "Add pipe ('|') separators around bracketed groups to define valid boundaries." ) + selected.append(layer_type) + cursor = item_end + return selected + + +def _get_logical_offset_from_physical_offset( + layer_type_list: List[LayerPatternItem], offset: int +) -> int: + """Return the logical item count before a physical-layer offset.""" + logical_offset = 0 + cursor = 0 + for layer_type in layer_type_list: + item_count = get_layer_type_physical_count(layer_type) + item_end = cursor + item_count + if item_end <= offset: + logical_offset += get_layer_type_logical_count(layer_type) + cursor = item_end + continue + if cursor == offset: + return logical_offset + raise ValueError( + "Pipeline splitting would split a bracketed hybrid layer group. " + "Add pipe ('|') separators around bracketed groups to define valid boundaries." + ) + if cursor == offset: + return logical_offset + raise ValueError(f"Physical layer offset {offset} is out of range for hybrid layer pattern.") - # Disallow Attention + MLA/DSA hybridity. - if Symbols.ATTENTION in segment and Symbols.DS_ATTENTION in segment: - raise ValueError("Not supported to have both Attention and MLA/DSA in one model") - return layer_type_list +def select_pipeline_segment_with_logical_offset( + main_pattern: str, + pp_group: Optional[torch.distributed.ProcessGroup], + vp_stage: Optional[int], + first_stage_layers: Optional[int] = None, + last_stage_layers: Optional[int] = None, +) -> Tuple[List[LayerPatternItem], int, int]: + """Select a pipeline segment and return physical and logical offsets.""" + layer_type_list, layer_offset = select_pipeline_segment( + main_pattern, + pp_group, + vp_stage, + first_stage_layers=first_stage_layers, + last_stage_layers=last_stage_layers, + ) + + segments = main_pattern.split(Symbols.PIPE) if main_pattern else [''] + if len(segments) == 1: + full_layer_type_list = validate_segment_layers(segments[0]) + logical_layer_offset = _get_logical_offset_from_physical_offset( + full_layer_type_list, layer_offset + ) + else: + pp_rank = torch.distributed.get_rank(pp_group) if pp_group is not None else 0 + pp_size = torch.distributed.get_world_size(pp_group) if pp_group is not None else 1 + vp_rel = vp_stage if vp_stage is not None else 0 + segment_index = vp_rel * pp_size + pp_rank + logical_layer_offset = sum( + get_layer_type_list_logical_count(validate_segment_layers(segments[i])) + for i in range(segment_index) + ) + + return layer_type_list, layer_offset, logical_layer_offset def select_pipeline_segment( @@ -395,7 +592,7 @@ def select_pipeline_segment( ) full_pattern = segments[0] layer_type_list = validate_segment_layers(full_pattern) - num_layers = len(layer_type_list) + num_layers = get_layer_type_list_physical_count(layer_type_list) if first_stage_layers is not None or last_stage_layers is not None: first = first_stage_layers or 0 @@ -438,12 +635,13 @@ def select_pipeline_segment( offset = pp_rank * layers_per_rank count = layers_per_rank - selected = layer_type_list[offset : offset + count] + selected = _slice_layer_type_list_by_physical_range(layer_type_list, offset, count) log_on_each_pipeline_stage( logger, logging.INFO, f"HybridModel: pp_rank={pp_rank}/{pp_size}, vp_stage={vp_stage}, " - f"layers='{''.join(selected)}' ({len(selected)} layers), " + f"layers='{layer_type_list_to_str(selected)}' " + f"({get_layer_type_list_physical_count(selected)} layers), " f"layer_offset={offset} (auto-split)", ) return selected, offset @@ -467,7 +665,10 @@ def select_pipeline_segment( f"the current PP/VPP configuration." ) - layer_offset = sum(len(segments[i]) for i in range(segment_index)) + layer_offset = sum( + get_layer_type_list_physical_count(validate_segment_layers(segments[i])) + for i in range(segment_index) + ) my_segment = segments[segment_index] layer_type_list = validate_segment_layers(my_segment) @@ -477,21 +678,23 @@ def select_pipeline_segment( logging.INFO, f"HybridModel: pp_rank={pp_rank}/{pp_size}, vp_stage={vp_rel}, " f"segment_index={segment_index}/{len(segments)}, " - f"layers='{my_segment}' ({len(layer_type_list)} layers), " + f"layers='{my_segment}' ({get_layer_type_list_physical_count(layer_type_list)} layers), " f"layer_offset={layer_offset}", ) return layer_type_list, layer_offset -def get_layer_maps_from_layer_type_list(layer_type_list: list[str]) -> dict[str, dict[int, int]]: +def get_layer_maps_from_layer_type_list( + layer_type_list: list[LayerPatternItem], +) -> dict[str, dict[int, int]]: """ Returns maps from global layer index to the corresponding layer index for each valid layer type (those in Symbols.VALID_LAYERS) given a layer type list. """ layer_types = [symbol for symbol in Symbols.name_sorted_valid_layer_symbols()] layer_maps = {layer_type: {} for layer_type in layer_types} - for global_layer_idx, layer_type in enumerate(layer_type_list): + for global_layer_idx, layer_type in enumerate(flatten_layer_type_list(layer_type_list)): layer_map = layer_maps[layer_type] local_layer_idx = len(layer_map) layer_map[global_layer_idx] = local_layer_idx diff --git a/megatron/core/models/hybrid/hybrid_model.py b/megatron/core/models/hybrid/hybrid_model.py index 4399c6984a7..3473a7a43e1 100644 --- a/megatron/core/models/hybrid/hybrid_model.py +++ b/megatron/core/models/hybrid/hybrid_model.py @@ -1,12 +1,13 @@ # Copyright (c) 2023-2026, NVIDIA CORPORATION. All rights reserved. import logging -from typing import Literal, Optional +from typing import Dict, Literal, Optional from torch import Tensor from megatron.core import tensor_parallel from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding @@ -179,19 +180,21 @@ def __init__( # determine the pipeline segment for this model instance. from megatron.core.models.hybrid.hybrid_layer_allocation import ( parse_hybrid_pattern, - select_pipeline_segment, + select_pipeline_segment_with_logical_offset, ) parsed = parse_hybrid_pattern(self.hybrid_layer_pattern) self.mtp_pattern = parsed.mtp_pattern self.mtp_num_depths = parsed.mtp_num_depths - layer_type_list, layer_offset = select_pipeline_segment( - parsed.main_pattern or '', - self.pg_collection.pp, - vp_stage, - first_stage_layers=self.config.num_layers_in_first_pipeline_stage, - last_stage_layers=self.config.num_layers_in_last_pipeline_stage, + layer_type_list, layer_offset, logical_layer_offset = ( + select_pipeline_segment_with_logical_offset( + parsed.main_pattern or '', + self.pg_collection.pp, + vp_stage, + first_stage_layers=self.config.num_layers_in_first_pipeline_stage, + last_stage_layers=self.config.num_layers_in_last_pipeline_stage, + ) ) # Determine if MTP is needed (based on pattern parsing) @@ -256,6 +259,7 @@ def __init__( pre_process=self.pre_process, layer_type_list=layer_type_list, pp_layer_offset=layer_offset, + logical_layer_offset=logical_layer_offset, post_process=self.post_process, dtype=config.params_dtype, pg_collection=self.pg_collection, @@ -324,6 +328,77 @@ def set_input_tensor(self, input_tensor: Tensor) -> None: assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' self.decoder.set_input_tensor(input_tensor[0]) + def _preprocess( + self, + input_ids: Tensor, + position_ids: Tensor, + decoder_input: Tensor = None, + inference_context: BaseInferenceContext = None, + packed_seq_params: PackedSeqParams = None, + padding_mask: Optional[Tensor] = None, + ): + """Preprocess inputs for HybridStack or combined-1F1B scheduling.""" + in_inference_mode = inference_context is not None and not self.training + + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + + if ( + in_inference_mode + and inference_context.is_dynamic_batching() + and is_using_quantization_scales(self.config) + ): + decoder_input[inference_context.padding_slice] = 0.0 + else: + decoder_input = None + + rotary_pos_emb = None + if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_context, self.decoder, decoder_input, self.config, packed_seq_params + ) + rotary_pos_emb = self.rotary_pos_emb( + rotary_seq_len, + packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', + ) + elif self.position_embedding_type == 'yarn': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_context, self.decoder, decoder_input, self.config, packed_seq_params + ) + rotary_pos_emb, _ = self.rotary_pos_emb( + rotary_seq_len, + packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', + ) + + if ( + in_inference_mode + and ( + ( + self.config.cuda_graph_impl == "local" + and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope + ) + or self.config.flash_decode + ) + and inference_context.is_static_batching() + ): + current_batch_size = input_ids.shape[0] + import torch + + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * current_batch_size, + dtype=torch.int32, + device='cuda', + ) + else: + sequence_len_offset = None + + if in_inference_mode: + decoder_input = WrappedTensor(decoder_input) + + return decoder_input, rotary_pos_emb, None, None, sequence_len_offset, padding_mask + def preprocess_for_fine_grained_offloading(self): """Preprocess for fine-grained activation offloading.""" off_interface.init_chunk_handler( @@ -342,6 +417,163 @@ def preprocess_for_fine_grained_offloading(self): off_interface.mark_not_offloadable(param) self.disable_param_offloading = False + def _postprocess( + self, + hidden_states, + input_ids, + position_ids, + labels, + rotary_pos_emb, + rotary_pos_cos=None, + rotary_pos_sin=None, + mtp_in_postprocess=None, + loss_mask=None, + decoder_input=None, + attention_mask=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, + is_spec_decode=None, + ): + """Postprocess HybridStack hidden states into logits or language-model loss.""" + in_inference_mode = inference_context is not None and not self.training + if in_inference_mode: + assert runtime_gather_output, "Inference must always gather TP logits" + + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if is_spec_decode is None: + is_spec_decode = ( + in_inference_mode + and inference_context.is_dynamic_batching() + and inference_context.num_speculative_tokens > 0 + ) + + if mtp_in_postprocess and not (in_inference_mode or is_spec_decode): + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + embedding=self.embedding, + ) + + if not self.post_process: + return hidden_states + + if self.config.mtp_num_layers is not None and self.mtp_process: + assert self.config.mtp_num_layers > 0 + if in_inference_mode or is_spec_decode: + self._decoder_hidden_states_cache = hidden_states + else: + hidden_states = process_mtp_loss( + hidden_states=hidden_states, + labels=labels, + loss_mask=loss_mask, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + is_training=self.training, + compute_language_model_loss=self.compute_language_model_loss, + config=self.config, + cp_group=self.pg_collection.cp, + packed_seq_params=packed_seq_params, + scale_logits_fn=self._scale_logits if self.config.use_mup else None, + ) + + sequence_parallel_override = False + if in_inference_mode and inference_context.config.materialize_only_last_token_logits: + if inference_context.is_static_batching(): + hidden_states = hidden_states[-1:, :, :] + else: + if self.output_layer.sequence_parallel: + hidden_states = gather_from_sequence_parallel_region( + hidden_states, group=self.pg_collection.tp + ) + self.output_layer.sequence_parallel = False + sequence_parallel_override = True + + reshaped = hidden_states.squeeze(1).unsqueeze(0) + hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1) + + logits, _ = self.output_layer( + hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output + ) + logits = self._scale_logits(logits) + + if sequence_parallel_override: + assert ( + in_inference_mode + and inference_context.is_dynamic_batching() + and inference_context.config.materialize_only_last_token_logits + ) + self.output_layer.sequence_parallel = True + + if labels is None: + return logits.transpose(0, 1).contiguous() + + loss = self.compute_language_model_loss(labels, logits) + return loss + + def build_schedule_plan( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_context: BaseInferenceContext = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, + ): + """Build the HybridModel combined-1F1B schedule plan.""" + if self.config.fine_grained_activation_offloading: + self.preprocess_for_fine_grained_offloading() + + from ..common.model_chunk_schedule_plan import HybridStackModelChunkSchedulePlan + + return HybridStackModelChunkSchedulePlan( + self, + input_ids, + position_ids, + attention_mask, + decoder_input, + labels, + packed_seq_params, + extra_block_kwargs, + runtime_gather_output, + loss_mask, + padding_mask, + ) + + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None + ) -> ShardedStateDict: + """Return a Transformer-compatible sharded state dict for HybridModel.""" + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + output_layer_extra_state_key = f'{prefix}output_layer._extra_state' + + # Match GPTModel checkpoint compatibility: old GPT checkpoints do not include + # output layer extra state, and the TE extra state should be empty. + output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None) + assert not ( + output_extra_state and output_extra_state.data + ), f'Expected output layer extra state to be empty, got: {output_extra_state}' + + return sharded_state_dict + def _should_call_local_cudagraph(self, *args, **kwargs): """ Check if we should call the local cudagraph path. diff --git a/megatron/core/pipeline_parallel/combined_1f1b.py b/megatron/core/pipeline_parallel/combined_1f1b.py index f4f222ad2a1..dea54dd8749 100644 --- a/megatron/core/pipeline_parallel/combined_1f1b.py +++ b/megatron/core/pipeline_parallel/combined_1f1b.py @@ -343,11 +343,9 @@ def forward_backward_step(): unwrapped_model = get_attr_wrapped_model( f_model, "build_schedule_plan", return_model_obj=True ) - from megatron.core.models.gpt.gpt_model import GPTModel - - assert isinstance(unwrapped_model, GPTModel), ( - "The final unwrapped model must be a GPTModel instance " - "since only GPTModel is supported for EP A2A overlapping." + assert hasattr(unwrapped_model, "build_schedule_plan"), ( + "The final unwrapped model must implement build_schedule_plan " + "to support EP A2A overlapping." ) f_schedule_plan, loss_func = forward_step_func( data_iterator, unwrapped_model, return_schedule_plan=True diff --git a/pretrain_hybrid.py b/pretrain_hybrid.py index 9ab52ed11ab..4f3c16d6ed3 100644 --- a/pretrain_hybrid.py +++ b/pretrain_hybrid.py @@ -209,12 +209,13 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optio return loss, num_tokens, report -def forward_step(data_iterator, model: HybridModel): +def forward_step(data_iterator, model: HybridModel, return_schedule_plan: bool = False): """Forward training step. Args: data_iterator : Input data iterator model (HybridModel): The Model + return_schedule_plan (bool): Whether to return the schedule plan instead of output tensor. """ timers = get_timers() @@ -253,14 +254,28 @@ def forward_step(data_iterator, model: HybridModel): timers('batch-generator').stop() with stimer: - output_tensor = model( - tokens, - position_ids, - attention_mask, - labels=labels, - packed_seq_params=packed_seq_params, - loss_mask=loss_mask - ) + if return_schedule_plan: + args = get_args() + assert args.overlap_moe_expert_parallel_comm, ( + "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan" + ) + output_tensor = model.build_schedule_plan( + tokens, + position_ids, + attention_mask, + labels=labels, + packed_seq_params=packed_seq_params, + loss_mask=loss_mask, + ) + else: + output_tensor = model( + tokens, + position_ids, + attention_mask, + labels=labels, + packed_seq_params=packed_seq_params, + loss_mask=loss_mask + ) # [ModelOpt]: model is needed to access ModelOpt distillation losses return output_tensor, partial(loss_func, loss_mask, model=model) diff --git a/tests/unit_tests/models/test_hybrid_model.py b/tests/unit_tests/models/test_hybrid_model.py index 98a53da0314..902942f7141 100644 --- a/tests/unit_tests/models/test_hybrid_model.py +++ b/tests/unit_tests/models/test_hybrid_model.py @@ -198,6 +198,32 @@ def test_save_load(self, tmp_path): self.model.load_state_dict(torch.load(path)) + def test_grouped_sharded_state_dict_uses_transformer_checkpoint_keys(self): + """Grouped HybridModel checkpoints should be load-compatible with GPTModel keys.""" + model_config = TransformerConfig( + num_layers=2, + hidden_size=256, + num_attention_heads=4, + use_cpu_initialization=True, + ) + model = HybridModel( + config=model_config, + hybrid_stack_spec=hybrid_stack_spec, + vocab_size=100, + max_sequence_length=4, + hybrid_layer_pattern="[*-]", + ) + + sharded_state_dict = model.sharded_state_dict() + sharded_keys = {value.key for value in sharded_state_dict.values() if hasattr(value, "key")} + + assert "decoder.layers.0.self_attention.linear_qkv.weight" in sharded_keys + assert "decoder.layers.0.mlp.linear_fc1.weight" in sharded_keys + assert "decoder.layers.1.mlp.linear_fc1.weight" not in sharded_keys + assert "decoder.final_layernorm.weight" in sharded_keys + assert "decoder.final_norm.weight" not in sharded_keys + assert "output_layer._extra_state" not in sharded_state_dict + def test_layer_numbers(self): """ The layer numbers should start at one (for the embedding # layer) and go up diff --git a/tests/unit_tests/ssm/test_hybrid_block.py b/tests/unit_tests/ssm/test_hybrid_block.py index 08bf7f2bc28..efbe59cf9f3 100644 --- a/tests/unit_tests/ssm/test_hybrid_block.py +++ b/tests/unit_tests/ssm/test_hybrid_block.py @@ -3,8 +3,13 @@ import pytest import torch +from megatron.core.models.hybrid.fine_grained_callables import build_hybrid_stack_callables from megatron.core.models.hybrid.hybrid_block import HybridStack -from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols, validate_segment_layers +from megatron.core.models.hybrid.hybrid_layer_allocation import ( + Symbols, + get_layer_type_list_physical_count, + validate_segment_layers, +) from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.ssm.gated_delta_net import GatedDeltaNet @@ -27,8 +32,10 @@ def setup_method(self, method): Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) - def get_pg_collection(self): - return ProcessGroupCollection.use_mpu_process_groups(required_pgs=['tp', 'pp', 'cp']) + def get_pg_collection(self, required_pgs=None): + if required_pgs is None: + required_pgs = ['tp', 'pp', 'cp'] + return ProcessGroupCollection.use_mpu_process_groups(required_pgs=required_pgs) def get_mamba_block(self, layer_pattern): layer_type_list = validate_segment_layers(layer_pattern) @@ -81,6 +88,60 @@ def get_dsa_mamba_block(self, layer_pattern): pg_collection=self.get_pg_collection(), ) + def get_attention_mlp_block(self, layer_pattern): + layer_type_list = validate_segment_layers(layer_pattern) + transformer_config = TransformerConfig( + hidden_size=256, + num_layers=get_layer_type_list_physical_count(layer_type_list), + num_attention_heads=4, + hidden_dropout=0.0, + attention_dropout=0.0, + use_cpu_initialization=True, + ) + return HybridStack( + transformer_config, + hybrid_stack_spec.submodules, + layer_type_list=layer_type_list, + pp_layer_offset=0, + pg_collection=self.get_pg_collection(), + ) + + def get_attention_moe_block(self, layer_pattern): + layer_type_list = validate_segment_layers(layer_pattern) + transformer_config = TransformerConfig( + hidden_size=256, + num_layers=get_layer_type_list_physical_count(layer_type_list), + num_attention_heads=4, + ffn_hidden_size=256, + num_moe_experts=8, + expert_model_parallel_size=1, + moe_router_topk=2, + moe_grouped_gemm=True, + moe_token_dispatcher_type="alltoall", + hidden_dropout=0.0, + attention_dropout=0.0, + use_cpu_initialization=True, + ) + return HybridStack( + transformer_config, + hybrid_stack_spec.submodules, + layer_type_list=layer_type_list, + pp_layer_offset=0, + pg_collection=self.get_pg_collection( + required_pgs=[ + 'tp', + 'pp', + 'cp', + 'tp_cp', + 'tp_dp_cp', + 'ep', + 'expt_tp', + 'tp_ep', + 'expt_dp', + ] + ), + ) + def teardown_method(self, method): Utils.destroy_model_parallel() @@ -118,6 +179,107 @@ def test_layer_types(self): assert isinstance(layers[2], TransformerLayer) assert isinstance(layers[2].mlp, MLP) + def test_group_layer_type_builds_nested_hybrid_stack(self): + """Bracketed groups build an inner HybridStack with physical layer numbering.""" + layer_type_list = validate_segment_layers("M[M*]-") + transformer_config = TransformerConfig( + hidden_size=256, + num_layers=get_layer_type_list_physical_count(layer_type_list), + num_attention_heads=4, + use_cpu_initialization=True, + ) + block = HybridStack( + transformer_config, + hybrid_stack_spec.submodules, + layer_type_list=layer_type_list, + pp_layer_offset=0, + pg_collection=self.get_pg_collection(), + ) + assert isinstance(block.layers[0], MambaLayer) + assert isinstance(block.layers[1], HybridStack) + assert isinstance(block.layers[1].layers[0], MambaLayer) + assert isinstance(block.layers[1].layers[1], TransformerLayer) + assert isinstance(block.layers[2], TransformerLayer) + assert [layer.layer_number for layer in block.layers[1].layers] == [2, 3] + assert block.layers[2].layer_number == 4 + + def test_group_sharded_state_dict_uses_logical_layer_keys(self): + """Grouped attention+MLP layers share one Transformer-compatible checkpoint key.""" + layer_type_list = validate_segment_layers("[*-]") + transformer_config = TransformerConfig( + hidden_size=256, + num_layers=get_layer_type_list_physical_count(layer_type_list), + num_attention_heads=4, + use_cpu_initialization=True, + ) + block = HybridStack( + transformer_config, + hybrid_stack_spec.submodules, + layer_type_list=layer_type_list, + pp_layer_offset=0, + logical_layer_offset=0, + pg_collection=self.get_pg_collection(), + ) + + sharded_state_dict = block.sharded_state_dict(prefix="decoder.") + sharded_keys = {value.key for value in sharded_state_dict.values() if hasattr(value, "key")} + + assert "decoder.layers.0.self_attention.linear_qkv.weight" in sharded_keys + assert "decoder.layers.0.mlp.linear_fc1.weight" in sharded_keys + assert "decoder.layers.1.mlp.linear_fc1.weight" not in sharded_keys + assert "decoder.final_layernorm.weight" in sharded_keys + assert "decoder.final_norm.weight" not in sharded_keys + + def test_group_forward_matches_equivalent_flat_layers(self): + """A bracket group is only a scheduling/checkpoint boundary, not new math.""" + flat_block = self.get_attention_mlp_block("*-") + group_block = self.get_attention_mlp_block("[*-]") + + group_block.layers[0].layers[0].load_state_dict(flat_block.layers[0].state_dict()) + group_block.layers[0].layers[1].load_state_dict(flat_block.layers[1].state_dict()) + group_block.final_norm.load_state_dict(flat_block.final_norm.state_dict()) + + flat_block.cuda().eval() + group_block.cuda().eval() + sequence_length = 16 + micro_batch_size = 2 + hidden_states = torch.randn( + sequence_length, + micro_batch_size, + flat_block.config.hidden_size, + device="cuda", + ) + attention_mask = torch.ones( + (micro_batch_size, 1, sequence_length, sequence_length), + dtype=bool, + device="cuda", + ) + + with torch.no_grad(): + flat_output = flat_block(hidden_states.clone(), attention_mask=attention_mask) + group_output = group_block(hidden_states.clone(), attention_mask=attention_mask) + + torch.testing.assert_close(group_output, flat_output, rtol=0, atol=0) + + def test_group_overlap_callables_keep_ep_moe_split_visible(self): + """EP-overlap scheduling still sees dispatch/experts/combine inside a group.""" + block = self.get_attention_moe_block("[*E]") + + forward_callables, bwd_dw_callable_map, is_moe, num_local_experts = ( + build_hybrid_stack_callables(block.layers[0], layer_type=block.layer_type_list[0]) + ) + + pre_dispatch, dispatch, experts, combine, mtp_post_process = forward_callables + assert callable(pre_dispatch) + assert callable(dispatch) + assert callable(experts) + assert callable(combine) + assert mtp_post_process is None + assert is_moe + assert num_local_experts == 8 + assert "attn" in bwd_dw_callable_map + assert "mlp" in bwd_dw_callable_map + def test_invalid_layer_types_cause_failure(self): invalid_symbol = '+' assert invalid_symbol not in Symbols.VALID_LAYERS # sanity check. diff --git a/tests/unit_tests/ssm/test_hybrid_layer_allocation.py b/tests/unit_tests/ssm/test_hybrid_layer_allocation.py index fe0d7c2dc1e..896fb0bddb4 100644 --- a/tests/unit_tests/ssm/test_hybrid_layer_allocation.py +++ b/tests/unit_tests/ssm/test_hybrid_layer_allocation.py @@ -15,6 +15,7 @@ parse_hybrid_pattern, pattern_from_ratios, select_pipeline_segment, + select_pipeline_segment_with_logical_offset, validate_segment_layers, ) @@ -71,6 +72,8 @@ def test_valid_patterns(self): """Test that valid segment patterns produce the correct layer type lists.""" test_cases = [ ("M*-M*-M*-", ['M', '*', '-', 'M', '*', '-', 'M', '*', '-']), + ("M[M*]-", ['M', ('M', '*'), '-']), + ("[M*E]", [('M', '*', 'E')]), ("MMMMMMMMM", ['M'] * 9), ("MM*-MM*-", ['M', 'M', '*', '-', 'M', 'M', '*', '-']), ("E", ['E']), @@ -98,6 +101,10 @@ def test_invalid_symbols_cause_failure(self): validate_segment_layers("M|M") # pipe not valid in a segment with pytest.raises(ValueError): validate_segment_layers("M/M") # MTP separator not valid in a segment + with pytest.raises(ValueError): + validate_segment_layers("M[[M]]") # nested groups are not valid + with pytest.raises(ValueError): + validate_segment_layers("M[EM]") # MoE must be last in a group with pytest.raises(ValueError): # Not allowed to have both standard Attention and MLA/DSA validate_segment_layers("MDM*-") @@ -110,6 +117,8 @@ def test_simple_patterns(self): assert get_hybrid_total_layer_count("M*M*") == 4 assert get_hybrid_total_layer_count("MMMM") == 4 assert get_hybrid_total_layer_count("M") == 1 + assert get_hybrid_total_layer_count("[M*E]") == 3 + assert get_hybrid_total_layer_count("M[M*]-") == 4 def test_with_pipe_separators(self): assert get_hybrid_total_layer_count("M-M-|M-M*-") == 9 @@ -155,6 +164,8 @@ def test_main_pattern_only(self): """Test patterns without MTP (no / separator).""" test_cases = [ ("M*M*", "M*M*"), + ("[M*E]", "[M*E]"), + ("M[M*]-", "M[M*]-"), ("MMMM", "MMMM"), ("*M*M", "*M*M"), ("MM-*", "MM-*"), @@ -231,11 +242,22 @@ def test_invalid_symbols_in_main_pattern(self): "M*X*", # X is not valid "MaMM", # a is not valid "M*M*1", # 1 is not valid + "M[M*]X", # X is not valid after a group ] for pattern in invalid_patterns: with pytest.raises(ValueError, match="not a valid layer symbol"): parse_hybrid_pattern(pattern) + def test_invalid_group_syntax(self): + with pytest.raises(ValueError, match="without a matching"): + parse_hybrid_pattern("M[M*") + with pytest.raises(ValueError, match="not supported"): + parse_hybrid_pattern("M[M[*]]") + with pytest.raises(ValueError, match="cannot be empty"): + parse_hybrid_pattern("M[]") + with pytest.raises(ValueError, match="must be the last"): + parse_hybrid_pattern("M[EM]") + def test_invalid_symbols_in_mtp_pattern(self): """Test that invalid symbols in MTP pattern raise ValueError.""" # Single MTP depth with invalid symbol - should raise "not a valid layer symbol" @@ -350,6 +372,16 @@ def test_with_pipes_and_mtp(self): def test_moe_pattern(self): assert get_hybrid_layer_counts("MEME") == {'*': 0, 'D': 0, 'G': 0, 'M': 2, '-': 0, 'E': 2} + def test_group_pattern(self): + assert get_hybrid_layer_counts("M[M*]E") == { + '*': 1, + 'D': 0, + 'G': 0, + 'M': 2, + '-': 0, + 'E': 1, + } + def test_mtp_with_attention(self): # MTP pattern "*M" repeated 3 depths -> 3 attn + 3 mamba from MTP assert get_hybrid_layer_counts("MMMM/*M/*M/*M") == { @@ -414,6 +446,21 @@ def test_four_segments(self, mock_log): assert layer_types == expected_layers, f"Failed for vp_stage={vp_stage}" assert offset == expected_offset, f"Failed for vp_stage={vp_stage}" + @patch('megatron.core.models.hybrid.hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_group_segment_offsets(self, mock_log): + layer_types, offset = select_pipeline_segment("[M*E]|M-", pp_group=None, vp_stage=1) + assert layer_types == ['M', '-'] + assert offset == 3 + + @patch('megatron.core.models.hybrid.hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_group_segment_logical_offsets(self, mock_log): + layer_types, physical_offset, logical_offset = select_pipeline_segment_with_logical_offset( + "[*-][*-]|[*E][*E]", pp_group=None, vp_stage=1 + ) + assert layer_types == [('*', 'E'), ('*', 'E')] + assert physical_offset == 4 + assert logical_offset == 2 + @patch('megatron.core.models.hybrid.hybrid_layer_allocation.log_on_each_pipeline_stage') def test_empty_segment(self, mock_log): """Empty segments are allowed for pipeline balancing.""" @@ -688,3 +735,12 @@ def test_all_mamba(self): assert mamba_map == {0: 0, 1: 1, 2: 2} assert mlp_map == {} assert moe_map == {} + + def test_grouped_layers_are_flattened(self): + maps = get_layer_maps_from_layer_type_list([("M", "*", "E"), "M"]) + attention_map, mamba_map, moe_map = operator.itemgetter( + Symbols.ATTENTION, Symbols.MAMBA, Symbols.MOE + )(maps) + assert attention_map == {1: 0} + assert mamba_map == {0: 0, 3: 1} + assert moe_map == {2: 0} From d2daf292593077f13ca8910f736ee03c8cfbc49e Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Wed, 6 May 2026 09:16:03 -0700 Subject: [PATCH 02/16] Cleanup hybrid EP-overlap callables and HybridStack final-norm path hybrid/fine_grained_callables.py: remove _SharedExpertBackwardDWWrapper. The wrapper only existed to call mlp.backward_dw(routed_experts=False, shared_experts=True) under a separate scheduling slot; the same wgrad work runs correctly when MoELayer.shared_experts is registered as a sibling callable in backward_dw["mlp"], which the schedule node already iterates. Skip the shared callable when shared_expert_overlap is enabled (its wgrad is folded into the dispatcher overlap path). gpt/fine_grained_callables.py + hybrid/hybrid_block.py: revert the final_layernorm/final_norm dual lookup in PostProcessNode; instead expose final_layernorm as a property on HybridStack that returns final_norm. GPT's PostProcessNode now finds the final norm under the same attribute name it always used, without GPT-side changes. The registered submodule stays final_norm so existing hybrid checkpoint keys are unchanged. hybrid/fine_grained_callables.py: inline _apply_mamba_layer and _apply_attention_layer at their single call sites in pre_dispatch_computation. The attention case still uses item_layer._forward_attention(...) rather than item_layer(...) because attention half-layers have mlp=IdentityOp and mlp_bda=IdentityFuncOp, so __call__ would route through _forward_mlp + mlp_bda and double-apply the post-attention residual. Comment added at the call site. hybrid/hybrid_layer_allocation.py: reject bracketed groups inside MTP patterns. Each MTP depth is itself a fused unit, so wrapping its symbols in '[...]' has no defined meaning and breaks downstream construction. Raises ValueError with a clear message. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../core/models/gpt/fine_grained_callables.py | 3 +- .../models/hybrid/fine_grained_callables.py | 92 +++++++------------ megatron/core/models/hybrid/hybrid_block.py | 15 ++- .../models/hybrid/hybrid_layer_allocation.py | 10 ++ 4 files changed, 59 insertions(+), 61 deletions(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 03d963c481b..fa2a2ec4934 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -204,8 +204,7 @@ def forward_impl(self, hidden_states): """ empty_decoder = len(self.gpt_model.decoder.layers) == 0 - layer_norm = getattr(self.gpt_model.decoder, "final_layernorm", None) - layer_norm = layer_norm or getattr(self.gpt_model.decoder, "final_norm", None) + layer_norm = self.gpt_model.decoder.final_layernorm if not self.gpt_model.config.mtp_num_layers and empty_decoder and layer_norm: hidden_states = layer_norm(hidden_states) hidden_states = make_viewless_tensor( diff --git a/megatron/core/models/hybrid/fine_grained_callables.py b/megatron/core/models/hybrid/fine_grained_callables.py index b94111e3f36..79c66d87fe1 100644 --- a/megatron/core/models/hybrid/fine_grained_callables.py +++ b/megatron/core/models/hybrid/fine_grained_callables.py @@ -1,7 +1,6 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. from contextlib import nullcontext -from functools import partial from typing import Optional import torch @@ -11,13 +10,11 @@ from megatron.core.fp4_utils import get_fp4_context from megatron.core.fp8_utils import get_fp8_context from megatron.core.models.hybrid.hybrid_block import HybridStack -from megatron.core.models.hybrid.hybrid_layer_allocation import ( - LayerPatternItem, - Symbols as LayerSymbols, - is_layer_group, -) +from megatron.core.models.hybrid.hybrid_layer_allocation import LayerPatternItem +from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols as LayerSymbols +from megatron.core.models.hybrid.hybrid_layer_allocation import is_layer_group from megatron.core.pipeline_parallel.utils import ScheduleNode -from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor +from megatron.core.transformer.transformer_layer import make_viewless_tensor def _get_inner_quant_context(layer): @@ -37,32 +34,6 @@ def _as_hybrid_layers(layer, layer_type: Optional[LayerPatternItem]): return [(layer_type, layer)] -def _apply_attention_layer( - layer: TransformerLayer, - node: ScheduleNode, - hidden_states: Tensor, -): - hidden_states, _ = layer._forward_attention( - hidden_states=hidden_states, - attention_mask=node.chunk_state.attention_mask, - rotary_pos_emb=node.chunk_state.rotary_pos_emb, - rotary_pos_cos=node.chunk_state.rotary_pos_cos, - rotary_pos_sin=node.chunk_state.rotary_pos_sin, - packed_seq_params=node.chunk_state.packed_seq_params, - sequence_len_offset=node.chunk_state.sequence_len_offset, - ) - return hidden_states - - -def _apply_mamba_layer(layer, node: ScheduleNode, hidden_states: Tensor): - return layer( - hidden_states=hidden_states, - attention_mask=node.chunk_state.attention_mask, - inference_context=getattr(node.chunk_state, "inference_context", None), - packed_seq_params=node.chunk_state.packed_seq_params, - ) - - def _maybe_apply_final_norm(node: ScheduleNode, hidden_states: Tensor): final_norm = getattr(node.chunk_state.model.decoder, "final_norm", None) final_norm = final_norm or getattr(node.chunk_state.model.decoder, "final_layernorm", None) @@ -82,24 +53,6 @@ def _get_moe_padding_mask(node: ScheduleNode): return padding_mask -class _SharedExpertBackwardDWWrapper: - """Backward weight-gradient wrapper for MoE-only hybrid terminal layers.""" - - def __init__(self, layer): - self.layer = layer - self.shared_expert_dw_callable = None - if layer.mlp.use_shared_expert: - self.shared_expert_dw_callable = partial( - layer.mlp.backward_dw, routed_experts=False, shared_experts=True - ) - - def backward_dw(self): - if self.shared_expert_dw_callable is not None: - self.shared_expert_dw_callable() - self.layer = None - self.shared_expert_dw_callable = None - - def _run_moe_preprocess(layer, node: ScheduleNode, hidden_states: Tensor): pre_mlp_layernorm_output = layer._forward_pre_mlp_layernorm(hidden_states) if isinstance(pre_mlp_layernorm_output, tuple): @@ -199,13 +152,30 @@ def pre_dispatch_computation(node: ScheduleNode, hidden_states: Tensor): for item_type, item_layer in pre_layers: with _get_inner_quant_context(item_layer): if item_type == LayerSymbols.MAMBA: - hidden_states = _apply_mamba_layer(item_layer, node, hidden_states) + hidden_states = item_layer( + hidden_states=hidden_states, + attention_mask=node.chunk_state.attention_mask, + inference_context=getattr(node.chunk_state, "inference_context", None), + packed_seq_params=node.chunk_state.packed_seq_params, + ) elif item_type in ( LayerSymbols.ATTENTION, LayerSymbols.DS_ATTENTION, LayerSymbols.GDN, ): - hidden_states = _apply_attention_layer(item_layer, node, hidden_states) + # Use _forward_attention rather than __call__: an attention half-layer has + # mlp=IdentityOp / mlp_bda=IdentityFuncOp by default, and TransformerLayer's + # __call__ would route through _forward_mlp + mlp_bda, double-applying the + # post-attention residual. + hidden_states, _ = item_layer._forward_attention( + hidden_states=hidden_states, + attention_mask=node.chunk_state.attention_mask, + rotary_pos_emb=node.chunk_state.rotary_pos_emb, + rotary_pos_cos=node.chunk_state.rotary_pos_cos, + rotary_pos_sin=node.chunk_state.rotary_pos_sin, + packed_seq_params=node.chunk_state.packed_seq_params, + sequence_len_offset=node.chunk_state.sequence_len_offset, + ) else: raise ValueError( f"HybridStack overlap does not support layer type '{item_type}' before " @@ -268,10 +238,18 @@ def raise_not_implemented(*args): item_layer.init_backward_dw_wrapper() pre_bwd_dw.append(item_layer.backward_dw_wrapper) if is_moe: - shared_expert_dw = _SharedExpertBackwardDWWrapper(terminal_layer) - if shared_expert_dw.shared_expert_dw_callable is not None: - pre_bwd_dw.append(shared_expert_dw) - backward_dw["mlp"] = terminal_layer.mlp + # MoELayer.backward_dw default kwargs (routed_experts=True, shared_experts=False) handle + # the routed-experts wgrad. The shared-experts wgrad is registered as a sibling callable + # under "mlp" so the schedule node iterates both. Skip registering the shared-experts + # callable when shared_expert_overlap is enabled — in that case the shared-experts + # forward and backward are folded into the dispatcher's overlap handling. + mlp_backward_callables = [terminal_layer.mlp] + if ( + terminal_layer.mlp.use_shared_expert + and not terminal_layer.mlp.shared_expert_overlap + ): + mlp_backward_callables.append(terminal_layer.mlp.shared_experts) + backward_dw["mlp"] = mlp_backward_callables elif terminal_type == LayerSymbols.MLP: backward_dw["mlp"] = terminal_layer.mlp diff --git a/megatron/core/models/hybrid/hybrid_block.py b/megatron/core/models/hybrid/hybrid_block.py index d9045aed0b3..1c0209da0c1 100644 --- a/megatron/core/models/hybrid/hybrid_block.py +++ b/megatron/core/models/hybrid/hybrid_block.py @@ -19,9 +19,9 @@ from megatron.core.fp4_utils import get_fp4_context from megatron.core.fp8_utils import get_fp8_context from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.models.hybrid.hybrid_layer_allocation import LayerPatternItem +from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols as LayerSymbols from megatron.core.models.hybrid.hybrid_layer_allocation import ( - LayerPatternItem, - Symbols as LayerSymbols, get_layer_type_physical_count, is_layer_group, ) @@ -219,6 +219,17 @@ def __init__( eps=self.config.layernorm_epsilon, ) + @property + def final_layernorm(self): + """Alias for ``final_norm`` matching the attribute name on TransformerBlock. + + Lets generic decoder consumers (e.g. ``GPTModel.PostProcessNode``) discover the + final norm via the same attribute name they use for non-hybrid decoders, while + keeping ``final_norm`` as the registered submodule so existing hybrid checkpoint + keys are unchanged. + """ + return getattr(self, "final_norm", None) + def set_input_tensor(self, input_tensor: Tensor): """Set input tensor to be used instead of forward()'s input. diff --git a/megatron/core/models/hybrid/hybrid_layer_allocation.py b/megatron/core/models/hybrid/hybrid_layer_allocation.py index 57326301cc5..fdd51a94459 100644 --- a/megatron/core/models/hybrid/hybrid_layer_allocation.py +++ b/megatron/core/models/hybrid/hybrid_layer_allocation.py @@ -321,6 +321,16 @@ def parse_hybrid_pattern(pattern: Optional[str]) -> ParsedHybridPattern: _validate_pattern(mtp_pattern, "MTP", allow_pipe=False) + # MTP layers are themselves a fused unit (each MTP depth contains its own attention + # + MLP), so it does not make sense to wrap them in a HybridStack group. Reject + # bracketed groups inside MTP patterns to keep downstream construction simple. + if Symbols.GROUP_START in mtp_pattern or Symbols.GROUP_END in mtp_pattern: + raise ValueError( + f"In MTP pattern, layer groups '{Symbols.GROUP_START}...{Symbols.GROUP_END}' " + f"are not supported because each MTP depth is already a fused unit. " + f"Got MTP pattern: '{mtp_pattern}'." + ) + return ParsedHybridPattern( main_pattern=main_pattern if main_pattern else None, mtp_pattern=mtp_pattern, From f162e1683fc81e1f24bf4dfff155c32a498cefb1 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Wed, 6 May 2026 11:37:57 -0700 Subject: [PATCH 03/16] Move hybrid schedule plan into hybrid/, share base via subclassing Restore TransformerLayerSchedulePlan and TransformerModelChunkSchedulePlan in core/models/common/model_chunk_schedule_plan.py to their pre-PR shape (GPT/MTP only, no layer_type concept) and drop the backward-compat aliases. The hybrid plan now lives in core/models/hybrid/model_chunk_schedule_plan.py as HybridStackSchedulePlan / HybridStackModelChunkSchedulePlan, subclassed from the GPT classes. To make the subclass small, the GPT base classes grow three extension points: - LAYER_SCHEDULE_PLAN_CLASS picks the per-layer plan class. - PRE_PROCESS_NODE_CLASS / POST_PROCESS_NODE_CLASS pick the embedding / output-layer node classes. - _extra_args_for_layer is the hook subclasses override to thread per-layer metadata into the layer plan constructor. Defaults fall back to the GPT classes so non-hybrid callers see no behavior change. HybridModel.forward now delegates to the existing _preprocess / _postprocess methods instead of inlining the embedding / rotary / output-layer / loss code, so the eager forward and the EP-overlap PreProcessNode / PostProcessNode read the same chunk_state slots. HybridStackNode (in hybrid/fine_grained_callables.py) is the schedule node class used for hybrid layer plans. Subclassed from TransformerLayerNode so the runtime backbone is shared, but the free-input policy is now resolved through a method (_resolve_free_input) that subclasses override. The hybrid class currently delegates to should_free_input but exists so any hybrid-specific policy can be added surgically without touching GPT. HybridPreProcessNode and HybridPostProcessNode (in hybrid/model_chunk_schedule_plan.py) mirror the GPT counterparts but take a HybridModel rather than a GPTModel, so the hybrid schedule plan does not have to pass a HybridModel under a GPT-named attribute. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../common/model_chunk_schedule_plan.py | 104 ++++---- .../core/models/gpt/fine_grained_callables.py | 11 +- .../models/hybrid/fine_grained_callables.py | 32 +++ megatron/core/models/hybrid/hybrid_model.py | 174 +++----------- .../hybrid/model_chunk_schedule_plan.py | 224 ++++++++++++++++++ 5 files changed, 358 insertions(+), 187 deletions(-) create mode 100644 megatron/core/models/hybrid/model_chunk_schedule_plan.py diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index 3d3591ea6dd..0430987fcc5 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -27,8 +27,8 @@ class ModelChunkState: pass -class HybridStackSchedulePlan: - """Schedule the executing plan of nodes in a transformer, MTP, or hybrid layer. +class TransformerLayerSchedulePlan: + """Schedule the executing plan of the nodes in a transformer/mtp layer. This class organizes the sub-modules of a transformer/mtp layer, including attention, post attention, MLP, dispatch, combine and @@ -55,7 +55,7 @@ class HybridStackSchedulePlan: mtp_post_process = None def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_args={}): - """Initializes a layer schedule plan. + """Initializes a transformer layer schedule plan. Args: layer (TransformerLayer): @@ -76,12 +76,11 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar self.layer_state = TransformerLayerState() self.chunk_state = chunk_state self.layer = layer - self.layer_type = extra_args.get("layer_type", None) self.event = event self.comp_stream = comp_stream self.comm_stream = comm_stream - # get callable nodes for transformer/mtp/hybrid layer + # get callable nodes for transformer/mtp layer self._build_callable_nodes(event, comp_stream, comm_stream, extra_args) def release_state(self): @@ -109,35 +108,24 @@ def release_state(self): def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args): """ - Builds the callable nodes for the transformer/mtp/hybrid layer: + Builds the callable nodes for the transformer/mtp layer: attn, mlp, moe_dispatch and moe_combine, and mtp_post_process. """ from megatron.core.models.gpt.fine_grained_callables import ( TransformerLayerNode, build_layer_callables, ) - from megatron.core.models.hybrid.fine_grained_callables import ( - build_hybrid_stack_callables, - ) - from megatron.core.models.hybrid.hybrid_layer_allocation import LayerPatternItem from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer - layer_type: LayerPatternItem = extra_args.get("layer_type", None) - if layer_type is None: - # build the forward and backward callables for the transformer/mtp layer - fwd_callables, bwd_dw_callable_map = build_layer_callables(self.layer) + # build the forward and backward callables for the transformer/mtp layer + fwd_callables, bwd_dw_callable_map = build_layer_callables(self.layer) - # get flags for later use - is_mtp = isinstance(self.layer, MultiTokenPredictionLayer) - transformer_layer = self.layer.mtp_model_layer if is_mtp else self.layer - is_moe = isinstance(transformer_layer.mlp, MoELayer) - num_local_experts = transformer_layer.mlp.num_local_experts if is_moe else None - else: - fwd_callables, bwd_dw_callable_map, is_moe, num_local_experts = ( - build_hybrid_stack_callables(self.layer, layer_type=layer_type) - ) - is_mtp = False + # get flags for latter use + is_mtp = isinstance(self.layer, MultiTokenPredictionLayer) + transformer_layer = self.layer.mtp_model_layer if is_mtp else self.layer + is_moe = isinstance(transformer_layer.mlp, MoELayer) + num_local_experts = transformer_layer.mlp.num_local_experts if is_moe else None extra_args["config"] = self.layer.config extra_args["is_moe"] = is_moe @@ -148,9 +136,6 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args): # wrapper to help create TransformerLayerNode def create_node(stream, module, name): bwd_dw_callables = bwd_dw_callable_map.get(name, None) - node_extra_args = dict(extra_args) - if bwd_dw_callables is None: - node_extra_args["delay_wgrad_compute"] = False return TransformerLayerNode( stream, event, @@ -159,7 +144,7 @@ def create_node(stream, module, name): module, name=name, bwd_dw_callables=bwd_dw_callables, - extra_args=node_extra_args, + extra_args=extra_args, ) ( @@ -195,8 +180,6 @@ def get_fp8_context(self): use_inner_fp8_context = ( self.layer.config.fp8 and self.layer.config.fp8_recipe != Fp8Recipe.delayed ) - if self.layer_type is not None or not hasattr(self.layer, "layer_number"): - return nullcontext() return ( get_fp8_context(self.layer.config, self.layer.layer_number - 1) if use_inner_fp8_context @@ -271,7 +254,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False) return f_input, b_grad -class HybridStackModelChunkSchedulePlan(AbstractSchedulePlan): +class TransformerModelChunkSchedulePlan(AbstractSchedulePlan): """Schedule the executing plan of the sub-modules in a model chunk sub-modules. This class organizes the computation nodes for a model chunk, @@ -284,8 +267,27 @@ class HybridStackModelChunkSchedulePlan(AbstractSchedulePlan): │ ├── layer[1]: TransformerLayerSchedulePlan │ └── ... └── post_process: PostProcessNode + + Subclasses can swap the per-layer schedule plan by overriding the + ``LAYER_SCHEDULE_PLAN_CLASS`` class attribute (e.g. HybridStack uses a + layer plan that understands grouped/inferred layer types). They can also + swap the pre/post-process node classes via ``PRE_PROCESS_NODE_CLASS`` / + ``POST_PROCESS_NODE_CLASS`` so each model owns its own embedding / output + layer node implementations. """ + #: The TransformerLayerSchedulePlan-compatible class used to build per-layer + #: schedule plans. Subclasses override this to inject a layer-plan variant. + LAYER_SCHEDULE_PLAN_CLASS = None + + #: Pre/post-process node classes. Defaults below pull in the GPT-side + #: ``PreProcessNode`` / ``PostProcessNode`` (which call ``GPTModel._preprocess`` / + #: ``GPTModel._postprocess``). Subclasses set these to model-specific node + #: classes so the node calls the right model's ``_preprocess`` / + #: ``_postprocess`` methods. + PRE_PROCESS_NODE_CLASS = None + POST_PROCESS_NODE_CLASS = None + def __init__( self, model, @@ -322,6 +324,9 @@ def __init__( """ from megatron.core.models.gpt.fine_grained_callables import PostProcessNode, PreProcessNode + pre_process_cls = self.PRE_PROCESS_NODE_CLASS or PreProcessNode + post_process_cls = self.POST_PROCESS_NODE_CLASS or PostProcessNode + self._model_chunk_state = ModelChunkState() self._transformer_layers = [] self._event = torch.cuda.Event() @@ -347,7 +352,7 @@ def __init__( self._model_chunk_state.attention_bias = None # build preprocess - self.pre_process = PreProcessNode( + self.pre_process = pre_process_cls( model, self._model_chunk_state, self._event, get_comp_stream ) @@ -361,23 +366,18 @@ def __init__( # build post process if model.post_process: - self.post_process = PostProcessNode( + self.post_process = post_process_cls( model, self._model_chunk_state, self._event, get_comp_stream ) def _build_layer_schedule_plan(self, module, comp_stream, comm_stream): if module is None: return + plan_cls = self.LAYER_SCHEDULE_PLAN_CLASS or TransformerLayerSchedulePlan num_layers = len(module.layers) for layer_idx in range(num_layers): - extra_args = { - "is_first_layer": layer_idx == 0, - "is_last_layer": layer_idx == num_layers - 1, - } - extra_args["layer_type"] = ( - module.layer_type_list[layer_idx] if hasattr(module, "layer_type_list") else None - ) - layer_plan = HybridStackSchedulePlan( + extra_args = self._extra_args_for_layer(module, layer_idx, num_layers) + layer_plan = plan_cls( module.layers[layer_idx], self.event, self.state, @@ -387,6 +387,17 @@ def _build_layer_schedule_plan(self, module, comp_stream, comm_stream): ) self._transformer_layers.append(layer_plan) + def _extra_args_for_layer(self, module, layer_idx, num_layers): + """Per-layer ``extra_args`` dict passed to the layer plan constructor. + + Subclasses extend this hook to thread additional metadata (e.g. hybrid + layer-type symbols) without overriding ``_build_layer_schedule_plan``. + """ + return { + "is_first_layer": layer_idx == 0, + "is_last_layer": layer_idx == num_layers - 1, + } + @property def event(self): """Gets the CUDA event for synchronization.""" @@ -496,7 +507,7 @@ def run( b_layer = b_schedule_plan.pop_layer() nvtx_msg = f"layer_{i}f-layer_{b_schedule_plan.num_layers()}b" nvtx_range_push(nvtx_msg) - f_input, b_grad = HybridStackSchedulePlan.run( + f_input, b_grad = TransformerLayerSchedulePlan.run( f_layer, b_layer, f_input=f_input, @@ -512,7 +523,7 @@ def run( b_layer = b_schedule_plan.pop_layer() nvtx_msg = f"layer_{b_schedule_plan.num_layers()}b" nvtx_range_push(nvtx_msg) - _, b_grad = HybridStackSchedulePlan.run( + _, b_grad = TransformerLayerSchedulePlan.run( None, b_layer, b_grad=b_grad, is_last_layer_in_bwd=(i == b_num_layers - 1) ) if i < b_num_layers - 1: @@ -524,7 +535,7 @@ def run( f_layer = f_schedule_plan.get_layer(i) nvtx_msg = f"layer_{i}f" nvtx_range_push(nvtx_msg) - f_input, _ = HybridStackSchedulePlan.run(f_layer, None, f_input=f_input) + f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input) nvtx_range_pop(nvtx_msg) if f_schedule_plan is not None and post_forward is not None: @@ -562,8 +573,3 @@ def run( b_schedule_plan.release_state() return f_input - - -# Backward-compatible aliases for GPT callers and existing tests. -TransformerLayerSchedulePlan = HybridStackSchedulePlan -TransformerModelChunkSchedulePlan = HybridStackModelChunkSchedulePlan diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index fa2a2ec4934..d96743228d3 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -270,7 +270,7 @@ def __init__( assert config is not None, "model config must be passed to TransformerLayerNode." is_moe = extra_args.get("is_moe", False) num_local_experts = extra_args.get("num_local_experts", None) - free_input = should_free_input(name, is_moe, config, num_local_experts) + free_input = self._resolve_free_input(name, is_moe, config, num_local_experts) self.delay_wgrad_compute = extra_args.get("delay_wgrad_compute", False) super().__init__( @@ -299,6 +299,15 @@ def __init__( bwd_dw_callables if isinstance(bwd_dw_callables, list) else [bwd_dw_callables] ) + @staticmethod + def _resolve_free_input(name, is_moe, config, num_local_experts): + """Free-input policy hook. Subclasses override to specialize. + + Default delegates to module-level ``should_free_input`` (the GPT MoE + EP-overlap policy). + """ + return should_free_input(name, is_moe, config, num_local_experts) + def detach(self, t): """Detaches a tensor and stores it for backward computation.""" detached = make_viewless(t).detach() diff --git a/megatron/core/models/hybrid/fine_grained_callables.py b/megatron/core/models/hybrid/fine_grained_callables.py index 79c66d87fe1..c08a0956056 100644 --- a/megatron/core/models/hybrid/fine_grained_callables.py +++ b/megatron/core/models/hybrid/fine_grained_callables.py @@ -9,6 +9,7 @@ from megatron.core.enums import Fp8Recipe from megatron.core.fp4_utils import get_fp4_context from megatron.core.fp8_utils import get_fp8_context +from megatron.core.models.gpt.fine_grained_callables import TransformerLayerNode, should_free_input from megatron.core.models.hybrid.hybrid_block import HybridStack from megatron.core.models.hybrid.hybrid_layer_allocation import LayerPatternItem from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols as LayerSymbols @@ -17,6 +18,37 @@ from megatron.core.transformer.transformer_layer import make_viewless_tensor +class HybridStackNode(TransformerLayerNode): + """Schedule node for HybridStack-built fine-grained callables. + + Subclassed from ``TransformerLayerNode`` so the runtime backbone (forward / + backward / backward_dw plumbing, detach bookkeeping, output-grad release) + is shared. The hybrid path keeps a separate node class so its free-input + policy can diverge from the GPT defaults — for example, the ``attn`` slot + here covers the whole pre-dispatch loop (mamba + attention + …) rather than + a single attention block, and group-level decisions about whether the input + is needed in backward may differ from ``should_free_input`` in + ``gpt/fine_grained_callables.py``. Keep this override thin until a hybrid + counter-example forces it to diverge; the explicit subclass exists so the + divergence can be made surgically without touching the GPT class. + """ + + @staticmethod + def _resolve_free_input(name, is_moe, config, num_local_experts): + """Hybrid free-input policy. + + Currently mirrors the GPT default: dense layers always retain their + input for backward; MoE-only "moe_dispatch", "mlp", and "moe_combine" + slots can free, subject to the dispatcher / cuda-graph constraints + encoded in ``should_free_input``. Hybrid groups have an "attn" slot + whose semantics differ, but its policy resolves to ``False`` in + ``should_free_input``, which is correct: pre-layer outputs are needed + for backward through the loop. Override here when a hybrid-specific + rule is needed. + """ + return should_free_input(name, is_moe, config, num_local_experts) + + def _get_inner_quant_context(layer): config = layer.config if config.fp8 and config.fp8_recipe != Fp8Recipe.delayed: diff --git a/megatron/core/models/hybrid/hybrid_model.py b/megatron/core/models/hybrid/hybrid_model.py index 3473a7a43e1..21deb431c59 100644 --- a/megatron/core/models/hybrid/hybrid_model.py +++ b/megatron/core/models/hybrid/hybrid_model.py @@ -542,7 +542,7 @@ def build_schedule_plan( if self.config.fine_grained_activation_offloading: self.preprocess_for_fine_grained_offloading() - from ..common.model_chunk_schedule_plan import HybridStackModelChunkSchedulePlan + from .model_chunk_schedule_plan import HybridStackModelChunkSchedulePlan return HybridStackModelChunkSchedulePlan( self, @@ -632,62 +632,34 @@ def forward( It either returns the Loss values if labels are given or the final hidden units """ - # If decoder_input is provided (not None), then input_ids and position_ids are ignored. - # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. - if self.config.fine_grained_activation_offloading: self.preprocess_for_fine_grained_offloading() inference_context = deprecate_inference_params(inference_context, inference_params) in_inference_mode = inference_context is not None and not self.training - if in_inference_mode: assert runtime_gather_output, "Inference must always gather TP logits" - # Decoder embedding. - if decoder_input is not None: - pass - elif self.pre_process: - decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) - - # Clear the outputs for padding tokens when using dynamic batching with - # quantization scales to avoid corrupting amax calculations - if ( - in_inference_mode - and inference_context.is_dynamic_batching() - and is_using_quantization_scales(self.config) - ): - decoder_input[inference_context.padding_slice] = 0.0 - else: - # intermediate stage of pipeline - # decoder will get hidden_states from encoder.input_tensor - decoder_input = None - - rotary_pos_emb = None - if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_context, self.decoder, decoder_input, self.config, packed_seq_params - ) - rotary_pos_emb = self.rotary_pos_emb( - rotary_seq_len, - packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', - ) - elif self.position_embedding_type == 'yarn': - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_context, self.decoder, decoder_input, self.config, packed_seq_params - ) - # YarnRotaryEmbedding.forward returns (emb, mscale); discard mscale here - rotary_pos_emb, _ = self.rotary_pos_emb( - rotary_seq_len, - packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', - ) - - # Wrap decoder_input to allow the decoder (HybridStack) to delete the - # reference held by this caller function, enabling early garbage collection - # for inference. - if in_inference_mode: - decoder_input = WrappedTensor(decoder_input) + # Mirror GPTModel.forward: delegate the embedding / rotary computation and + # the output-layer / MTP / loss computation to the same hooks the + # EP-overlap PreProcessNode / PostProcessNode call. Keeps the eager and + # combined-1F1B paths on the same code. + ( + decoder_input, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + padding_mask, + ) = self._preprocess( + input_ids=input_ids, + position_ids=position_ids, + decoder_input=decoder_input, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + padding_mask=padding_mask, + ) # The following assert will currently fail when running inference. # Commented out for now. @@ -709,93 +681,21 @@ def forward( padding_mask=padding_mask, ) - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - - # Check if speculative decoding is active. When it is, MTP must be - # computed *after* verification so that it is conditioned on verified - # tokens rather than stale speculative tokens from the previous step. - if is_spec_decode is None: - is_spec_decode = ( - in_inference_mode - and inference_context.is_dynamic_batching() - and inference_context.num_speculative_tokens > 0 - ) - - mtp_forward_ran = self.mtp_process and not (in_inference_mode or is_spec_decode) - if mtp_forward_ran: - hidden_states = self.mtp( - input_ids=input_ids, - position_ids=position_ids, - hidden_states=hidden_states, - attention_mask=attention_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - packed_seq_params=packed_seq_params, - embedding=self.embedding, - ) - - if not self.post_process: - return hidden_states - - if self.config.mtp_num_layers is not None and self.mtp_process: - assert self.config.mtp_num_layers > 0 - if in_inference_mode or is_spec_decode: - self._decoder_hidden_states_cache = hidden_states - else: - hidden_states = process_mtp_loss( - hidden_states=hidden_states, - labels=labels, - loss_mask=loss_mask, - output_layer=self.output_layer, - output_weight=output_weight, - runtime_gather_output=runtime_gather_output, - is_training=self.training, - compute_language_model_loss=self.compute_language_model_loss, - config=self.config, - cp_group=self.pg_collection.cp, - packed_seq_params=packed_seq_params, - scale_logits_fn=self._scale_logits if self.config.use_mup else None, - ) - sequence_parallel_override = False - if in_inference_mode and inference_context.config.materialize_only_last_token_logits: - if inference_context.is_static_batching(): - hidden_states = hidden_states[-1:, :, :] - else: - if self.output_layer.sequence_parallel: - # Perform the sequence parallel gather here instead of after the output layer - # because we need to slice the last token logits from the full view of the - # packed logits across all requests. - hidden_states = gather_from_sequence_parallel_region( - hidden_states, group=self.pg_collection.tp - ) - self.output_layer.sequence_parallel = False - sequence_parallel_override = True - - # Reshape [S, B, H] (with B=1) to [1, S, H] for logit extraction, - # then back to [S', B, H] for the output layer. - reshaped = hidden_states.squeeze(1).unsqueeze(0) - hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1) - - logits, _ = self.output_layer( - hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output + return self._postprocess( + hidden_states=hidden_states, + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + mtp_in_postprocess=True, + loss_mask=loss_mask, + attention_mask=attention_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + runtime_gather_output=runtime_gather_output, + inference_context=inference_context, + is_spec_decode=is_spec_decode, ) - logits = self._scale_logits(logits) - - # Restore sequence parallel execution to the output layer if necessary. - if sequence_parallel_override: - assert ( - in_inference_mode - and inference_context.is_dynamic_batching() - and inference_context.config.materialize_only_last_token_logits - ) - self.output_layer.sequence_parallel = True - - if labels is None: - # [s b h] => [b s h] - return logits.transpose(0, 1).contiguous() - - loss = self.compute_language_model_loss(labels, logits) - - return loss diff --git a/megatron/core/models/hybrid/model_chunk_schedule_plan.py b/megatron/core/models/hybrid/model_chunk_schedule_plan.py new file mode 100644 index 00000000000..add48332e7a --- /dev/null +++ b/megatron/core/models/hybrid/model_chunk_schedule_plan.py @@ -0,0 +1,224 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Schedule-plan classes for HybridStack-based decoders. + +These extend the GPT-side ``TransformerLayerSchedulePlan`` / +``TransformerModelChunkSchedulePlan`` with the per-layer ``layer_type`` symbol +that HybridStack assigns to each entry of its ``layer_type_list`` (including +bracketed groups like ``[*-]``). The base classes remain GPT-only; this module +adds the hybrid-specific dispatch into ``build_hybrid_stack_callables`` and uses +``HybridStackNode`` so the schedule node's free-input policy can diverge from +the GPT default. +""" + +from contextlib import nullcontext + +from megatron.core.models.common.model_chunk_schedule_plan import ( + TransformerLayerSchedulePlan, + TransformerModelChunkSchedulePlan, +) +from megatron.core.models.gpt.fine_grained_callables import ( + PostProcessNode, + PreProcessNode, + weak_method, +) +from megatron.core.transformer.module import float16_to_fp32 +from megatron.core.transformer.transformer_layer import make_viewless_tensor + + +class HybridPreProcessNode(PreProcessNode): + """``PreProcessNode`` that calls ``HybridModel._preprocess``. + + Mirrors the GPT counterpart but takes a HybridModel rather than a GPTModel + so the EP-overlap schedule plan does not cross-import a GPT-named class + when scheduling a hybrid model. Behavior matches: ``_preprocess`` returns + the same 6-tuple shape ``(decoder_input, rotary_pos_emb, rotary_pos_cos, + rotary_pos_sin, sequence_len_offset, padding_mask)`` and the chunk_state + fields populated here line up with the slots downstream layer nodes read. + """ + + def __init__(self, hybrid_model, chunk_state, event, stream): + # Bypass ``PreProcessNode.__init__`` to avoid binding to a + # ``gpt_model``-named attribute; reuse the underlying ScheduleNode. + super(PreProcessNode, self).__init__( + weak_method(self.forward_impl), stream, event, name="pre_process" + ) + self.hybrid_model = hybrid_model + self.chunk_state = chunk_state + + def forward_impl(self): + if not self.hybrid_model.pre_process: + self.chunk_state.decoder_input = self.hybrid_model.decoder.input_tensor + ( + decoder_input, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + padding_mask, + ) = self.hybrid_model._preprocess( + input_ids=self.chunk_state.input_ids, + position_ids=self.chunk_state.position_ids, + decoder_input=self.chunk_state.decoder_input, + packed_seq_params=self.chunk_state.packed_seq_params, + padding_mask=self.chunk_state.padding_mask, + ) + + self.chunk_state.decoder_input = decoder_input + self.chunk_state.rotary_pos_emb = rotary_pos_emb + self.chunk_state.rotary_pos_cos = rotary_pos_cos + self.chunk_state.rotary_pos_sin = rotary_pos_sin + self.chunk_state.sequence_len_offset = sequence_len_offset + self.chunk_state.padding_mask = padding_mask + return decoder_input + + +class HybridPostProcessNode(PostProcessNode): + """``PostProcessNode`` that calls ``HybridModel._postprocess``. + + Mirrors the GPT counterpart. Skips MTP inside ``_postprocess`` (sets + ``mtp_in_postprocess=False``) because the EP-overlap schedule plan handles + MTP as separate layer nodes in the same chunk plan. + """ + + def __init__(self, hybrid_model, chunk_state, event, stream): + super(PostProcessNode, self).__init__( + weak_method(self.forward_impl), stream, event, name="post_process" + ) + self.hybrid_model = hybrid_model + self.chunk_state = chunk_state + + def forward_impl(self, hidden_states): + empty_decoder = len(self.hybrid_model.decoder.layers) == 0 + layer_norm = getattr(self.hybrid_model.decoder, "final_layernorm", None) or getattr( + self.hybrid_model.decoder, "final_norm", None + ) + if not self.hybrid_model.config.mtp_num_layers and empty_decoder and layer_norm: + hidden_states = layer_norm(hidden_states) + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + + loss = self.hybrid_model._postprocess( + hidden_states=hidden_states, + input_ids=self.chunk_state.input_ids, + position_ids=self.chunk_state.position_ids, + labels=self.chunk_state.labels, + decoder_input=self.chunk_state.decoder_input, + rotary_pos_emb=self.chunk_state.rotary_pos_emb, + rotary_pos_cos=self.chunk_state.rotary_pos_cos, + rotary_pos_sin=self.chunk_state.rotary_pos_sin, + mtp_in_postprocess=False, + loss_mask=self.chunk_state.loss_mask, + attention_mask=self.chunk_state.attention_mask, + packed_seq_params=self.chunk_state.packed_seq_params, + sequence_len_offset=self.chunk_state.sequence_len_offset, + runtime_gather_output=self.chunk_state.runtime_gather_output, + extra_block_kwargs=self.chunk_state.extra_block_kwargs, + ) + return float16_to_fp32(loss) + + +class HybridStackSchedulePlan(TransformerLayerSchedulePlan): + """Per-layer schedule plan for HybridStack decoders. + + Adds the ``layer_type`` extra-arg propagation; routes through + ``build_hybrid_stack_callables`` when ``layer_type`` is set (i.e. the layer + is a HybridStack entry, possibly a bracketed group); falls back to the GPT + path for plain TransformerLayer / MTP layers when ``layer_type`` is None. + """ + + def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_args=None): + if extra_args is None: + extra_args = {} + self.layer_type = extra_args.get("layer_type", None) + super().__init__(layer, event, chunk_state, comp_stream, comm_stream, extra_args) + + def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args): + if self.layer_type is None: + return super()._build_callable_nodes(event, comp_stream, comm_stream, extra_args) + + # Hybrid grouped path. Imports are local because hybrid pulls in TE / SSM + # extensions that we don't want to load when only the GPT path is used. + from megatron.core.models.hybrid.fine_grained_callables import ( + HybridStackNode, + build_hybrid_stack_callables, + ) + from megatron.core.pipeline_parallel.utils import NoopScheduleNode + + fwd_callables, bwd_dw_callable_map, is_moe, num_local_experts = ( + build_hybrid_stack_callables(self.layer, layer_type=self.layer_type) + ) + + extra_args["config"] = self.layer.config + extra_args["is_moe"] = is_moe + extra_args["num_local_experts"] = num_local_experts + extra_args["delay_wgrad_compute"] = self.layer.config.delay_wgrad_compute + extra_args["is_mtp"] = False + + def create_node(stream, module, name): + bwd_dw_callables = bwd_dw_callable_map.get(name, None) + node_extra_args = dict(extra_args) + if bwd_dw_callables is None: + node_extra_args["delay_wgrad_compute"] = False + return HybridStackNode( + stream, + event, + self.layer_state, + self.chunk_state, + module, + name=name, + bwd_dw_callables=bwd_dw_callables, + extra_args=node_extra_args, + ) + + ( + attn_module, + moe_dispatch_module, + mlp_module, + moe_combine_module, + mtp_post_process_module, + ) = fwd_callables + + self.attn = create_node(comp_stream, attn_module, "attn") + self.mlp = create_node(comp_stream, mlp_module, "mlp") + if is_moe: + self.moe_dispatch = create_node(comm_stream, moe_dispatch_module, "moe_dispatch") + self.moe_combine = create_node(comm_stream, moe_combine_module, "moe_combine") + else: + self.moe_dispatch = NoopScheduleNode() + self.moe_combine = NoopScheduleNode() + + # HybridStack groups never carry an MTP terminal, so mtp_post_process is + # always a no-op here. + self.mtp_post_process = NoopScheduleNode() + + def get_fp8_context(self): + # Grouped hybrid layers (and inferred-layer-type entries that point at + # a HybridStack rather than a plain TransformerLayer) don't have a + # ``layer_number`` we can hand to ``get_fp8_context``; the inner layers + # manage their own per-layer fp8 context inside the hybrid callables. + if self.layer_type is not None or not hasattr(self.layer, "layer_number"): + return nullcontext() + return super().get_fp8_context() + + +class HybridStackModelChunkSchedulePlan(TransformerModelChunkSchedulePlan): + """Model-chunk schedule plan that builds ``HybridStackSchedulePlan`` layer plans. + + Threads HybridStack's ``layer_type_list[layer_idx]`` symbol into each + layer plan's ``extra_args`` so the per-layer plan can dispatch grouped + layers correctly. Ordinary GPT/MTP layers (no ``layer_type_list``) + default to ``layer_type=None`` and follow the GPT path. + """ + + LAYER_SCHEDULE_PLAN_CLASS = HybridStackSchedulePlan + PRE_PROCESS_NODE_CLASS = HybridPreProcessNode + POST_PROCESS_NODE_CLASS = HybridPostProcessNode + + def _extra_args_for_layer(self, module, layer_idx, num_layers): + extra_args = super()._extra_args_for_layer(module, layer_idx, num_layers) + extra_args["layer_type"] = ( + module.layer_type_list[layer_idx] if hasattr(module, "layer_type_list") else None + ) + return extra_args From 2be1c29bb8abfcbf628fbdceaa6f11948290d12c Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Wed, 6 May 2026 13:05:27 -0700 Subject: [PATCH 04/16] HybridModel._postprocess: guard self.mtp on self.mtp_process The MTP forward block in _postprocess was only gated on the kwarg mtp_in_postprocess and the inference / spec-decode flags; the eager forward path (which now delegates to _postprocess with mtp_in_postprocess=True) hits AttributeError: 'HybridModel' object has no attribute 'mtp' on models built without an MTP block. The eager forward had this guard inline before its body was lifted into _postprocess; reinstate it on the method itself so both the eager and EP-overlap paths skip the call when no MTP block is configured. Co-Authored-By: Claude Opus 4.7 (1M context) --- megatron/core/models/hybrid/hybrid_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/megatron/core/models/hybrid/hybrid_model.py b/megatron/core/models/hybrid/hybrid_model.py index 21deb431c59..e2d5b0806bb 100644 --- a/megatron/core/models/hybrid/hybrid_model.py +++ b/megatron/core/models/hybrid/hybrid_model.py @@ -454,7 +454,11 @@ def _postprocess( and inference_context.num_speculative_tokens > 0 ) - if mtp_in_postprocess and not (in_inference_mode or is_spec_decode): + if ( + mtp_in_postprocess + and self.mtp_process + and not (in_inference_mode or is_spec_decode) + ): hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, From da436ac47f61b2f0e94ee39083c13f72ddec609f Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Wed, 6 May 2026 13:26:55 -0700 Subject: [PATCH 05/16] Extract common schedule-plan helpers to core/models/common/utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move the model-agnostic schedule-plan pieces — weak_method, should_free_input, TransformerLayerState, TransformerLayerNode, _BackwardDWWrapper, PreProcessNode, PostProcessNode — out of core/models/gpt/fine_grained_callables.py and into a new core/models/common/utils.py. The common chunk-schedule-plan module and the hybrid path now import from common/utils.py instead of crossing into gpt/. The GPT-specific TransformerLayer / MTPLayer callable builders stay in gpt/fine_grained_callables.py and re-export the moved names so existing imports keep working. Pre/PostProcessNode now use a generic ``model`` attribute (was ``gpt_model``); they call ``model._preprocess`` / ``model._postprocess`` which works for any model that exposes those methods. With the rename HybridModel is just another consumer of the same nodes — the dedicated HybridPreProcessNode / HybridPostProcessNode subclasses are no longer needed and are deleted, along with the PRE_PROCESS_NODE_CLASS / POST_PROCESS_NODE_CLASS overrides on HybridStackModelChunkSchedulePlan. transformer/module.py also imports _BackwardDWWrapper from common/utils.py now; the previous import path crossed from megatron/core/transformer into megatron/core/models/gpt/, which is the cross-cut the reviewer flagged. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../common/model_chunk_schedule_plan.py | 10 +- megatron/core/models/common/utils.py | 366 +++++++++++++++ .../core/models/gpt/fine_grained_callables.py | 422 ++---------------- .../models/hybrid/fine_grained_callables.py | 2 +- .../hybrid/model_chunk_schedule_plan.py | 115 +---- megatron/core/transformer/module.py | 2 +- 6 files changed, 409 insertions(+), 508 deletions(-) create mode 100644 megatron/core/models/common/utils.py diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index 0430987fcc5..ec08598743e 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -70,7 +70,7 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar The event and chunk_state are binded to the TransformerModelChunkSchedulePlan and shared across all layers in the model chunk. """ - from megatron.core.models.gpt.fine_grained_callables import TransformerLayerState + from megatron.core.models.common.utils import TransformerLayerState self.config = layer.config self.layer_state = TransformerLayerState() @@ -111,10 +111,8 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args): Builds the callable nodes for the transformer/mtp layer: attn, mlp, moe_dispatch and moe_combine, and mtp_post_process. """ - from megatron.core.models.gpt.fine_grained_callables import ( - TransformerLayerNode, - build_layer_callables, - ) + from megatron.core.models.common.utils import TransformerLayerNode + from megatron.core.models.gpt.fine_grained_callables import build_layer_callables from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer @@ -322,7 +320,7 @@ def __init__( Returns: The model chunk schedule plan. """ - from megatron.core.models.gpt.fine_grained_callables import PostProcessNode, PreProcessNode + from megatron.core.models.common.utils import PostProcessNode, PreProcessNode pre_process_cls = self.PRE_PROCESS_NODE_CLASS or PreProcessNode post_process_cls = self.POST_PROCESS_NODE_CLASS or PostProcessNode diff --git a/megatron/core/models/common/utils.py b/megatron/core/models/common/utils.py new file mode 100644 index 00000000000..531ff879599 --- /dev/null +++ b/megatron/core/models/common/utils.py @@ -0,0 +1,366 @@ +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Schedule-plan helpers shared by GPTModel and HybridModel. + +These pieces used to live in ``core/models/gpt/fine_grained_callables.py`` and +were imported by ``core/models/common/model_chunk_schedule_plan.py`` and the +hybrid schedule plan via that path. They are model-agnostic in practice — the +``Pre/PostProcessNode`` classes call the model's ``_preprocess`` / +``_postprocess`` methods and don't otherwise care which model implements +them — so they live here and the GPT module re-exports them for backward +compatibility with existing imports. +""" + +import weakref +from functools import partial +from typing import Callable + +import torch + +from megatron.core.pipeline_parallel.utils import ScheduleNode, make_viewless +from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.module import GraphableMegatronModule, float16_to_fp32 +from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor +from megatron.core.utils import internal_api, nvtx_range_pop, nvtx_range_push + + +def weak_method(method): + """Wrap ``method`` in a weakref-keyed dispatcher to break refcycles. + + ``ScheduleNode`` keeps a reference to the bound forward / backward functions + of every node in the plan; using a strong reference would keep the layer + plan (and the model chunk through it) alive after the iteration completes. + The ``weakref.WeakMethod`` lets the schedule plan be torn down between + iterations without manual ``del`` chains. + """ + method_ref = weakref.WeakMethod(method) + del method + + def wrapped_func(*args, **kwarg): + return method_ref()(*args, **kwarg) + + return wrapped_func + + +@internal_api +def should_free_input(name, is_moe, config, num_local_experts): + """Whether the schedule node named ``name`` can free its input after forward. + + The schedule decomposes a transformer layer into ``attn``, ``moe_dispatch``, + ``mlp``, and ``moe_combine`` nodes; the inputs to some of those nodes are + not needed in backward and can be released early to lower peak activation + memory. Dense layers and the ``attn`` node always need their input retained + (the attention residual flows through the post-MLP BDA). + + Args: + name: Schedule node name. + is_moe: True for MoE layers; dense layers always retain inputs. + config: ``TransformerConfig`` for the layer. + num_local_experts: Local expert count on this rank (None for dense). + + Returns: + True iff the named node may free its input after forward. + """ + # For dense layers [attn, fake, mlp, fake], the input is needed during backward pass + if not is_moe: + return False + enable_deepep = ( + config.moe_token_dispatcher_type == "flex" + and config.moe_flex_dispatcher_backend == "deepep" + ) + enable_hybridep = ( + config.moe_token_dispatcher_type == "flex" + and config.moe_flex_dispatcher_backend == "hybridep" + ) + # Define which nodes should free input memory. + # Since we split the computing graph into multiple nodes, we can manually control + # when and how to free the input memory. + # The input and output of A2A are not needed anymore after the forward pass, + # so we can free the input memory after the forward pass. + + # When low precision fp8/4 is enabled, the casted tensors are saved and the + # original bf16 tensors are safe to be freed. + free_mlp = config.fp8 is not None or config.fp4 is not None + if not free_mlp: + # AlltoAll dispatcher with local_num_experts=1 and HybridEP both use identity + # operation for `dispatch_postprocess`, hence the mlp inputs will be directly + # passed to GroupedGemm and should be saved for backward pass. + free_mlp = num_local_experts > 1 or config.moe_token_dispatcher_type != "alltoall" + free_mlp = free_mlp and not enable_hybridep + + free_input_nodes = { + "mlp": free_mlp, + "moe_combine": True, + # For non-DeepEP and non-HybridEP dispatcher mode, the input is the un-dispatched + # tokens and probs before dispatch A2A and it's not needed anymore after the + # forward pass. For DeepEP and HybridEP dispatcher mode, they are both needed in + # backward pass and cannot be freed. + # If moe_preprocess is in cuda graph scope, tokens and probs are fixed size + # tensors, so they cannot be freed. + "moe_dispatch": not (enable_deepep or enable_hybridep) + and (CudaGraphScope.moe_preprocess not in config.cuda_graph_scope), + } + + return free_input_nodes.get(name, False) + + +class TransformerLayerState: + """State shared between the schedule nodes that come from one logical layer. + + Empty placeholder; nodes attach their own attributes (residual, dispatched + probs, shared-expert outputs) for downstream nodes in the same layer to + consume. Kept as a real class so weakrefs work uniformly. + """ + + pass + + +class PreProcessNode(ScheduleNode): + """Run the model's ``_preprocess`` (embedding + rotary + padding mask). + + The schedule plan wraps a model that exposes a ``_preprocess`` method + returning the canonical 6-tuple ``(decoder_input, rotary_pos_emb, + rotary_pos_cos, rotary_pos_sin, sequence_len_offset, padding_mask)`` + (slots a given model doesn't use are returned as ``None``). The chunk + state is mutated in-place so layer nodes can read the same fields by + name. + """ + + def __init__(self, model, chunk_state, event, stream): + super().__init__(weak_method(self.forward_impl), stream, event, name="pre_process") + self.model = model + self.chunk_state = chunk_state + + def forward_impl(self): + if not self.model.pre_process: + self.chunk_state.decoder_input = self.model.decoder.input_tensor + ( + decoder_input, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + padding_mask, + ) = self.model._preprocess( + input_ids=self.chunk_state.input_ids, + position_ids=self.chunk_state.position_ids, + decoder_input=self.chunk_state.decoder_input, + packed_seq_params=self.chunk_state.packed_seq_params, + padding_mask=self.chunk_state.padding_mask, + ) + + self.chunk_state.decoder_input = decoder_input + self.chunk_state.rotary_pos_emb = rotary_pos_emb + self.chunk_state.rotary_pos_cos = rotary_pos_cos + self.chunk_state.rotary_pos_sin = rotary_pos_sin + self.chunk_state.sequence_len_offset = sequence_len_offset + self.chunk_state.padding_mask = padding_mask + return decoder_input + + +class PostProcessNode(ScheduleNode): + """Run the model's ``_postprocess`` (final norm, output layer, loss). + + Calls ``_postprocess`` with ``mtp_in_postprocess=False`` because the + schedule plan handles MTP layers as sibling layer nodes inside the same + chunk; the model's MTP block is not invoked here. The optional final + layernorm — applied only when this rank holds an empty decoder shard + (early stage of pipeline parallel) — is handled here so the chunk plan + does not need a separate node for it. + """ + + def __init__(self, model, chunk_state, event, stream): + super().__init__(weak_method(self.forward_impl), stream, event, name="post_process") + self.model = model + self.chunk_state = chunk_state + + def forward_impl(self, hidden_states): + empty_decoder = len(self.model.decoder.layers) == 0 + layer_norm = self.model.decoder.final_layernorm + if not self.model.config.mtp_num_layers and empty_decoder and layer_norm: + hidden_states = layer_norm(hidden_states) + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + + loss = self.model._postprocess( + hidden_states=hidden_states, + input_ids=self.chunk_state.input_ids, + position_ids=self.chunk_state.position_ids, + labels=self.chunk_state.labels, + decoder_input=self.chunk_state.decoder_input, + rotary_pos_emb=self.chunk_state.rotary_pos_emb, + rotary_pos_cos=self.chunk_state.rotary_pos_cos, + rotary_pos_sin=self.chunk_state.rotary_pos_sin, + mtp_in_postprocess=False, + loss_mask=self.chunk_state.loss_mask, + attention_mask=self.chunk_state.attention_mask, + packed_seq_params=self.chunk_state.packed_seq_params, + sequence_len_offset=self.chunk_state.sequence_len_offset, + runtime_gather_output=self.chunk_state.runtime_gather_output, + extra_block_kwargs=self.chunk_state.extra_block_kwargs, + ) + + # combined-1F1B currently expects fp32 loss output. + return float16_to_fp32(loss) + + +class TransformerLayerNode(ScheduleNode): + """Schedule node for one slot of a fine-grained transformer layer plan. + + Each transformer layer is decomposed into ``attn``, ``moe_dispatch``, + ``mlp``, and ``moe_combine`` slots; this class is the scheduler-side + handle for one slot. It owns the slot's stream / event, the per-slot + ``free_input`` policy, and the optional delayed weight-gradient hook. + Subclasses override ``_resolve_free_input`` to specialize the policy + (HybridStackNode does this for grouped layers). + """ + + def __init__( + self, + stream, + event, + layer_state, + chunk_state, + submodule, + name="default", + bwd_dw_callables=None, + extra_args={}, + ): + config = extra_args.get("config", None) + assert config is not None, "model config must be passed to TransformerLayerNode." + is_moe = extra_args.get("is_moe", False) + num_local_experts = extra_args.get("num_local_experts", None) + free_input = self._resolve_free_input(name, is_moe, config, num_local_experts) + self.delay_wgrad_compute = extra_args.get("delay_wgrad_compute", False) + + super().__init__( + weak_method(self.forward_impl), + stream, + event, + weak_method(self.backward_impl), + free_input=free_input, + name=name, + ) + self.layer_state = layer_state + self.chunk_state = chunk_state + self.submodule = submodule + self.detached = tuple() + self.before_detached = tuple() + self.is_mtp = extra_args.get("is_mtp", False) + + self.is_first_layer = extra_args.get("is_first_layer", False) + self.is_last_layer = extra_args.get("is_last_layer", False) + + self.bwd_dw_callables = [] + if bwd_dw_callables is not None: + self.bwd_dw_callables = ( + bwd_dw_callables if isinstance(bwd_dw_callables, list) else [bwd_dw_callables] + ) + + @staticmethod + def _resolve_free_input(name, is_moe, config, num_local_experts): + """Free-input policy hook. Subclasses override to specialize.""" + return should_free_input(name, is_moe, config, num_local_experts) + + def detach(self, t): + """Detach a tensor and remember it for backward through the schedule node.""" + detached = make_viewless(t).detach() + detached.requires_grad = t.requires_grad + self.before_detached = self.before_detached + (t,) + self.detached = self.detached + (detached,) + return detached + + def forward_impl(self, *args): + """Invoke the slot's submodule forward.""" + return self.submodule(self, *args) + + def backward_impl(self, outputs, output_grad): + """Run the slot's backward, holding output_grads when wgrad is delayed.""" + detached_grad = tuple([e.grad for e in self.detached]) + grads = output_grad + detached_grad + self.default_backward_func(outputs + self.before_detached, grads) + # Release the output grad memory after backward finishes, except when + # delay_wgrad_compute is enabled — then the grads are kept until every + # registered ``backward_dw`` callable has run. + if self.delay_wgrad_compute: + self.output_grads = grads + self.delay_grads_release = len(self.bwd_dw_callables) > 0 + + return grads + + def backward_dw(self): + """Run the slot's delayed weight-gradient callables on the slot's stream.""" + if not self.delay_wgrad_compute: + return + if isinstance(self.stream, Callable): + self.stream = self.stream() + with torch.cuda.stream(self.stream): + nvtx_msg = f"{self.name} wgrad" + nvtx_range_push(nvtx_msg) + for module in self.bwd_dw_callables: + module.backward_dw() + nvtx_range_pop(nvtx_msg) + + # The output grad memory is last used in wgrad compute; safe to release now. + assert self.delay_grads_release, "output grad memory should be valid before wgrad." + if self.manual_release_grads: + for tensor in self.output_grads: + tensor.untyped_storage().resize_(0) + self.output_grads = None + + self.bwd_dw_callables = None + + def __del__(self): + # Release references early to help avoid leaks across iterations. + self.before_detached = None + self.detached = None + self.layer_state = None + self.chunk_state = None + self.submodule = None + + +class _BackwardDWWrapper: + """Backward weight-gradient wrapper for the ``attn`` slot of a transformer layer. + + Runs the layer's ``self_attention.backward_dw`` plus, on MoE layers, the + shared-expert ``backward_dw``; coordinates with the cuda-graph wgrad + capture (``set_graphed_backward_dw_callable``) so that scopes covered by + the graph are not re-run eagerly. Used when + ``overlap_moe_expert_parallel_comm`` and ``delay_wgrad_compute`` are both + enabled. + """ + + def __init__(self, layer): + assert isinstance( + layer, GraphableMegatronModule + ), "cuda graphed ep overlap only supports GraphableMegatronModule." + assert isinstance( + layer, TransformerLayer + ), "cuda graphed ep overlap only supports TransformerLayer for now." + self.layer = layer + self.graphed_backward_dw_callable = None + self.attn_dw_callable = layer.self_attention.backward_dw + if layer.is_moe_layer: + self.shared_expert_dw_callable = partial( + layer.mlp.backward_dw, routed_experts=False, shared_experts=True + ) + else: + self.shared_expert_dw_callable = None + self.cuda_graph_scope = layer.config.cuda_graph_scope + + def backward_dw(self): + is_replay = hasattr(self.layer, 'cuda_graphs') and self.layer.cuda_graphs + if self.shared_expert_dw_callable is not None and ( + not is_replay or CudaGraphScope.moe_router not in self.cuda_graph_scope + ): + self.shared_expert_dw_callable() + if not is_replay or CudaGraphScope.attn not in self.cuda_graph_scope: + self.attn_dw_callable() + if is_replay and self.graphed_backward_dw_callable is not None: + self.graphed_backward_dw_callable() + self.layer = None + + def set_graphed_backward_dw_callable(self, graphed_backward_dw_callable): + """Plug the cuda-graph backward wgrad replay callable.""" + self.graphed_backward_dw_callable = graphed_backward_dw_callable diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index d96743228d3..bd51ce698e6 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -1,21 +1,28 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -import weakref from contextlib import nullcontext from functools import partial -from typing import Callable, Optional +from typing import Optional import torch from torch import Tensor from megatron.core import tensor_parallel +from megatron.core.models.common.utils import ( + PostProcessNode, + PreProcessNode, + TransformerLayerNode, + TransformerLayerState, + _BackwardDWWrapper, + should_free_input, + weak_method, +) from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.pipeline_parallel.utils import ScheduleNode, make_viewless -from megatron.core.transformer.enums import CudaGraphScope -from megatron.core.transformer.module import GraphableMegatronModule, float16_to_fp32 +from megatron.core.pipeline_parallel.utils import ScheduleNode +from megatron.core.transformer.module import GraphableMegatronModule from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionLayer, @@ -23,396 +30,23 @@ ) from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor from megatron.core.typed_torch import apply_module, copy_signature -from megatron.core.utils import internal_api, nvtx_range_pop, nvtx_range_push - - -def weak_method(method): - """Creates a weak reference to a method to prevent circular references. - - This function creates a weak reference to a method and returns a wrapper function - that calls the method when invoked. This helps prevent memory leaks from circular - references. - """ - method_ref = weakref.WeakMethod(method) - del method - - def wrapped_func(*args, **kwarg): - # nonlocal object_ref - return method_ref()(*args, **kwarg) - - return wrapped_func - - -@internal_api -def should_free_input(name, is_moe, config, num_local_experts): - """Determine if the node should free its input memory. - - Args: - name: Node name - is_moe: Whether it's a MoE model - config: TransformerConfig object - num_local_experts: Number of local experts in MoE module - - Returns: - bool: Whether to free input memory - """ - # For dense layers [attn, fake, mlp, fake], the input is needed during backward pass - if not is_moe: - return False - enable_deepep = ( - config.moe_token_dispatcher_type == "flex" - and config.moe_flex_dispatcher_backend == "deepep" - ) - enable_hybridep = ( - config.moe_token_dispatcher_type == "flex" - and config.moe_flex_dispatcher_backend == "hybridep" - ) - # Define which nodes should free input memory - # Since we split the computing graph into multiple nodes, we can manually control - # when and how to free the input memory. - # The input and output of A2A are not needed anymore after the forward pass, - # so we can free the input memory after the forward pass. - - # When low precision fp8/4 is enabled, the casted tensors are saved and the - # original bf16 tensors are safe to be freed. - free_mlp = config.fp8 is not None or config.fp4 is not None - if not free_mlp: - # AlltoAll dispatcher with local_num_experts=1 and HybridEP both use identity - # operation for `dispatch_postprocess`, hence the mlp inputs will be directly - # passed to GroupedGemm and should be saved for backward pass. - free_mlp = num_local_experts > 1 or config.moe_token_dispatcher_type != "alltoall" - free_mlp = free_mlp and not enable_hybridep - - free_input_nodes = { - "mlp": free_mlp, - "moe_combine": True, - # For non-DeepEP and non-HybridEP dispatcher mode, the input is the un-dispatched tokens - # and probs before dispatch A2A and it's not needed anymore after the forward pass - # For DeepEP and HybridEP dispatcher mode, they are both needed in backward pass - # and cannot be freed. - # If moe_preprocess is in cuda graph scope, tokens and probs are fixed size tensors, - # so they cannot be freed. - "moe_dispatch": not (enable_deepep or enable_hybridep) - and (CudaGraphScope.moe_preprocess not in config.cuda_graph_scope), - } - - return free_input_nodes.get(name, False) - - -class TransformerLayerState: - """State shared within a transformer layer. - - This class holds state that is shared between different nodes - within a transformer layer. - """ - - pass - - -class PreProcessNode(ScheduleNode): - """Node responsible for preprocessing operations in the model. - - This node handles embedding and rotary positional embedding computations - before the main transformer layers. - """ - - def __init__(self, gpt_model, chunk_state, event, stream): - """Initializes a preprocessing node. - - Args: - gpt_model: The GPT model instance. - chunk_state (TransformerChunkState): State shared within a chunk - event: CUDA event for synchronization. - stream: CUDA stream for execution. - """ - super().__init__(weak_method(self.forward_impl), stream, event, name="pre_process") - self.gpt_model = gpt_model - self.chunk_state = chunk_state - - def forward_impl(self): - """forward pass for pre-processing. - - This method handles: - 1. Decoder embedding computation - 2. Rotary positional embedding computation - 3. Sequence length offset computation for flash decoding - - Returns: - The processed decoder input tensor. - """ - # Get decoder input - if not self.gpt_model.pre_process: - self.chunk_state.decoder_input = self.gpt_model.decoder.input_tensor - # Run GPTModel._preprocess - ( - decoder_input, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - sequence_len_offset, - padding_mask, - ) = self.gpt_model._preprocess( - input_ids=self.chunk_state.input_ids, - position_ids=self.chunk_state.position_ids, - decoder_input=self.chunk_state.decoder_input, - packed_seq_params=self.chunk_state.packed_seq_params, - padding_mask=self.chunk_state.padding_mask, - ) - - # Saved for later use - self.chunk_state.decoder_input = decoder_input - self.chunk_state.rotary_pos_emb = rotary_pos_emb - self.chunk_state.rotary_pos_cos = rotary_pos_cos - self.chunk_state.rotary_pos_sin = rotary_pos_sin - self.chunk_state.sequence_len_offset = sequence_len_offset - self.chunk_state.padding_mask = padding_mask - return decoder_input - -class PostProcessNode(ScheduleNode): - """Node responsible for postprocessing operations in the model. - - This node handles final layer normalization and output layer computation - after the main transformer layers. - """ - - def __init__(self, gpt_model, chunk_state, event, stream): - """Initializes a postprocessing node. - - Args: - gpt_model: The GPT model instance. - chunk_state (TransformerChunkState): State shared within a chunk - event: CUDA event for synchronization. - stream: CUDA stream for execution. - """ - super().__init__(weak_method(self.forward_impl), stream, event, name="post_process") - self.gpt_model = gpt_model - self.chunk_state = chunk_state - - def forward_impl(self, hidden_states): - """Implements the forward pass for postprocessing. - - This method handles: - 1. Output layer computation - 2. Loss computation if labels are provided - - Args: - hidden_states: The hidden states from the transformer layers. - - Returns: - The logits or loss depending on whether labels are provided. - """ - - empty_decoder = len(self.gpt_model.decoder.layers) == 0 - layer_norm = self.gpt_model.decoder.final_layernorm - if not self.gpt_model.config.mtp_num_layers and empty_decoder and layer_norm: - hidden_states = layer_norm(hidden_states) - hidden_states = make_viewless_tensor( - inp=hidden_states, requires_grad=True, keep_graph=True - ) - - # Run GPTModel._postprocess - loss = self.gpt_model._postprocess( - hidden_states=hidden_states, - input_ids=self.chunk_state.input_ids, - position_ids=self.chunk_state.position_ids, - labels=self.chunk_state.labels, - decoder_input=self.chunk_state.decoder_input, - rotary_pos_emb=self.chunk_state.rotary_pos_emb, - rotary_pos_cos=self.chunk_state.rotary_pos_cos, - rotary_pos_sin=self.chunk_state.rotary_pos_sin, - mtp_in_postprocess=False, - loss_mask=self.chunk_state.loss_mask, - attention_mask=self.chunk_state.attention_mask, - packed_seq_params=self.chunk_state.packed_seq_params, - sequence_len_offset=self.chunk_state.sequence_len_offset, - runtime_gather_output=self.chunk_state.runtime_gather_output, - extra_block_kwargs=self.chunk_state.extra_block_kwargs, - ) - - # For now, 1f1b only supports fp16 module - return float16_to_fp32(loss) - - -class TransformerLayerNode(ScheduleNode): - """Base class for transformer layer computation nodes. - - This class provides common functionality for different types of - transformer layer nodes (attention, MLP, etc.) - """ - - def __init__( - self, - stream, - event, - layer_state, - chunk_state, - submodule, - name="default", - bwd_dw_callables=None, - extra_args={}, - ): - """Initialize a transformer layer node. - - Args: - stream (torch.cuda.Stream): CUDA stream for execution - event (torch.cuda.Event): Synchronization event - layer_state (TransformerLayerState): State shared within a layer - chunk_state (TransformerChunkState): State shared within a chunk - submodule (function): The submodule contain forward and dw function - it's the per_batch_state_context, o.w. nullcontext - name (str): Node name, also used to determine memory strategy - bwd_dw_callables (list): List of weight gradient functions for the layer. - extra_args (dict): Extra arguments for the node: is_moe, config. - """ - # determine whether to free input memory - config = extra_args.get("config", None) - assert config is not None, "model config must be passed to TransformerLayerNode." - is_moe = extra_args.get("is_moe", False) - num_local_experts = extra_args.get("num_local_experts", None) - free_input = self._resolve_free_input(name, is_moe, config, num_local_experts) - self.delay_wgrad_compute = extra_args.get("delay_wgrad_compute", False) - - super().__init__( - weak_method(self.forward_impl), - stream, - event, - weak_method(self.backward_impl), - free_input=free_input, - name=name, - ) - self.layer_state = layer_state - self.chunk_state = chunk_state - self.submodule = submodule - self.detached = tuple() - self.before_detached = tuple() - self.is_mtp = extra_args.get("is_mtp", False) - - # Create flags to indicate first and last layer - self.is_first_layer = extra_args.get("is_first_layer", False) - self.is_last_layer = extra_args.get("is_last_layer", False) - - # Initialize list to store registered dw callables - self.bwd_dw_callables = [] - if bwd_dw_callables is not None: - self.bwd_dw_callables = ( - bwd_dw_callables if isinstance(bwd_dw_callables, list) else [bwd_dw_callables] - ) - - @staticmethod - def _resolve_free_input(name, is_moe, config, num_local_experts): - """Free-input policy hook. Subclasses override to specialize. - - Default delegates to module-level ``should_free_input`` (the GPT MoE - EP-overlap policy). - """ - return should_free_input(name, is_moe, config, num_local_experts) - - def detach(self, t): - """Detaches a tensor and stores it for backward computation.""" - detached = make_viewless(t).detach() - detached.requires_grad = t.requires_grad - self.before_detached = self.before_detached + (t,) - self.detached = self.detached + (detached,) - return detached - - def forward_impl(self, *args): - """Calls the submodule as the forward pass.""" - return self.submodule(self, *args) - - def backward_impl(self, outputs, output_grad): - """Implements the backward pass for the transformer layer node.""" - detached_grad = tuple([e.grad for e in self.detached]) - grads = output_grad + detached_grad - self.default_backward_func(outputs + self.before_detached, grads) - # release the output grad memory after backward finishes, - # except when delay_wgrad_comptue is enabled, the grad should be - # kept until all modules' backward_dw has been invoked. - if self.delay_wgrad_compute: - self.output_grads = grads - self.delay_grads_release = len(self.bwd_dw_callables) > 0 - - # return grads for record stream - return grads - - def backward_dw(self): - """Computes the weight gradients for the transformer layer node.""" - if not self.delay_wgrad_compute: - return - if isinstance(self.stream, Callable): - self.stream = self.stream() - with torch.cuda.stream(self.stream): - nvtx_msg = f"{self.name} wgrad" - nvtx_range_push(nvtx_msg) - for module in self.bwd_dw_callables: - module.backward_dw() - nvtx_range_pop(nvtx_msg) - - # the output grad memory is last used in wgrad compute, should be safe to release. - assert self.delay_grads_release, "output grad memory should be valid before wgrad." - if self.manual_release_grads: - for tensor in self.output_grads: - tensor.untyped_storage().resize_(0) - self.output_grads = None - - self.bwd_dw_callables = None - - def __del__(self): - # Release reference as early as possible, this helps avoid memory leak. - self.before_detached = None - self.detached = None - self.layer_state = None - self.chunk_state = None - self.submodule = None - - -class _BackwardDWWrapper: - """Wrapper for managing backward weight gradient computation of attn module. - - This class handles the execution of weight gradient computations for transformer layers, - coordinating between CUDA graphed and non-graphed components. It is used when - overlap_moe_expert_parallel_comm and delay_wgrad_compute are enabled to manage - the delayed weight gradient computation in MoE models. - - The wrapper stores references to the attention and shared expert backward weight gradient - callables, and determines which components should be executed based on whether CUDA graphs - are being replayed and which scopes are covered by the graphs. - """ - - def __init__(self, layer): - assert isinstance( - layer, GraphableMegatronModule - ), "cuda graphed ep overlap only supports GraphableMegatronModule." - assert isinstance( - layer, TransformerLayer - ), "cuda graphed ep overlap only supports TransformerLayer for now." - self.layer = layer - self.graphed_backward_dw_callable = None - self.attn_dw_callable = layer.self_attention.backward_dw - if layer.is_moe_layer: - self.shared_expert_dw_callable = partial( - layer.mlp.backward_dw, routed_experts=False, shared_experts=True - ) - else: - self.shared_expert_dw_callable = None - self.cuda_graph_scope = layer.config.cuda_graph_scope - - def backward_dw(self): - """Execute weight gradients, skipping CUDA graphed components during replay.""" - is_replay = hasattr(self.layer, 'cuda_graphs') and self.layer.cuda_graphs - if self.shared_expert_dw_callable is not None and ( - not is_replay or CudaGraphScope.moe_router not in self.cuda_graph_scope - ): - self.shared_expert_dw_callable() - if not is_replay or CudaGraphScope.attn not in self.cuda_graph_scope: - self.attn_dw_callable() - if is_replay and self.graphed_backward_dw_callable is not None: - self.graphed_backward_dw_callable() - self.layer = None - - def set_graphed_backward_dw_callable(self, graphed_backward_dw_callable): - """Store the CUDA graphed backward weight gradient callable.""" - self.graphed_backward_dw_callable = graphed_backward_dw_callable +# Re-export the model-agnostic schedule-plan helpers so existing imports of +# ``from megatron.core.models.gpt.fine_grained_callables import ...`` keep +# working. The implementations live in ``megatron.core.models.common.utils``; +# only the GPT-specific ``build_*_layer_callables`` builders below stay here. +__all__ = [ + "PostProcessNode", + "PreProcessNode", + "TransformerLayerNode", + "TransformerLayerState", + "_BackwardDWWrapper", + "should_free_input", + "weak_method", + "build_layer_callables", + "build_mtp_layer_callables", + "build_transformer_layer_callables", +] def build_transformer_layer_callables(layer: TransformerLayer): diff --git a/megatron/core/models/hybrid/fine_grained_callables.py b/megatron/core/models/hybrid/fine_grained_callables.py index c08a0956056..2c079cb05a0 100644 --- a/megatron/core/models/hybrid/fine_grained_callables.py +++ b/megatron/core/models/hybrid/fine_grained_callables.py @@ -9,7 +9,7 @@ from megatron.core.enums import Fp8Recipe from megatron.core.fp4_utils import get_fp4_context from megatron.core.fp8_utils import get_fp8_context -from megatron.core.models.gpt.fine_grained_callables import TransformerLayerNode, should_free_input +from megatron.core.models.common.utils import TransformerLayerNode, should_free_input from megatron.core.models.hybrid.hybrid_block import HybridStack from megatron.core.models.hybrid.hybrid_layer_allocation import LayerPatternItem from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols as LayerSymbols diff --git a/megatron/core/models/hybrid/model_chunk_schedule_plan.py b/megatron/core/models/hybrid/model_chunk_schedule_plan.py index add48332e7a..7c34da6219a 100644 --- a/megatron/core/models/hybrid/model_chunk_schedule_plan.py +++ b/megatron/core/models/hybrid/model_chunk_schedule_plan.py @@ -6,9 +6,11 @@ ``TransformerModelChunkSchedulePlan`` with the per-layer ``layer_type`` symbol that HybridStack assigns to each entry of its ``layer_type_list`` (including bracketed groups like ``[*-]``). The base classes remain GPT-only; this module -adds the hybrid-specific dispatch into ``build_hybrid_stack_callables`` and uses -``HybridStackNode`` so the schedule node's free-input policy can diverge from -the GPT default. +adds the hybrid-specific dispatch into ``build_hybrid_stack_callables`` and +uses ``HybridStackNode`` so the schedule node's free-input policy can diverge +from the GPT default. The pre/post-process nodes from +``core.models.common.utils`` are reused as-is — they already call +``model._preprocess`` / ``model._postprocess`` which work on a HybridModel. """ from contextlib import nullcontext @@ -17,106 +19,6 @@ TransformerLayerSchedulePlan, TransformerModelChunkSchedulePlan, ) -from megatron.core.models.gpt.fine_grained_callables import ( - PostProcessNode, - PreProcessNode, - weak_method, -) -from megatron.core.transformer.module import float16_to_fp32 -from megatron.core.transformer.transformer_layer import make_viewless_tensor - - -class HybridPreProcessNode(PreProcessNode): - """``PreProcessNode`` that calls ``HybridModel._preprocess``. - - Mirrors the GPT counterpart but takes a HybridModel rather than a GPTModel - so the EP-overlap schedule plan does not cross-import a GPT-named class - when scheduling a hybrid model. Behavior matches: ``_preprocess`` returns - the same 6-tuple shape ``(decoder_input, rotary_pos_emb, rotary_pos_cos, - rotary_pos_sin, sequence_len_offset, padding_mask)`` and the chunk_state - fields populated here line up with the slots downstream layer nodes read. - """ - - def __init__(self, hybrid_model, chunk_state, event, stream): - # Bypass ``PreProcessNode.__init__`` to avoid binding to a - # ``gpt_model``-named attribute; reuse the underlying ScheduleNode. - super(PreProcessNode, self).__init__( - weak_method(self.forward_impl), stream, event, name="pre_process" - ) - self.hybrid_model = hybrid_model - self.chunk_state = chunk_state - - def forward_impl(self): - if not self.hybrid_model.pre_process: - self.chunk_state.decoder_input = self.hybrid_model.decoder.input_tensor - ( - decoder_input, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - sequence_len_offset, - padding_mask, - ) = self.hybrid_model._preprocess( - input_ids=self.chunk_state.input_ids, - position_ids=self.chunk_state.position_ids, - decoder_input=self.chunk_state.decoder_input, - packed_seq_params=self.chunk_state.packed_seq_params, - padding_mask=self.chunk_state.padding_mask, - ) - - self.chunk_state.decoder_input = decoder_input - self.chunk_state.rotary_pos_emb = rotary_pos_emb - self.chunk_state.rotary_pos_cos = rotary_pos_cos - self.chunk_state.rotary_pos_sin = rotary_pos_sin - self.chunk_state.sequence_len_offset = sequence_len_offset - self.chunk_state.padding_mask = padding_mask - return decoder_input - - -class HybridPostProcessNode(PostProcessNode): - """``PostProcessNode`` that calls ``HybridModel._postprocess``. - - Mirrors the GPT counterpart. Skips MTP inside ``_postprocess`` (sets - ``mtp_in_postprocess=False``) because the EP-overlap schedule plan handles - MTP as separate layer nodes in the same chunk plan. - """ - - def __init__(self, hybrid_model, chunk_state, event, stream): - super(PostProcessNode, self).__init__( - weak_method(self.forward_impl), stream, event, name="post_process" - ) - self.hybrid_model = hybrid_model - self.chunk_state = chunk_state - - def forward_impl(self, hidden_states): - empty_decoder = len(self.hybrid_model.decoder.layers) == 0 - layer_norm = getattr(self.hybrid_model.decoder, "final_layernorm", None) or getattr( - self.hybrid_model.decoder, "final_norm", None - ) - if not self.hybrid_model.config.mtp_num_layers and empty_decoder and layer_norm: - hidden_states = layer_norm(hidden_states) - hidden_states = make_viewless_tensor( - inp=hidden_states, requires_grad=True, keep_graph=True - ) - - loss = self.hybrid_model._postprocess( - hidden_states=hidden_states, - input_ids=self.chunk_state.input_ids, - position_ids=self.chunk_state.position_ids, - labels=self.chunk_state.labels, - decoder_input=self.chunk_state.decoder_input, - rotary_pos_emb=self.chunk_state.rotary_pos_emb, - rotary_pos_cos=self.chunk_state.rotary_pos_cos, - rotary_pos_sin=self.chunk_state.rotary_pos_sin, - mtp_in_postprocess=False, - loss_mask=self.chunk_state.loss_mask, - attention_mask=self.chunk_state.attention_mask, - packed_seq_params=self.chunk_state.packed_seq_params, - sequence_len_offset=self.chunk_state.sequence_len_offset, - runtime_gather_output=self.chunk_state.runtime_gather_output, - extra_block_kwargs=self.chunk_state.extra_block_kwargs, - ) - return float16_to_fp32(loss) class HybridStackSchedulePlan(TransformerLayerSchedulePlan): @@ -209,12 +111,13 @@ class HybridStackModelChunkSchedulePlan(TransformerModelChunkSchedulePlan): Threads HybridStack's ``layer_type_list[layer_idx]`` symbol into each layer plan's ``extra_args`` so the per-layer plan can dispatch grouped layers correctly. Ordinary GPT/MTP layers (no ``layer_type_list``) - default to ``layer_type=None`` and follow the GPT path. + default to ``layer_type=None`` and follow the GPT path. The pre/post + process nodes inherit from the GPT base class — they already dispatch + on ``model._preprocess`` / ``model._postprocess`` which a HybridModel + implements. """ LAYER_SCHEDULE_PLAN_CLASS = HybridStackSchedulePlan - PRE_PROCESS_NODE_CLASS = HybridPreProcessNode - POST_PROCESS_NODE_CLASS = HybridPostProcessNode def _extra_args_for_layer(self, module, layer_idx, num_layers): extra_args = super()._extra_args_for_layer(module, layer_idx, num_layers) diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index 6539ee36105..ae2cb08d035 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -200,7 +200,7 @@ def __init__(self, config: TransformerConfig, vp_stage: Optional[int] = None): def init_backward_dw_wrapper(self): """Initialize the backward_dw_wrapper.""" - from megatron.core.models.gpt.fine_grained_callables import _BackwardDWWrapper + from megatron.core.models.common.utils import _BackwardDWWrapper config = getattr(self, 'config', None) assert config is not None, ( From b3139322d3299cc987cf18e64daef8e9afc511ec Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Fri, 8 May 2026 09:54:57 -0700 Subject: [PATCH 06/16] Rename TransformerLayerState to LayerState; move MTP callables to common MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three follow-up cleanups on the schedule-plan plumbing: * Rename the per-layer placeholder class TransformerLayerState to LayerState in core/models/common/utils.py (the name lives in common now and the Transformer prefix was misleading). * Move build_mtp_layer_callables and the build_layer_callables dispatcher out of core/models/gpt/fine_grained_callables.py and into the new core/models/common/fine_grained_callables.py. MTP is shared between GPTModel and HybridModel; the dispatcher's job is to dispatch on layer type and is naturally common too. Only build_transformer_layer_callables stays in gpt/ since it depends on GPT's MoE wiring. * Drop the gpt/ re-export block (the __all__ list and the bulk import from common.utils). Callers now import the moved names directly from core/models/common/utils.py and core/models/common/fine_grained_callables .py. Gpt/fine_grained_callables.py is GPT-only again. Restore the explanatory comments inside HybridModel._preprocess and _postprocess that the earlier forward-refactor lost — they were originally inline in forward() and got dropped when the body moved into the helper methods. The new comments explain decoder-input handling, rotary cos/sin discard, sequence-parallel gather rationale, and the speculative-decoding ordering constraint. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../models/common/fine_grained_callables.py | 129 +++++++++++++++++ .../common/model_chunk_schedule_plan.py | 6 +- megatron/core/models/common/utils.py | 2 +- .../core/models/gpt/fine_grained_callables.py | 134 ------------------ megatron/core/models/hybrid/hybrid_model.py | 45 +++++- .../transformer/test_submodule_callables.py | 2 +- 6 files changed, 177 insertions(+), 141 deletions(-) create mode 100644 megatron/core/models/common/fine_grained_callables.py diff --git a/megatron/core/models/common/fine_grained_callables.py b/megatron/core/models/common/fine_grained_callables.py new file mode 100644 index 00000000000..c2435e76f3a --- /dev/null +++ b/megatron/core/models/common/fine_grained_callables.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Layer-callable builders for the combined-1F1B fine-grained schedule plan. + +These build_* functions assemble the per-layer ``(forward_funcs, backward_dw)`` +tuple that the schedule plan plugs into ``TransformerLayerNode``. + +The TransformerLayer-specific builder lives in ``gpt/fine_grained_callables.py`` +because it depends on GPT's MoE wiring; the MTP builder and the dispatcher +``build_layer_callables`` are model-agnostic — both GPTModel and HybridModel +schedule MTP layers identically — so they live here. +""" + +from contextlib import nullcontext +from functools import partial + +import torch + +from megatron.core import tensor_parallel +from megatron.core.models.gpt.fine_grained_callables import build_transformer_layer_callables +from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.multi_token_prediction import ( + MultiTokenPredictionLayer, + get_mtp_layer_offset, +) +from megatron.core.transformer.transformer_layer import TransformerLayer + + +def build_mtp_layer_callables(layer): + """Callables for multi-token prediction layer nodes. + + This class contains the callable functions for different types of + multi-token prediction layer nodes (attention, MLP, etc.) + """ + + forward_funcs, backward_dw = build_transformer_layer_callables(layer.mtp_model_layer) + attn_forward, dispatch_forward, mlp_forward, combine_forward, _ = forward_funcs + is_moe = isinstance(layer.mtp_model_layer.mlp, MoELayer) + assert is_moe, "MTP layer in a2a overlap only supports MoE layer for now." + + def submodule_mtp_attn_forward(node, hidden_states): + # MTP Block Preprocess + if node.is_first_layer: + offset = get_mtp_layer_offset(layer.config, node.chunk_state.model.vp_stage) + node.chunk_state.mtp_hidden_states = list(torch.chunk(hidden_states, 1 + offset, dim=0)) + hidden_states = node.chunk_state.mtp_hidden_states[offset] + + input_ids, position_ids, decoder_input, hidden_states = layer._get_embeddings( + input_ids=node.chunk_state.input_ids, + position_ids=node.chunk_state.position_ids, + embedding=node.chunk_state.model.embedding, + hidden_states=hidden_states, + ) + node.chunk_state.input_ids = input_ids + node.chunk_state.position_ids = position_ids + + # MTP Layer Preprocess + # norm, linear projection and transformer + assert ( + node.chunk_state.context is None + ), f"multi token prediction + cross attention is not yet supported." + assert ( + node.chunk_state.packed_seq_params is None + ), f"multi token prediction + sequence packing is not yet supported." + + if layer.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # fp8 context is added in 1f1b schedule, so we don't need to add it here + with rng_context: + hidden_states = layer._concat_embeddings(hidden_states, decoder_input) + return attn_forward(node, hidden_states) + + def submodule_mtp_postprocess_forward(node, hidden_states): + hidden_states = layer._postprocess(hidden_states) + node.chunk_state.mtp_hidden_states.append(hidden_states) + if node.is_last_layer: + hidden_states = torch.cat(node.chunk_state.mtp_hidden_states, dim=0) + node.chunk_state.mtp_hidden_states = None + return hidden_states + + def rng_context_wrapper(func, *args, **kwargs): + """ + Wrapper to add rng context to submodule callables + """ + if layer.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + with rng_context: + return func(*args, **kwargs) + + # Build forward and backward callable functions + # attn_forward already has rng context, no need to wrap + attn_func = submodule_mtp_attn_forward + dispatch_func = partial(rng_context_wrapper, dispatch_forward) + mlp_func = partial(rng_context_wrapper, mlp_forward) + combine_func = partial(rng_context_wrapper, combine_forward) + mtp_post_process_func = submodule_mtp_postprocess_forward + + forward_funcs = [attn_func, dispatch_func, mlp_func, combine_func, mtp_post_process_func] + if isinstance(backward_dw["attn"], list): + backward_dw["attn"].append(layer.eh_proj) + else: + backward_dw["attn"] = [backward_dw["attn"], layer.eh_proj] + + return forward_funcs, backward_dw + + +def build_layer_callables(layer): + """ + Builds the callable functions(forward and dw) for the given layer. + For now, 1f1b overlap only support TransformerLayer and MultiTokenPredictionLayer. + + Args: + layer: The layer to build callables for. + + Returns: + forward_funcs: list of callable functions for the layer. + backward_dw: dict of weight gradient functions for the layer. + """ + if isinstance(layer, TransformerLayer): + return build_transformer_layer_callables(layer) + elif isinstance(layer, MultiTokenPredictionLayer): + return build_mtp_layer_callables(layer) + + raise ValueError(f"Unsupported layer type: {type(layer)}") diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index ec08598743e..0f34111eb84 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -70,10 +70,10 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar The event and chunk_state are binded to the TransformerModelChunkSchedulePlan and shared across all layers in the model chunk. """ - from megatron.core.models.common.utils import TransformerLayerState + from megatron.core.models.common.utils import LayerState self.config = layer.config - self.layer_state = TransformerLayerState() + self.layer_state = LayerState() self.chunk_state = chunk_state self.layer = layer self.event = event @@ -111,8 +111,8 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args): Builds the callable nodes for the transformer/mtp layer: attn, mlp, moe_dispatch and moe_combine, and mtp_post_process. """ + from megatron.core.models.common.fine_grained_callables import build_layer_callables from megatron.core.models.common.utils import TransformerLayerNode - from megatron.core.models.gpt.fine_grained_callables import build_layer_callables from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer diff --git a/megatron/core/models/common/utils.py b/megatron/core/models/common/utils.py index 531ff879599..ec1dd09230d 100644 --- a/megatron/core/models/common/utils.py +++ b/megatron/core/models/common/utils.py @@ -104,7 +104,7 @@ def should_free_input(name, is_moe, config, num_local_experts): return free_input_nodes.get(name, False) -class TransformerLayerState: +class LayerState: """State shared between the schedule nodes that come from one logical layer. Empty placeholder; nodes attach their own attributes (residual, dispatched diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index bd51ce698e6..454e0416700 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -1,22 +1,11 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -from contextlib import nullcontext -from functools import partial from typing import Optional import torch from torch import Tensor from megatron.core import tensor_parallel -from megatron.core.models.common.utils import ( - PostProcessNode, - PreProcessNode, - TransformerLayerNode, - TransformerLayerState, - _BackwardDWWrapper, - should_free_input, - weak_method, -) from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, @@ -24,30 +13,9 @@ from megatron.core.pipeline_parallel.utils import ScheduleNode from megatron.core.transformer.module import GraphableMegatronModule from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.transformer.multi_token_prediction import ( - MultiTokenPredictionLayer, - get_mtp_layer_offset, -) from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor from megatron.core.typed_torch import apply_module, copy_signature -# Re-export the model-agnostic schedule-plan helpers so existing imports of -# ``from megatron.core.models.gpt.fine_grained_callables import ...`` keep -# working. The implementations live in ``megatron.core.models.common.utils``; -# only the GPT-specific ``build_*_layer_callables`` builders below stay here. -__all__ = [ - "PostProcessNode", - "PreProcessNode", - "TransformerLayerNode", - "TransformerLayerState", - "_BackwardDWWrapper", - "should_free_input", - "weak_method", - "build_layer_callables", - "build_mtp_layer_callables", - "build_transformer_layer_callables", -] - def build_transformer_layer_callables(layer: TransformerLayer): """Create callables for transformer layer nodes. @@ -292,105 +260,3 @@ def raise_not_implemented(*args): backward_dw = {"attn": layer.backward_dw_wrapper, "mlp": layer.mlp} return forward_funcs, backward_dw - -def build_mtp_layer_callables(layer): - """Callables for multi-token prediction layer nodes. - - This class contains the callable functions for different types of - multi-token prediction layer nodes (attention, MLP, etc.) - """ - - forward_funcs, backward_dw = build_transformer_layer_callables(layer.mtp_model_layer) - attn_forward, dispatch_forward, mlp_forward, combine_forward, _ = forward_funcs - is_moe = isinstance(layer.mtp_model_layer.mlp, MoELayer) - assert is_moe, "MTP layer in a2a overlap only supports MoE layer for now." - - def submodule_mtp_attn_forward(node, hidden_states): - # MTP Block Preprocess - if node.is_first_layer: - offset = get_mtp_layer_offset(layer.config, node.chunk_state.model.vp_stage) - node.chunk_state.mtp_hidden_states = list(torch.chunk(hidden_states, 1 + offset, dim=0)) - hidden_states = node.chunk_state.mtp_hidden_states[offset] - - input_ids, position_ids, decoder_input, hidden_states = layer._get_embeddings( - input_ids=node.chunk_state.input_ids, - position_ids=node.chunk_state.position_ids, - embedding=node.chunk_state.model.embedding, - hidden_states=hidden_states, - ) - node.chunk_state.input_ids = input_ids - node.chunk_state.position_ids = position_ids - - # MTP Layer Preprocess - # norm, linear projection and transformer - assert ( - node.chunk_state.context is None - ), f"multi token prediction + cross attention is not yet supported." - assert ( - node.chunk_state.packed_seq_params is None - ), f"multi token prediction + sequence packing is not yet supported." - - if layer.config.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - - # fp8 context is added in 1f1b schedule, so we don't need to add it here - with rng_context: - hidden_states = layer._concat_embeddings(hidden_states, decoder_input) - return attn_forward(node, hidden_states) - - def submodule_mtp_postprocess_forward(node, hidden_states): - hidden_states = layer._postprocess(hidden_states) - node.chunk_state.mtp_hidden_states.append(hidden_states) - if node.is_last_layer: - hidden_states = torch.cat(node.chunk_state.mtp_hidden_states, dim=0) - node.chunk_state.mtp_hidden_states = None - return hidden_states - - def rng_context_wrapper(func, *args, **kwargs): - """ - Wrapper to add rng context to submodule callables - """ - if layer.config.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - with rng_context: - return func(*args, **kwargs) - - # Build forward and backward callable functions - # attn_forward already has rng context, no need to wrap - attn_func = submodule_mtp_attn_forward - dispatch_func = partial(rng_context_wrapper, dispatch_forward) - mlp_func = partial(rng_context_wrapper, mlp_forward) - combine_func = partial(rng_context_wrapper, combine_forward) - mtp_post_process_func = submodule_mtp_postprocess_forward - - forward_funcs = [attn_func, dispatch_func, mlp_func, combine_func, mtp_post_process_func] - if isinstance(backward_dw["attn"], list): - backward_dw["attn"].append(layer.eh_proj) - else: - backward_dw["attn"] = [backward_dw["attn"], layer.eh_proj] - - return forward_funcs, backward_dw - - -def build_layer_callables(layer): - """ - Builds the callable functions(forward and dw) for the given layer. - For now, 1f1b overlap only support TransformerLayer and MultiTokenPredictionLayer. - - Args: - layer: The layer to build callables for. - - Returns: - forward_funcs: list of callable functions for the layer. - backward_dw: dict of weight gradient functions for the layer. - """ - if isinstance(layer, TransformerLayer): - return build_transformer_layer_callables(layer) - elif isinstance(layer, MultiTokenPredictionLayer): - return build_mtp_layer_callables(layer) - - raise ValueError(f"Unsupported layer type: {type(layer)}") diff --git a/megatron/core/models/hybrid/hybrid_model.py b/megatron/core/models/hybrid/hybrid_model.py index e2d5b0806bb..33415b2a502 100644 --- a/megatron/core/models/hybrid/hybrid_model.py +++ b/megatron/core/models/hybrid/hybrid_model.py @@ -337,14 +337,27 @@ def _preprocess( packed_seq_params: PackedSeqParams = None, padding_mask: Optional[Tensor] = None, ): - """Preprocess inputs for HybridStack or combined-1F1B scheduling.""" + """Preprocess inputs for HybridStack or combined-1F1B scheduling. + + Mirrors ``GPTModel._preprocess`` so the eager forward and the + EP-overlap ``PreProcessNode`` see the same embedding / rotary / + padding-mask code. Returns the canonical 6-tuple ``(decoder_input, + rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset, + padding_mask)`` — slots HybridModel does not compute (rotary cos/sin) + come back as ``None``. + """ in_inference_mode = inference_context is not None and not self.training + # If decoder_input is provided, input_ids and position_ids are ignored; + # otherwise apply the embedding layer to get decoder_input. if decoder_input is not None: pass elif self.pre_process: + # Decoder embedding. decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + # Clear the outputs for padding tokens when using dynamic batching + # with quantization scales to avoid corrupting amax calculations. if ( in_inference_mode and inference_context.is_dynamic_batching() @@ -352,6 +365,8 @@ def _preprocess( ): decoder_input[inference_context.padding_slice] = 0.0 else: + # Intermediate stage of pipeline parallelism — the decoder will get + # hidden_states from encoder.input_tensor. decoder_input = None rotary_pos_emb = None @@ -367,11 +382,14 @@ def _preprocess( rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( inference_context, self.decoder, decoder_input, self.config, packed_seq_params ) + # YarnRotaryEmbedding.forward returns (emb, mscale); discard mscale here. rotary_pos_emb, _ = self.rotary_pos_emb( rotary_seq_len, packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', ) + # ``sequence_len_offset`` is only needed for flash-decode / local-cudagraph + # static-batching inference; otherwise leave it as ``None``. if ( in_inference_mode and ( @@ -394,6 +412,8 @@ def _preprocess( else: sequence_len_offset = None + # Wrap decoder_input so the decoder (HybridStack) can drop its caller's + # reference for early garbage collection during inference. if in_inference_mode: decoder_input = WrappedTensor(decoder_input) @@ -438,7 +458,14 @@ def _postprocess( inference_context=None, is_spec_decode=None, ): - """Postprocess HybridStack hidden states into logits or language-model loss.""" + """Postprocess HybridStack hidden states into logits or language-model loss. + + Mirrors ``GPTModel._postprocess`` so the eager forward and the EP-overlap + ``PostProcessNode`` produce the same logits / loss / MTP outputs. + ``mtp_in_postprocess`` lets the EP-overlap path skip the inline MTP block + (it schedules MTP as separate layer nodes); the eager forward leaves it + ``True`` so the regular MTP forward runs here. + """ in_inference_mode = inference_context is not None and not self.training if in_inference_mode: assert runtime_gather_output, "Inference must always gather TP logits" @@ -447,6 +474,8 @@ def _postprocess( if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() + # Speculative decoding: when active, MTP must run *after* verification so + # it conditions on verified tokens rather than stale speculative ones. if is_spec_decode is None: is_spec_decode = ( in_inference_mode @@ -454,6 +483,10 @@ def _postprocess( and inference_context.num_speculative_tokens > 0 ) + # MTP forward inline (skipped when the EP-overlap plan schedules MTP + # separately, when running inference, or when speculative decoding is + # active). ``self.mtp_process`` guards against models built without an + # MTP block. if ( mtp_in_postprocess and self.mtp_process @@ -476,8 +509,11 @@ def _postprocess( if self.config.mtp_num_layers is not None and self.mtp_process: assert self.config.mtp_num_layers > 0 if in_inference_mode or is_spec_decode: + # Cache decoder hidden states for serial MTP computation after + # speculative token verification. self._decoder_hidden_states_cache = hidden_states else: + # In training/eval, fold MTP loss into hidden_states. hidden_states = process_mtp_loss( hidden_states=hidden_states, labels=labels, @@ -499,12 +535,17 @@ def _postprocess( hidden_states = hidden_states[-1:, :, :] else: if self.output_layer.sequence_parallel: + # Perform the sequence-parallel gather here instead of after + # the output layer so we can slice the last-token logits from + # the full view of the packed logits across all requests. hidden_states = gather_from_sequence_parallel_region( hidden_states, group=self.pg_collection.tp ) self.output_layer.sequence_parallel = False sequence_parallel_override = True + # Reshape [S, B, H] (with B=1) to [1, S, H] for logit extraction, + # then back to [S', B, H] for the output layer. reshaped = hidden_states.squeeze(1).unsqueeze(0) hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1) diff --git a/tests/unit_tests/transformer/test_submodule_callables.py b/tests/unit_tests/transformer/test_submodule_callables.py index 7b41b3ca197..67d60c91b46 100644 --- a/tests/unit_tests/transformer/test_submodule_callables.py +++ b/tests/unit_tests/transformer/test_submodule_callables.py @@ -2,7 +2,7 @@ import pytest import torch -from megatron.core.models.gpt.fine_grained_callables import build_layer_callables +from megatron.core.models.common.fine_grained_callables import build_layer_callables from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_with_transformer_engine_submodules, ) From 363f4bd59776678677fc034ab571ca6c62b1aaa4 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Fri, 8 May 2026 10:28:05 -0700 Subject: [PATCH 07/16] Rename schedule-plan slot 'attn' to 'pre_dispatch_computation' The slot covers more than attention: in the GPT path it runs attention + pre-MLP layernorm + router + dispatch preprocess; in the hybrid grouped path it loops over Mamba / attention / GDN sub-layers and ends with the MoE dispatch preprocess. The 'attn' name is misleading on both sides and diverges from the hybrid forward callable that has long been called pre_dispatch_computation. Rename the slot consistently: * TransformerLayerSchedulePlan: class attribute attn -> pre_dispatch_ computation, release_state, and run() updated. backward_dw key 'attn' -> 'pre_dispatch_computation' so it matches the slot name the schedule node passes when looking up the bwd_dw callables map. * HybridStackSchedulePlan: same self.pre_dispatch_computation slot. * GPTModel build_transformer_layer_callables and the MTP builder: forward_funcs first entry is now pre_dispatch_func; backward_dw uses the renamed key. * Hybrid build_hybrid_stack_callables: backward_dw uses the renamed key. No behavior change; the slot still receives the same callables and runs on the same stream. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../models/common/fine_grained_callables.py | 22 +++++++--- .../common/model_chunk_schedule_plan.py | 44 +++++++++++-------- .../core/models/gpt/fine_grained_callables.py | 9 ++-- .../models/hybrid/fine_grained_callables.py | 8 ++-- .../hybrid/model_chunk_schedule_plan.py | 6 ++- 5 files changed, 56 insertions(+), 33 deletions(-) diff --git a/megatron/core/models/common/fine_grained_callables.py b/megatron/core/models/common/fine_grained_callables.py index c2435e76f3a..4d99bab3211 100644 --- a/megatron/core/models/common/fine_grained_callables.py +++ b/megatron/core/models/common/fine_grained_callables.py @@ -92,19 +92,27 @@ def rng_context_wrapper(func, *args, **kwargs): with rng_context: return func(*args, **kwargs) - # Build forward and backward callable functions - # attn_forward already has rng context, no need to wrap - attn_func = submodule_mtp_attn_forward + # Build forward and backward callable functions. + # pre_dispatch_func already has rng context (rolled into submodule_mtp_attn_forward), + # so it does not need to be wrapped. + pre_dispatch_func = submodule_mtp_attn_forward dispatch_func = partial(rng_context_wrapper, dispatch_forward) mlp_func = partial(rng_context_wrapper, mlp_forward) combine_func = partial(rng_context_wrapper, combine_forward) mtp_post_process_func = submodule_mtp_postprocess_forward - forward_funcs = [attn_func, dispatch_func, mlp_func, combine_func, mtp_post_process_func] - if isinstance(backward_dw["attn"], list): - backward_dw["attn"].append(layer.eh_proj) + forward_funcs = [ + pre_dispatch_func, + dispatch_func, + mlp_func, + combine_func, + mtp_post_process_func, + ] + pre_dispatch_bwd = backward_dw["pre_dispatch_computation"] + if isinstance(pre_dispatch_bwd, list): + pre_dispatch_bwd.append(layer.eh_proj) else: - backward_dw["attn"] = [backward_dw["attn"], layer.eh_proj] + backward_dw["pre_dispatch_computation"] = [pre_dispatch_bwd, layer.eh_proj] return forward_funcs, backward_dw diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index 0f34111eb84..9af52f05cc0 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -35,20 +35,23 @@ class TransformerLayerSchedulePlan: mtp post process nodes. layer (TransformerLayerSchedulePlan) - ├── attn (TransformerLayerNode): attention -> layernorm -> router -> dispatch preprocess + ├── pre_dispatch_computation (TransformerLayerNode): + │ attention -> layernorm -> router -> dispatch preprocess ├── moe_dispatch (TransformerLayerNode): dispatch All2All ├── mlp (TransformerLayerNode): mlp module ├── moe_combine (TransformerLayerNode): combine All2All └── mtp_post_process (PostProcessNode): mtp post process Note that MTP layer has the same operation and execution order with TransformerLayer regarding - moe_dispatch, mlp, moe_combine, but contains extra operations in attn and mtp_post_process: - * mtp.attn wraps around transformer_layer.attn with extra norm, proj and embedding operations. + moe_dispatch, mlp, moe_combine, but contains extra operations in + pre_dispatch_computation and mtp_post_process: + * mtp.pre_dispatch_computation wraps around transformer_layer.pre_dispatch_computation with + extra norm, proj and embedding operations. * mtp.mtp_post_process contains output_layer, mtp loss operations, whereas transformer_layer.mtp_post_process is empty. """ - attn = None + pre_dispatch_computation = None moe_dispatch = None mlp = None moe_combine = None @@ -85,9 +88,12 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar def release_state(self): """Release reference, this helps avoid memory leak.""" - if hasattr(self, 'attn') and self.attn is not None: - del self.attn - self.attn = None + if ( + hasattr(self, 'pre_dispatch_computation') + and self.pre_dispatch_computation is not None + ): + del self.pre_dispatch_computation + self.pre_dispatch_computation = None if hasattr(self, 'moe_dispatch') and self.moe_dispatch is not None: del self.moe_dispatch self.moe_dispatch = None @@ -146,7 +152,7 @@ def create_node(stream, module, name): ) ( - attn_module, + pre_dispatch_module, moe_dispatch_module, mlp_module, moe_combine_module, @@ -155,7 +161,9 @@ def create_node(stream, module, name): # Create nodes for different operations in the layer # Each node type has a predefined name that determines its memory strategy - self.attn = create_node(comp_stream, attn_module, "attn") + self.pre_dispatch_computation = create_node( + comp_stream, pre_dispatch_module, "pre_dispatch_computation" + ) self.mlp = create_node(comp_stream, mlp_module, "mlp") if is_moe: self.moe_dispatch = create_node(comm_stream, moe_dispatch_module, "moe_dispatch") @@ -216,7 +224,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False) if f_layer is not None: with f_layer.get_fp8_context(): - f_input = f_layer.attn.forward(f_input) + f_input = f_layer.pre_dispatch_computation.forward(f_input) if b_layer is not None: b_grad = b_layer.mlp.backward(b_grad) @@ -230,7 +238,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False) b_grad = b_layer.moe_dispatch.backward(b_grad) if b_layer is not None and b_layer.config.ep_overlap_early_attn_memory_release: - b_grad = b_layer.attn.backward(b_grad) + b_grad = b_layer.pre_dispatch_computation.backward(b_grad) if f_layer is not None: with f_layer.get_fp8_context(): @@ -242,12 +250,12 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False) f_input = f_layer.mtp_post_process.forward(f_input) if b_layer is not None and not b_layer.config.ep_overlap_early_attn_memory_release: - b_grad = b_layer.attn.backward(b_grad) + b_grad = b_layer.pre_dispatch_computation.backward(b_grad) - # Delay the last attn_dw in backward pass (attn_dw of the first layer) - # for overlapping with the p2p comm + # Delay the last pre_dispatch_computation wgrad in backward pass (wgrad + # of the first layer) for overlapping with the p2p comm. if b_layer is not None and not is_last_layer_in_bwd: - b_layer.attn.backward_dw() + b_layer.pre_dispatch_computation.backward_dw() return f_input, b_grad @@ -549,11 +557,11 @@ def run( b_schedule_plan.wait_current_stream() post_backward(b_grad, b_schedule_plan.vp_stage) - # Delay the last attn_dw in backward pass (attn_dw of the first layer) - # for overlapping with the p2p comm + # Delay the last pre_dispatch_computation wgrad in backward pass (wgrad + # of the first layer) for overlapping with the p2p comm. if b_num_layers > 0: assert b_layer is not None - b_layer.attn.backward_dw() + b_layer.pre_dispatch_computation.backward_dw() b_layer.release_state() # post process forward diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 454e0416700..0b3a17ea12e 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -249,14 +249,17 @@ def raise_not_implemented(*args): raise NotImplementedError("This callable is not implemented for Dense layer.") # Build forward and backward callable functions - attn_func = submodule_attn_forward + pre_dispatch_func = submodule_attn_forward dispatch_func = submodule_dispatch_forward if is_moe else raise_not_implemented mlp_func = submodule_moe_forward if is_moe else mlp_wrapper combine_func = submodule_combine_forward if is_moe else raise_not_implemented layer.init_backward_dw_wrapper() - forward_funcs = [attn_func, dispatch_func, mlp_func, combine_func, None] - backward_dw = {"attn": layer.backward_dw_wrapper, "mlp": layer.mlp} + forward_funcs = [pre_dispatch_func, dispatch_func, mlp_func, combine_func, None] + backward_dw = { + "pre_dispatch_computation": layer.backward_dw_wrapper, + "mlp": layer.mlp, + } return forward_funcs, backward_dw diff --git a/megatron/core/models/hybrid/fine_grained_callables.py b/megatron/core/models/hybrid/fine_grained_callables.py index 2c079cb05a0..0221e6f6af2 100644 --- a/megatron/core/models/hybrid/fine_grained_callables.py +++ b/megatron/core/models/hybrid/fine_grained_callables.py @@ -40,8 +40,10 @@ def _resolve_free_input(name, is_moe, config, num_local_experts): Currently mirrors the GPT default: dense layers always retain their input for backward; MoE-only "moe_dispatch", "mlp", and "moe_combine" slots can free, subject to the dispatcher / cuda-graph constraints - encoded in ``should_free_input``. Hybrid groups have an "attn" slot - whose semantics differ, but its policy resolves to ``False`` in + encoded in ``should_free_input``. Hybrid groups have a + "pre_dispatch_computation" slot whose semantics differ (it covers a + loop over Mamba/attention/GDN sub-layers, not a single attention + block), but its policy resolves to ``False`` in ``should_free_input``, which is correct: pre-layer outputs are needed for backward through the loop. Override here when a hybrid-specific rule is needed. @@ -286,7 +288,7 @@ def raise_not_implemented(*args): backward_dw["mlp"] = terminal_layer.mlp if pre_bwd_dw: - backward_dw["attn"] = pre_bwd_dw + backward_dw["pre_dispatch_computation"] = pre_bwd_dw forward_funcs = [ pre_dispatch_computation, diff --git a/megatron/core/models/hybrid/model_chunk_schedule_plan.py b/megatron/core/models/hybrid/model_chunk_schedule_plan.py index 7c34da6219a..cd3dd0cab5b 100644 --- a/megatron/core/models/hybrid/model_chunk_schedule_plan.py +++ b/megatron/core/models/hybrid/model_chunk_schedule_plan.py @@ -75,14 +75,16 @@ def create_node(stream, module, name): ) ( - attn_module, + pre_dispatch_module, moe_dispatch_module, mlp_module, moe_combine_module, mtp_post_process_module, ) = fwd_callables - self.attn = create_node(comp_stream, attn_module, "attn") + self.pre_dispatch_computation = create_node( + comp_stream, pre_dispatch_module, "pre_dispatch_computation" + ) self.mlp = create_node(comp_stream, mlp_module, "mlp") if is_moe: self.moe_dispatch = create_node(comm_stream, moe_dispatch_module, "moe_dispatch") From 03453aee5b24d2955c1a5fb3381472367aada8d8 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Fri, 8 May 2026 10:29:22 -0700 Subject: [PATCH 08/16] Add backward_dw to MambaMixer and MambaLayer; register Mamba pre-layers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MambaLayer used to be skipped in the hybrid EP-overlap pre-layer backward-dw registration loop because MambaLayer is not a TransformerLayer and the standard _BackwardDWWrapper(self) constructor asserts isinstance (layer, TransformerLayer). That meant Mamba pre-layers in a grouped HybridStack ran their wgrad inline in the regular backward pass while attention / GDN pre-layers got delayed wgrad — different scheduling for identical-shaped slots. Mirror GatedDeltaNet.backward_dw on Mamba: MambaMixer.backward_dw calls in_proj.backward_dw and out_proj.backward_dw (no-op when the spec uses non-TE linears that lack delayed wgrad); MambaLayer.backward_dw delegates to its mixer. In the hybrid pre-layer registration, add a MAMBA branch that appends the layer directly (not the wrapper, since MambaLayer is not a TransformerLayer). The schedule node already iterates the list and calls .backward_dw() on each, so the direct registration just works. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../core/models/hybrid/fine_grained_callables.py | 10 ++++++++++ megatron/core/ssm/mamba_layer.py | 11 +++++++++++ megatron/core/ssm/mamba_mixer.py | 15 +++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/megatron/core/models/hybrid/fine_grained_callables.py b/megatron/core/models/hybrid/fine_grained_callables.py index 0221e6f6af2..44e3912a492 100644 --- a/megatron/core/models/hybrid/fine_grained_callables.py +++ b/megatron/core/models/hybrid/fine_grained_callables.py @@ -269,8 +269,18 @@ def raise_not_implemented(*args): pre_bwd_dw = [] for item_type, item_layer in pre_layers: if item_type in (LayerSymbols.ATTENTION, LayerSymbols.DS_ATTENTION, LayerSymbols.GDN): + # TransformerLayer-backed pre-layers go through the standard + # _BackwardDWWrapper which coordinates attn / shared-expert wgrad + # with cuda-graph replay scopes. item_layer.init_backward_dw_wrapper() pre_bwd_dw.append(item_layer.backward_dw_wrapper) + elif item_type == LayerSymbols.MAMBA: + # MambaLayer is not a TransformerLayer, so init_backward_dw_wrapper + # would assert. MambaLayer.backward_dw delegates to its mixer, which + # in turn calls backward_dw on the in_proj / out_proj linears. The + # schedule node iterates this list and calls .backward_dw() on each; + # registering the layer directly is sufficient. + pre_bwd_dw.append(item_layer) if is_moe: # MoELayer.backward_dw default kwargs (routed_experts=True, shared_experts=False) handle # the routed-experts wgrad. The shared-experts wgrad is registered as a sibling callable diff --git a/megatron/core/ssm/mamba_layer.py b/megatron/core/ssm/mamba_layer.py index 17903cebf3b..2a49fcbbade 100644 --- a/megatron/core/ssm/mamba_layer.py +++ b/megatron/core/ssm/mamba_layer.py @@ -151,6 +151,17 @@ def forward( return hidden_states + def backward_dw(self): + """Compute weight gradients for the layer's linear projections. + + Delegates to the mixer; lets the hybrid EP-overlap schedule plan + register a Mamba pre-layer's wgrad alongside attention/GDN pre-layers + so the schedule node iterates a uniform set of callables. No-op when + the linears in the spec do not support delayed wgrad. + """ + if hasattr(self.mixer, "backward_dw"): + self.mixer.backward_dw() + def sharded_state_dict( self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None ) -> ShardedStateDict: diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 727c6ef5fd6..a051164522c 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -1227,6 +1227,21 @@ def mamba_state_shapes_per_request(self) -> Tuple[Tuple[int], Tuple[int]]: ssm_states_shape = (self.nheads_local_tp, self.headdim, self.d_state) return (conv_states_shape, ssm_states_shape) + def backward_dw(self): + """Compute weight gradients for the linear layers wrapped by this mixer. + + Mirrors ``GatedDeltaNet.backward_dw``. The selective-scan kernel is a + single autograd function whose wgrad runs in the regular backward pass, + so only the input/output projections need delayed wgrad here. Each + ``backward_dw`` call is a no-op unless the underlying linear is built + from a TE primitive that supports delayed wgrad; if the spec uses + non-TE linears, ``backward_dw`` simply does nothing. + """ + if hasattr(self.in_proj, "backward_dw"): + self.in_proj.backward_dw() + if hasattr(self.out_proj, "backward_dw"): + self.out_proj.backward_dw() + def _get_states_from_cache(self, inference_context, batch_size, *, inference_params=None): """Initializes or retrieves the SSM state tensors from the cache. From bc8cb2d147db4fe56dafc1776cb413828754edb5 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Fri, 8 May 2026 10:49:36 -0700 Subject: [PATCH 09/16] test_group_overlap: update bwd_dw_callable_map slot name to pre_dispatch_computation Follow-up to the schedule-plan slot rename: the test assertion that the hybrid grouped-overlap callables expose a backward-dw entry for the pre-dispatch slot was still checking the old 'attn' key; update it to 'pre_dispatch_computation'. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/unit_tests/ssm/test_hybrid_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/ssm/test_hybrid_block.py b/tests/unit_tests/ssm/test_hybrid_block.py index efbe59cf9f3..adaf7479017 100644 --- a/tests/unit_tests/ssm/test_hybrid_block.py +++ b/tests/unit_tests/ssm/test_hybrid_block.py @@ -277,7 +277,7 @@ def test_group_overlap_callables_keep_ep_moe_split_visible(self): assert mtp_post_process is None assert is_moe assert num_local_experts == 8 - assert "attn" in bwd_dw_callable_map + assert "pre_dispatch_computation" in bwd_dw_callable_map assert "mlp" in bwd_dw_callable_map def test_invalid_layer_types_cause_failure(self): From 2a2dd2b3d75a1bc2efc9303993c24c3b8af5a69c Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 11 May 2026 08:29:51 -0700 Subject: [PATCH 10/16] Unify build_layer_callables dispatcher; rename mtp attn_forward local MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three follow-up cleanups from review of common/fine_grained_callables.py: * build_layer_callables now dispatches HybridStack alongside TransformerLayer and MultiTokenPredictionLayer, so an mtp_model_layer that happens to be a HybridStack (or a future generic decoder layer) goes through the dispatcher rather than the TransformerLayer-only entrypoint. * The dispatcher now returns (forward_funcs, backward_dw, is_moe, num_local_experts) for every layer type (the build function already knows the layer, so the caller doesn't have to re-derive these). build_mtp_layer _callables calls build_layer_callables on its inner layer so the same pass-through works recursively. * Rename the local variable attn_forward (and the helper submodule_mtp_attn _forward) inside build_mtp_layer_callables to pre_dispatch_forward / submodule_mtp_pre_dispatch_forward — the schedule slot was renamed in 363f4bd59 but the local names inside the MTP wrapper were missed. Also tighten the assert message in GraphableMegatronModule.init_backward_dw _wrapper: the wrapper is no longer hard-bound to TransformerLayer in documentation tone (MambaLayer.backward_dw exists now; Mamba just doesn't use the init_backward_dw_wrapper path because _BackwardDWWrapper still asserts TransformerLayer in __init__). The assertion checks config, not type, and the message now says so. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../models/common/fine_grained_callables.py | 60 ++++++++++++------- .../common/model_chunk_schedule_plan.py | 12 ++-- megatron/core/transformer/module.py | 14 ++++- .../transformer/test_submodule_callables.py | 2 +- 4 files changed, 55 insertions(+), 33 deletions(-) diff --git a/megatron/core/models/common/fine_grained_callables.py b/megatron/core/models/common/fine_grained_callables.py index 4d99bab3211..9a100003de6 100644 --- a/megatron/core/models/common/fine_grained_callables.py +++ b/megatron/core/models/common/fine_grained_callables.py @@ -29,16 +29,26 @@ def build_mtp_layer_callables(layer): """Callables for multi-token prediction layer nodes. - This class contains the callable functions for different types of - multi-token prediction layer nodes (attention, MLP, etc.) + Wraps the inner ``layer.mtp_model_layer``'s callables with MTP-specific + pre-process (chunk and concat embeddings) and post-process (gather across + depths) steps. The inner layer is built by ``build_layer_callables`` so + that ``mtp_model_layer`` can be a TransformerLayer (today's case) or a + HybridStack (when an MTP depth uses the hybrid layout). """ - forward_funcs, backward_dw = build_transformer_layer_callables(layer.mtp_model_layer) - attn_forward, dispatch_forward, mlp_forward, combine_forward, _ = forward_funcs - is_moe = isinstance(layer.mtp_model_layer.mlp, MoELayer) + forward_funcs, backward_dw, is_moe, num_local_experts = build_layer_callables( + layer.mtp_model_layer + ) + ( + pre_dispatch_forward, + dispatch_forward, + mlp_forward, + combine_forward, + _, + ) = forward_funcs assert is_moe, "MTP layer in a2a overlap only supports MoE layer for now." - def submodule_mtp_attn_forward(node, hidden_states): + def submodule_mtp_pre_dispatch_forward(node, hidden_states): # MTP Block Preprocess if node.is_first_layer: offset = get_mtp_layer_offset(layer.config, node.chunk_state.model.vp_stage) @@ -71,7 +81,7 @@ def submodule_mtp_attn_forward(node, hidden_states): # fp8 context is added in 1f1b schedule, so we don't need to add it here with rng_context: hidden_states = layer._concat_embeddings(hidden_states, decoder_input) - return attn_forward(node, hidden_states) + return pre_dispatch_forward(node, hidden_states) def submodule_mtp_postprocess_forward(node, hidden_states): hidden_states = layer._postprocess(hidden_states) @@ -93,9 +103,9 @@ def rng_context_wrapper(func, *args, **kwargs): return func(*args, **kwargs) # Build forward and backward callable functions. - # pre_dispatch_func already has rng context (rolled into submodule_mtp_attn_forward), - # so it does not need to be wrapped. - pre_dispatch_func = submodule_mtp_attn_forward + # pre_dispatch_func already has rng context (rolled into + # submodule_mtp_pre_dispatch_forward), so it does not need to be wrapped. + pre_dispatch_func = submodule_mtp_pre_dispatch_forward dispatch_func = partial(rng_context_wrapper, dispatch_forward) mlp_func = partial(rng_context_wrapper, mlp_forward) combine_func = partial(rng_context_wrapper, combine_forward) @@ -114,24 +124,28 @@ def rng_context_wrapper(func, *args, **kwargs): else: backward_dw["pre_dispatch_computation"] = [pre_dispatch_bwd, layer.eh_proj] - return forward_funcs, backward_dw + return forward_funcs, backward_dw, is_moe, num_local_experts def build_layer_callables(layer): - """ - Builds the callable functions(forward and dw) for the given layer. - For now, 1f1b overlap only support TransformerLayer and MultiTokenPredictionLayer. - - Args: - layer: The layer to build callables for. + """Dispatch to the appropriate layer-callable builder. - Returns: - forward_funcs: list of callable functions for the layer. - backward_dw: dict of weight gradient functions for the layer. + Returns ``(forward_funcs, backward_dw, is_moe, num_local_experts)`` so the + schedule plan does not need to re-derive ``is_moe`` / + ``num_local_experts`` after the call — the build function already knows + the layer type. ``num_local_experts`` is ``None`` for dense layers. """ - if isinstance(layer, TransformerLayer): - return build_transformer_layer_callables(layer) - elif isinstance(layer, MultiTokenPredictionLayer): + from megatron.core.models.hybrid.fine_grained_callables import build_hybrid_stack_callables + from megatron.core.models.hybrid.hybrid_block import HybridStack + + if isinstance(layer, HybridStack): + return build_hybrid_stack_callables(layer) + if isinstance(layer, MultiTokenPredictionLayer): return build_mtp_layer_callables(layer) + if isinstance(layer, TransformerLayer): + forward_funcs, backward_dw = build_transformer_layer_callables(layer) + is_moe = isinstance(layer.mlp, MoELayer) + num_local_experts = layer.mlp.num_local_experts if is_moe else None + return forward_funcs, backward_dw, is_moe, num_local_experts raise ValueError(f"Unsupported layer type: {type(layer)}") diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index 9af52f05cc0..ccf30a0df32 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -119,17 +119,15 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args): """ from megatron.core.models.common.fine_grained_callables import build_layer_callables from megatron.core.models.common.utils import TransformerLayerNode - from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer - # build the forward and backward callables for the transformer/mtp layer - fwd_callables, bwd_dw_callable_map = build_layer_callables(self.layer) + # The dispatcher returns is_moe / num_local_experts directly since it + # already knows the layer type (saves a separate isinstance dance here). + fwd_callables, bwd_dw_callable_map, is_moe, num_local_experts = build_layer_callables( + self.layer + ) - # get flags for latter use is_mtp = isinstance(self.layer, MultiTokenPredictionLayer) - transformer_layer = self.layer.mtp_model_layer if is_mtp else self.layer - is_moe = isinstance(transformer_layer.mlp, MoELayer) - num_local_experts = transformer_layer.mlp.num_local_experts if is_moe else None extra_args["config"] = self.layer.config extra_args["is_moe"] = is_moe diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index ae2cb08d035..b9535474161 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -199,12 +199,22 @@ def __init__(self, config: TransformerConfig, vp_stage: Optional[int] = None): self.cuda_graph_backward_dw_wrapper = None def init_backward_dw_wrapper(self): - """Initialize the backward_dw_wrapper.""" + """Initialize ``self.backward_dw_wrapper`` for delayed-wgrad scheduling. + + The wrapper coordinates the per-layer wgrad callables (attention + wgrad, optional shared-expert wgrad) with cuda-graph replay scope so + captured components are not re-run eagerly. The method is defined on + ``GraphableMegatronModule`` so any graphable subclass can opt in; + ``_BackwardDWWrapper`` itself currently asserts the underlying layer + is a ``TransformerLayer``, so MambaLayer-derived modules implement + ``backward_dw`` directly and skip this helper. + """ from megatron.core.models.common.utils import _BackwardDWWrapper config = getattr(self, 'config', None) assert config is not None, ( - "TransformerLayer must be initialized before calling " "`init_backward_dw_wrapper`." + "Module must be fully constructed (config set) before calling " + "`init_backward_dw_wrapper`." ) self.backward_dw_wrapper = _BackwardDWWrapper(self) diff --git a/tests/unit_tests/transformer/test_submodule_callables.py b/tests/unit_tests/transformer/test_submodule_callables.py index 67d60c91b46..f4e6cdc0c89 100644 --- a/tests/unit_tests/transformer/test_submodule_callables.py +++ b/tests/unit_tests/transformer/test_submodule_callables.py @@ -65,7 +65,7 @@ def run_model_submodules_with_capture(model, input_tensors, microbatches): output_tensors = [] # get callables - callables, dw = build_layer_callables(model) + callables, dw, _is_moe, _num_local_experts = build_layer_callables(model) attn, dispatch, moe, combine, post_process = callables assert post_process is None dummy_model = DummyState() From 92f00aaca0d1fb046fef78d70089da14135910b3 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 11 May 2026 12:04:58 -0700 Subject: [PATCH 11/16] Finish slot rename: sync docstrings and gpt-side submodule name Round-2 renamed the schedule slot 'attn' to 'pre_dispatch_computation' and round-3 caught the MTP-wrapper local-variable parallel, but left peer docstrings and the GPT-side nested function un-renamed. Sweep the remaining references so a reader of the docs sees the same slot name the code uses: * common/model_chunk_schedule_plan.py: update _build_callable_nodes docstring, the run() overlap diagram (attn_fwd/attn_bwd -> pre_dispatch_fwd/pre_dispatch_bwd), and the two post-overlap comments. * common/utils.py: update should_free_input, TransformerLayerNode, and _BackwardDWWrapper docstrings; update the inline [attn, fake, mlp, fake] dense-layer list comment. * hybrid/fine_grained_callables.py: update the HybridStackNode docstring to reference the pre_dispatch_computation slot. * gpt/fine_grained_callables.py: rename submodule_attn_forward -> submodule_pre_dispatch_forward (def + binding), rewrite the build_transformer_layer_callables 5-callables enumeration and Returns block to match the real slot names and backward_dw dict keys, fix the 'attnention' typo, and disambiguate the dispatch_forward 'attn submodule' comment. No behavior change. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../common/model_chunk_schedule_plan.py | 13 ++++---- megatron/core/models/common/utils.py | 20 +++++++------ .../core/models/gpt/fine_grained_callables.py | 30 +++++++++++-------- .../models/hybrid/fine_grained_callables.py | 15 +++++----- 4 files changed, 44 insertions(+), 34 deletions(-) diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index ccf30a0df32..1b57537f27a 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -115,7 +115,8 @@ def release_state(self): def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args): """ Builds the callable nodes for the transformer/mtp layer: - attn, mlp, moe_dispatch and moe_combine, and mtp_post_process. + pre_dispatch_computation, moe_dispatch, mlp, moe_combine, + and mtp_post_process. """ from megatron.core.models.common.fine_grained_callables import build_layer_callables from megatron.core.models.common.utils import TransformerLayerNode @@ -195,12 +196,12 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False) """Schedule one-forward-one-backward operations for a single transformer layer. This function interleaves forward and backward operations, overlapping the communications - (dispatch or combine) of one with the computations (att or mlp) of the other + (dispatch or combine) of one with the computations (pre_dispatch or mlp) of the other to maximize parallelism and efficiency. When f_layer and b_layer are not None, forward and backward pass are overlapped as follows: - comm_stream: combine_bwd | dispatch_fwd->dispatch_bwd | combine_fwd - comp_stream: attn_fwd | mlp_bwd->mlp_bwd_dw->mlp_fwd| attn_bwd + comm_stream: combine_bwd | dispatch_fwd->dispatch_bwd | combine_fwd + comp_stream: pre_dispatch_fwd | mlp_bwd->mlp_bwd_dw->mlp_fwd| pre_dispatch_bwd For MTP, mtp_post_process_fwd is executed after the combine_fwd in the comp_stream, and mtp_post_process_bwd is executed before the combine_bwd in the comp_stream. @@ -544,13 +545,13 @@ def run( if f_schedule_plan is not None and post_forward is not None: # post_forward()/send_forward_recv_forward() is running in the communication stream, - # so the p2p comm could be overlapped with the attn backward + # so the p2p comm could be overlapped with the pre_dispatch backward with torch.cuda.stream(get_comm_stream()): f_schedule_plan.wait_current_stream() post_forward(f_input, f_schedule_plan.vp_stage) # post_backward()/send_backward_recv_backward() is running in the computation stream, - # so the p2p comm could be overlapped with the wgrad of attn backward + # so the p2p comm could be overlapped with the wgrad of pre_dispatch backward if b_schedule_plan is not None and post_backward is not None: b_schedule_plan.wait_current_stream() post_backward(b_grad, b_schedule_plan.vp_stage) diff --git a/megatron/core/models/common/utils.py b/megatron/core/models/common/utils.py index ec1dd09230d..b231678ba5a 100644 --- a/megatron/core/models/common/utils.py +++ b/megatron/core/models/common/utils.py @@ -46,11 +46,12 @@ def wrapped_func(*args, **kwarg): def should_free_input(name, is_moe, config, num_local_experts): """Whether the schedule node named ``name`` can free its input after forward. - The schedule decomposes a transformer layer into ``attn``, ``moe_dispatch``, - ``mlp``, and ``moe_combine`` nodes; the inputs to some of those nodes are - not needed in backward and can be released early to lower peak activation - memory. Dense layers and the ``attn`` node always need their input retained - (the attention residual flows through the post-MLP BDA). + The schedule decomposes a transformer layer into ``pre_dispatch_computation``, + ``moe_dispatch``, ``mlp``, and ``moe_combine`` nodes; the inputs to some of + those nodes are not needed in backward and can be released early to lower + peak activation memory. Dense layers and the ``pre_dispatch_computation`` + node always need their input retained (the attention residual flows through + the post-MLP BDA). Args: name: Schedule node name. @@ -61,7 +62,8 @@ def should_free_input(name, is_moe, config, num_local_experts): Returns: True iff the named node may free its input after forward. """ - # For dense layers [attn, fake, mlp, fake], the input is needed during backward pass + # For dense layers [pre_dispatch_computation, fake, mlp, fake], the input is needed + # during backward pass if not is_moe: return False enable_deepep = ( @@ -208,8 +210,8 @@ def forward_impl(self, hidden_states): class TransformerLayerNode(ScheduleNode): """Schedule node for one slot of a fine-grained transformer layer plan. - Each transformer layer is decomposed into ``attn``, ``moe_dispatch``, - ``mlp``, and ``moe_combine`` slots; this class is the scheduler-side + Each transformer layer is decomposed into ``pre_dispatch_computation``, + ``moe_dispatch``, ``mlp``, and ``moe_combine`` slots; this class is the scheduler-side handle for one slot. It owns the slot's stream / event, the per-slot ``free_input`` policy, and the optional delayed weight-gradient hook. Subclasses override ``_resolve_free_input`` to specialize the policy @@ -321,7 +323,7 @@ def __del__(self): class _BackwardDWWrapper: - """Backward weight-gradient wrapper for the ``attn`` slot of a transformer layer. + """Backward weight-gradient wrapper for the ``pre_dispatch_computation`` slot of a transformer layer. Runs the layer's ``self_attention.backward_dw`` plus, on MoE layers, the shared-expert ``backward_dw``; coordinates with the cuda-graph wgrad diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 0b3a17ea12e..1c32570fa07 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -23,12 +23,15 @@ def build_transformer_layer_callables(layer: TransformerLayer): functions. This decomposition separates computation-heavy tasks (e.g., self-attention, MLP) from communication-heavy tasks (e.g., MoE's All-to-All). - The five callables are: - 1. Attention (computation) - 2. Post-Attention (computation) - 3. MoE Dispatch (communication) - 4. MLP / MoE Experts (computation) - 5. MoE Combine (communication) + The five callables align with the schedule plan's slot order: + 1. pre_dispatch_computation (computation): + attention -> pre-MLP layernorm -> router -> dispatch preprocess. + For dense layers this is just the attention pass. + 2. moe_dispatch (communication): MoE dispatch All-to-All. + 3. mlp / moe_experts (computation): dense MLP or routed-experts compute. + 4. moe_combine (communication): MoE combine All-to-All + post-MLP residual. + 5. mtp_post_process (computation): always ``None`` here; only the MTP + wrapper in ``common/fine_grained_callables.py`` fills this slot. By assigning these functions to different CUDA streams (e.g., a compute stream and a communication stream), the scheduler can overlap their execution, preventing @@ -40,8 +43,11 @@ def build_transformer_layer_callables(layer: TransformerLayer): Returns: A tuple containing: - - forward_funcs: List of callable functions for the layer - - backward_dw: Dict of weight gradient functions for the layer + - forward_funcs: List of 5 callables, one per slot in the schedule plan + (pre_dispatch_computation, moe_dispatch, mlp, moe_combine, + mtp_post_process=None). + - backward_dw: Dict mapping slot name to the delayed-wgrad callable + (keys: "pre_dispatch_computation", "mlp"). """ is_moe = isinstance(layer.mlp, MoELayer) @@ -54,9 +60,9 @@ def build_transformer_layer_callables(layer: TransformerLayer): and layer.config.moe_flex_dispatcher_backend == "hybridep" ) - def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor): + def submodule_pre_dispatch_forward(node: ScheduleNode, hidden_states: torch.Tensor): """ - Performs same attnention forward logic as GPT Model and forward pass for + Performs the same attention forward logic as GPTModel and the forward pass for computations between attention and dispatch: pre mlp layernorm->router->dispatch preprocess """ @@ -154,7 +160,7 @@ def submodule_dispatch_forward( token_dispatcher = layer.mlp.token_dispatcher if enable_deepep or enable_hybridep: # update token_probs to be the detached version, prevents - # backward graph from connecting to attn submodule + # backward graph from connecting to pre_dispatch_computation submodule token_dispatcher._comm_manager.token_probs = probs dispatched_tokens, dispatched_probs = layer.mlp.dispatch(local_tokens, probs) @@ -249,7 +255,7 @@ def raise_not_implemented(*args): raise NotImplementedError("This callable is not implemented for Dense layer.") # Build forward and backward callable functions - pre_dispatch_func = submodule_attn_forward + pre_dispatch_func = submodule_pre_dispatch_forward dispatch_func = submodule_dispatch_forward if is_moe else raise_not_implemented mlp_func = submodule_moe_forward if is_moe else mlp_wrapper combine_func = submodule_combine_forward if is_moe else raise_not_implemented diff --git a/megatron/core/models/hybrid/fine_grained_callables.py b/megatron/core/models/hybrid/fine_grained_callables.py index 44e3912a492..2e70aa3d48c 100644 --- a/megatron/core/models/hybrid/fine_grained_callables.py +++ b/megatron/core/models/hybrid/fine_grained_callables.py @@ -24,13 +24,14 @@ class HybridStackNode(TransformerLayerNode): Subclassed from ``TransformerLayerNode`` so the runtime backbone (forward / backward / backward_dw plumbing, detach bookkeeping, output-grad release) is shared. The hybrid path keeps a separate node class so its free-input - policy can diverge from the GPT defaults — for example, the ``attn`` slot - here covers the whole pre-dispatch loop (mamba + attention + …) rather than - a single attention block, and group-level decisions about whether the input - is needed in backward may differ from ``should_free_input`` in - ``gpt/fine_grained_callables.py``. Keep this override thin until a hybrid - counter-example forces it to diverge; the explicit subclass exists so the - divergence can be made surgically without touching the GPT class. + policy can diverge from the GPT defaults — for example, the + ``pre_dispatch_computation`` slot here covers the whole pre-dispatch loop + (mamba + attention + …) rather than a single attention block, and + group-level decisions about whether the input is needed in backward may + differ from ``should_free_input`` in ``gpt/fine_grained_callables.py``. + Keep this override thin until a hybrid counter-example forces it to + diverge; the explicit subclass exists so the divergence can be made + surgically without touching the GPT class. """ @staticmethod From 8d6c6de7d0c7261e0ec5d8082a0925d82c129efd Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Mon, 11 May 2026 14:30:35 -0700 Subject: [PATCH 12/16] Restore shared-experts wgrad in pre_dispatch slot to fix TE Pop-empty-queue d2daf2925 ('Cleanup hybrid EP-overlap callables...') moved the MoE shared-experts wgrad from the pre_dispatch_computation slot into the mlp slot's bwd_dw_callables list. The commit message argued the relocation was equivalent because TransformerLayerNode.backward_dw iterates the list and calls .backward_dw() on each entry. That equivalence does not hold: TE's delay-wgrad model only puts to the wgrad queue inside the autograd backward (dgrad). Specifically, transformer_engine/.../linear.py calls ctx.wgrad_store.put(...) from inside the autograd Function's backward method, not from forward. So module.backward_dw() requires that module's autograd backward to have already run. For shared experts the relevant autograd backward lives in the pre_dispatch_computation slot (its forward is part of _run_moe_preprocess). The schedule plan calls mlp.backward_dw() before pre_dispatch_computation.backward(), so the shared-experts wgrad queue is still empty when the mlp slot tries to pop it, producing 'RuntimeError: Pop empty queue' from TE's _common.py. Restore the original wiring: register a _SharedExpertBackwardDWWrapper in pre_bwd_dw so the shared-experts wgrad fires after the pre_dispatch slot's autograd backward. The wrapper delegates to mlp.backward_dw(routed_experts=False, shared_experts=True) so the MoELayer-level guards (use_shared_expert / shared_expert_overlap / moe_latent_size + fc1_latent_proj) remain authoritative. Reproduction: DeepSeek-V3-Lite-deter-Hybrid with --hybrid-layer-pattern '*-*-|*-*-|*-*-|*-*-|*E*E|*E*E|*E*E|*E*E', PP=4 EP=2, --overlap-moe-expert-parallel-comm --delay-wgrad-compute, 20 train iterations. Bisect identified d2daf2925 as the first failing commit; 55d1e34f6 ran cleanly with bitwise-identical loss to the eager baseline. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../models/hybrid/fine_grained_callables.py | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/megatron/core/models/hybrid/fine_grained_callables.py b/megatron/core/models/hybrid/fine_grained_callables.py index 2e70aa3d48c..9763cde355c 100644 --- a/megatron/core/models/hybrid/fine_grained_callables.py +++ b/megatron/core/models/hybrid/fine_grained_callables.py @@ -1,6 +1,7 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. from contextlib import nullcontext +from functools import partial from typing import Optional import torch @@ -18,6 +19,32 @@ from megatron.core.transformer.transformer_layer import make_viewless_tensor +class _SharedExpertBackwardDWWrapper: + """Run MoE shared-experts wgrad as part of the ``pre_dispatch_computation`` slot. + + Why: shared-experts forward is part of ``_run_moe_preprocess`` (which runs in the + pre_dispatch slot), and TE's delay-wgrad model only ``put``s to the wgrad queue + inside the autograd backward (dgrad). So shared-experts' ``backward_dw`` must + fire *after* the pre_dispatch slot's autograd backward — registering it in the + ``mlp`` slot would call it before that dgrad and trigger + ``RuntimeError: Pop empty queue`` from TE. + """ + + def __init__(self, layer): + self.layer = layer + self.shared_expert_dw_callable = None + if layer.mlp.use_shared_expert and not layer.mlp.shared_expert_overlap: + self.shared_expert_dw_callable = partial( + layer.mlp.backward_dw, routed_experts=False, shared_experts=True + ) + + def backward_dw(self): + if self.shared_expert_dw_callable is not None: + self.shared_expert_dw_callable() + self.layer = None + self.shared_expert_dw_callable = None + + class HybridStackNode(TransformerLayerNode): """Schedule node for HybridStack-built fine-grained callables. @@ -283,18 +310,15 @@ def raise_not_implemented(*args): # registering the layer directly is sufficient. pre_bwd_dw.append(item_layer) if is_moe: - # MoELayer.backward_dw default kwargs (routed_experts=True, shared_experts=False) handle - # the routed-experts wgrad. The shared-experts wgrad is registered as a sibling callable - # under "mlp" so the schedule node iterates both. Skip registering the shared-experts - # callable when shared_expert_overlap is enabled — in that case the shared-experts - # forward and backward are folded into the dispatcher's overlap handling. - mlp_backward_callables = [terminal_layer.mlp] - if ( - terminal_layer.mlp.use_shared_expert - and not terminal_layer.mlp.shared_expert_overlap - ): - mlp_backward_callables.append(terminal_layer.mlp.shared_experts) - backward_dw["mlp"] = mlp_backward_callables + # MoELayer.backward_dw default kwargs (routed_experts=True, shared_experts=False) + # handle the routed-experts wgrad in the mlp slot. The shared-experts wgrad goes + # into the pre_dispatch_computation slot so it runs after that slot's autograd + # backward (where TE's wgrad_store.put fires); the wrapper is a no-op when + # shared_expert_overlap is enabled. + shared_expert_dw = _SharedExpertBackwardDWWrapper(terminal_layer) + if shared_expert_dw.shared_expert_dw_callable is not None: + pre_bwd_dw.append(shared_expert_dw) + backward_dw["mlp"] = terminal_layer.mlp elif terminal_type == LayerSymbols.MLP: backward_dw["mlp"] = terminal_layer.mlp From 84d195e3f0c7f00c59d6618434617710350038bf Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Tue, 12 May 2026 21:47:11 -0700 Subject: [PATCH 13/16] Materialize _forward_attention output via make_viewless_tensor in grouped pre-dispatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The grouped HybridStack pre_dispatch_computation closure calls TransformerLayer._forward_attention directly for ATTENTION / DSA / GDN pre-layers (rather than __call__, to avoid double-applying the post-attention residual through mlp_bda + _forward_mlp). The downside: _forward_attention returns the raw bias_dropout_add output, which can be a view tensor produced by the fused/JIT BDA kernel. The full forward() path at transformer_layer.py:895 normalizes this via make_viewless_tensor before returning; the shortcut here did not. When the view's non-canonical strides cross the schedule-node boundary into a downstream cuBLAS matmul (the terminal MLP's pre_mlp_layernorm linear, the next attention's linear_qkv, or the MoE router/dispatcher), the matmul kernel's algorithm-selection heuristics can pick a different algo based on stride layout. Different processes invoking the same model with the same seed can land in different deterministic clusters and produce ~1e-5 bit drift on the forward output of grouped layers containing GDN/attention pre-layers — even with the full Paul Gibbons deterministic recipe applied. Reproduction (before this commit): --hybrid-layer-pattern '*-*-|*-*-|*-*-|*-*-|[G-][G-]|[G-][G-]|[G-][G-]|[G-][G-]' Full deterministic recipe: --deterministic-mode true, NVTE_ALLOW_NONDETERMINISTIC_ALGO=0, MAMBA_DETERMINISTIC=1, CAUSAL_CONV1D_DETERMINISTIC=1, TRITON_ENABLE_AUTOTUNE=0, NCCL_ALGO=Ring, --attention-backend flash. Two reruns: iter 1 loss differs by ~1e-5; subsequent iters drift in a bounded 1e-5..1e-4 band. With this commit: iter 1 is bitwise across reruns. Verified on draco 8 H100 GPUs PP=4 EP=2. The Mamba branch is unaffected: it uses __call__, which already routes through the full forward() that ends with make_viewless_tensor. Only the _forward_attention shortcut needed this companion fix. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../core/models/hybrid/fine_grained_callables.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/megatron/core/models/hybrid/fine_grained_callables.py b/megatron/core/models/hybrid/fine_grained_callables.py index 9763cde355c..c1b375c08e2 100644 --- a/megatron/core/models/hybrid/fine_grained_callables.py +++ b/megatron/core/models/hybrid/fine_grained_callables.py @@ -238,6 +238,22 @@ def pre_dispatch_computation(node: ScheduleNode, hidden_states: Tensor): packed_seq_params=node.chunk_state.packed_seq_params, sequence_len_offset=node.chunk_state.sequence_len_offset, ) + # _forward_attention returns the bias_dropout_add output which can be a + # view tensor (the mlp_bda's add into the post-attention residual produces + # a view from a fused/JIT kernel). Downstream cuBLAS matmuls — including + # the terminal MLP/MoE's pre_mlp_layernorm and the next attention's QKV + # projection in a multi-pre-layer group — pick algorithms based on input + # strides; a view's non-canonical strides can lead to different algo + # selection across processes and produce ~1e-5 bit drift on the forward + # output. TransformerLayer's full forward() inserts this exact call at the + # MLP exit (transformer_layer.py:895) for the same reason; the + # _forward_attention shortcut here doesn't get that cleanup, so we add it + # explicitly. Same idea as the make_viewless_tensor in _maybe_apply_final_norm. + hidden_states = make_viewless_tensor( + inp=hidden_states, + requires_grad=hidden_states.requires_grad, + keep_graph=True, + ) else: raise ValueError( f"HybridStack overlap does not support layer type '{item_type}' before " From 5b9e9bf9625e078350faf6703d078c83f758610d Mon Sep 17 00:00:00 2001 From: Pingtian Li Date: Thu, 14 May 2026 09:11:39 +0800 Subject: [PATCH 14/16] revert ckpt final_layernorm for hybrid stack & assertion on cuda graph --- megatron/core/models/hybrid/hybrid_block.py | 16 ++++------------ .../models/hybrid/model_chunk_schedule_plan.py | 7 +++++++ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/megatron/core/models/hybrid/hybrid_block.py b/megatron/core/models/hybrid/hybrid_block.py index 1c0209da0c1..f61764eaccd 100644 --- a/megatron/core/models/hybrid/hybrid_block.py +++ b/megatron/core/models/hybrid/hybrid_block.py @@ -456,7 +456,9 @@ def _sharded_state_dict( if sharded_layer_prefix is None: sharded_layer_prefix = layer_prefix - for local_layer_idx, (layer_type, layer) in enumerate(zip(self.layer_type_list, self.layers)): + for local_layer_idx, (layer_type, layer) in enumerate( + zip(self.layer_type_list, self.layers) + ): state_dict_prefix = f'{layer_prefix}{local_layer_idx}.' logical_layer_idx = ( self.logical_layer_offset @@ -490,18 +492,8 @@ def _sharded_state_dict( for name, module in self.named_children(): if not module is self.layers: module_sharded_state_dict = sharded_state_dict_default( - module, - f'{prefix}{name}.', - sharded_offsets, - metadata, - tp_group=self.tp_group, + module, f'{prefix}{name}.', sharded_offsets, metadata, tp_group=self.tp_group ) - if name == "final_norm": - replace_prefix_for_sharding( - module_sharded_state_dict, - f'{prefix}{name}.', - f'{prefix}final_layernorm.', - ) sharded_state_dict.update(module_sharded_state_dict) return sharded_state_dict diff --git a/megatron/core/models/hybrid/model_chunk_schedule_plan.py b/megatron/core/models/hybrid/model_chunk_schedule_plan.py index cd3dd0cab5b..3fcc405846a 100644 --- a/megatron/core/models/hybrid/model_chunk_schedule_plan.py +++ b/megatron/core/models/hybrid/model_chunk_schedule_plan.py @@ -121,6 +121,13 @@ class HybridStackModelChunkSchedulePlan(TransformerModelChunkSchedulePlan): LAYER_SCHEDULE_PLAN_CLASS = HybridStackSchedulePlan + def init(self, model, *args, **kwargs): + assert model.config.cuda_graph_impl == "none", ( + "EP A2A overlap with grouped HybridStack patterns (e.g. '[*E]') does not " + "support cuda graphs yet. Set cuda_graph_impl='none' or use an ungrouped pattern." + ) + super().init(model, *args, **kwargs) + def _extra_args_for_layer(self, module, layer_idx, num_layers): extra_args = super()._extra_args_for_layer(module, layer_idx, num_layers) extra_args["layer_type"] = ( From 16513159cefa5ca1e6c4b8506067bdf606b7c97c Mon Sep 17 00:00:00 2001 From: Pingtian Li Date: Thu, 14 May 2026 09:11:59 +0800 Subject: [PATCH 15/16] format --- megatron/core/models/common/fine_grained_callables.py | 8 +------- .../core/models/common/model_chunk_schedule_plan.py | 10 ++-------- megatron/core/models/gpt/fine_grained_callables.py | 6 +----- megatron/core/models/hybrid/fine_grained_callables.py | 7 ++----- megatron/core/models/hybrid/hybrid_model.py | 6 +----- tests/unit_tests/models/test_hybrid_model.py | 5 +---- tests/unit_tests/ssm/test_hybrid_block.py | 9 ++------- tests/unit_tests/ssm/test_hybrid_layer_allocation.py | 9 +-------- 8 files changed, 11 insertions(+), 49 deletions(-) diff --git a/megatron/core/models/common/fine_grained_callables.py b/megatron/core/models/common/fine_grained_callables.py index 9a100003de6..52a9d738eee 100644 --- a/megatron/core/models/common/fine_grained_callables.py +++ b/megatron/core/models/common/fine_grained_callables.py @@ -39,13 +39,7 @@ def build_mtp_layer_callables(layer): forward_funcs, backward_dw, is_moe, num_local_experts = build_layer_callables( layer.mtp_model_layer ) - ( - pre_dispatch_forward, - dispatch_forward, - mlp_forward, - combine_forward, - _, - ) = forward_funcs + (pre_dispatch_forward, dispatch_forward, mlp_forward, combine_forward, _) = forward_funcs assert is_moe, "MTP layer in a2a overlap only supports MoE layer for now." def submodule_mtp_pre_dispatch_forward(node, hidden_states): diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index 1b57537f27a..a2d167e60a9 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -88,10 +88,7 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar def release_state(self): """Release reference, this helps avoid memory leak.""" - if ( - hasattr(self, 'pre_dispatch_computation') - and self.pre_dispatch_computation is not None - ): + if hasattr(self, 'pre_dispatch_computation') and self.pre_dispatch_computation is not None: del self.pre_dispatch_computation self.pre_dispatch_computation = None if hasattr(self, 'moe_dispatch') and self.moe_dispatch is not None: @@ -398,10 +395,7 @@ def _extra_args_for_layer(self, module, layer_idx, num_layers): Subclasses extend this hook to thread additional metadata (e.g. hybrid layer-type symbols) without overriding ``_build_layer_schedule_plan``. """ - return { - "is_first_layer": layer_idx == 0, - "is_last_layer": layer_idx == num_layers - 1, - } + return {"is_first_layer": layer_idx == 0, "is_last_layer": layer_idx == num_layers - 1} @property def event(self): diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 1c32570fa07..09772351570 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -263,9 +263,5 @@ def raise_not_implemented(*args): layer.init_backward_dw_wrapper() forward_funcs = [pre_dispatch_func, dispatch_func, mlp_func, combine_func, None] - backward_dw = { - "pre_dispatch_computation": layer.backward_dw_wrapper, - "mlp": layer.mlp, - } + backward_dw = {"pre_dispatch_computation": layer.backward_dw_wrapper, "mlp": layer.mlp} return forward_funcs, backward_dw - diff --git a/megatron/core/models/hybrid/fine_grained_callables.py b/megatron/core/models/hybrid/fine_grained_callables.py index c1b375c08e2..c451b377e93 100644 --- a/megatron/core/models/hybrid/fine_grained_callables.py +++ b/megatron/core/models/hybrid/fine_grained_callables.py @@ -101,9 +101,7 @@ def _maybe_apply_final_norm(node: ScheduleNode, hidden_states: Tensor): final_norm = final_norm or getattr(node.chunk_state.model.decoder, "final_layernorm", None) if not node.is_mtp and final_norm is not None and node.is_last_layer: hidden_states = final_norm(hidden_states) - hidden_states = make_viewless_tensor( - inp=hidden_states, requires_grad=True, keep_graph=True - ) + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) return hidden_states @@ -293,8 +291,7 @@ def mlp(node: ScheduleNode, hidden_states: Tensor): if terminal_type == LayerSymbols.MLP: with _get_inner_quant_context(terminal_layer): hidden_states = terminal_layer._forward_mlp( - hidden_states, - padding_mask=node.chunk_state.padding_mask, + hidden_states, padding_mask=node.chunk_state.padding_mask ) return _maybe_apply_final_norm(node, hidden_states) if terminal_type == LayerSymbols.MOE: diff --git a/megatron/core/models/hybrid/hybrid_model.py b/megatron/core/models/hybrid/hybrid_model.py index 33415b2a502..fe2ba14cf72 100644 --- a/megatron/core/models/hybrid/hybrid_model.py +++ b/megatron/core/models/hybrid/hybrid_model.py @@ -487,11 +487,7 @@ def _postprocess( # separately, when running inference, or when speculative decoding is # active). ``self.mtp_process`` guards against models built without an # MTP block. - if ( - mtp_in_postprocess - and self.mtp_process - and not (in_inference_mode or is_spec_decode) - ): + if mtp_in_postprocess and self.mtp_process and not (in_inference_mode or is_spec_decode): hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, diff --git a/tests/unit_tests/models/test_hybrid_model.py b/tests/unit_tests/models/test_hybrid_model.py index 902942f7141..43c18e3adf9 100644 --- a/tests/unit_tests/models/test_hybrid_model.py +++ b/tests/unit_tests/models/test_hybrid_model.py @@ -201,10 +201,7 @@ def test_save_load(self, tmp_path): def test_grouped_sharded_state_dict_uses_transformer_checkpoint_keys(self): """Grouped HybridModel checkpoints should be load-compatible with GPTModel keys.""" model_config = TransformerConfig( - num_layers=2, - hidden_size=256, - num_attention_heads=4, - use_cpu_initialization=True, + num_layers=2, hidden_size=256, num_attention_heads=4, use_cpu_initialization=True ) model = HybridModel( config=model_config, diff --git a/tests/unit_tests/ssm/test_hybrid_block.py b/tests/unit_tests/ssm/test_hybrid_block.py index adaf7479017..1ac11d2ed4f 100644 --- a/tests/unit_tests/ssm/test_hybrid_block.py +++ b/tests/unit_tests/ssm/test_hybrid_block.py @@ -244,15 +244,10 @@ def test_group_forward_matches_equivalent_flat_layers(self): sequence_length = 16 micro_batch_size = 2 hidden_states = torch.randn( - sequence_length, - micro_batch_size, - flat_block.config.hidden_size, - device="cuda", + sequence_length, micro_batch_size, flat_block.config.hidden_size, device="cuda" ) attention_mask = torch.ones( - (micro_batch_size, 1, sequence_length, sequence_length), - dtype=bool, - device="cuda", + (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool, device="cuda" ) with torch.no_grad(): diff --git a/tests/unit_tests/ssm/test_hybrid_layer_allocation.py b/tests/unit_tests/ssm/test_hybrid_layer_allocation.py index 896fb0bddb4..4b721609804 100644 --- a/tests/unit_tests/ssm/test_hybrid_layer_allocation.py +++ b/tests/unit_tests/ssm/test_hybrid_layer_allocation.py @@ -373,14 +373,7 @@ def test_moe_pattern(self): assert get_hybrid_layer_counts("MEME") == {'*': 0, 'D': 0, 'G': 0, 'M': 2, '-': 0, 'E': 2} def test_group_pattern(self): - assert get_hybrid_layer_counts("M[M*]E") == { - '*': 1, - 'D': 0, - 'G': 0, - 'M': 2, - '-': 0, - 'E': 1, - } + assert get_hybrid_layer_counts("M[M*]E") == {'*': 1, 'D': 0, 'G': 0, 'M': 2, '-': 0, 'E': 1} def test_mtp_with_attention(self): # MTP pattern "*M" repeated 3 depths -> 3 attn + 3 mamba from MTP From d0586c64c9bcc7cfc0b5426cd60bc5a385d55827 Mon Sep 17 00:00:00 2001 From: Pingtian Li Date: Thu, 14 May 2026 10:25:57 +0800 Subject: [PATCH 16/16] format --- megatron/core/models/common/utils.py | 5 ++++- megatron/core/models/hybrid/fine_grained_callables.py | 1 + megatron/core/models/hybrid/model_chunk_schedule_plan.py | 4 +++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/megatron/core/models/common/utils.py b/megatron/core/models/common/utils.py index b231678ba5a..59bf2e40423 100644 --- a/megatron/core/models/common/utils.py +++ b/megatron/core/models/common/utils.py @@ -134,6 +134,7 @@ def __init__(self, model, chunk_state, event, stream): self.chunk_state = chunk_state def forward_impl(self): + """Run model preprocessing and store chunk-level inputs for layer nodes.""" if not self.model.pre_process: self.chunk_state.decoder_input = self.model.decoder.input_tensor ( @@ -177,6 +178,7 @@ def __init__(self, model, chunk_state, event, stream): self.chunk_state = chunk_state def forward_impl(self, hidden_states): + """Run model postprocessing for the chunk's final hidden states.""" empty_decoder = len(self.model.decoder.layers) == 0 layer_norm = self.model.decoder.final_layernorm if not self.model.config.mtp_num_layers and empty_decoder and layer_norm: @@ -323,7 +325,7 @@ def __del__(self): class _BackwardDWWrapper: - """Backward weight-gradient wrapper for the ``pre_dispatch_computation`` slot of a transformer layer. + """Backward weight-gradient wrapper for a transformer pre-dispatch slot. Runs the layer's ``self_attention.backward_dw`` plus, on MoE layers, the shared-expert ``backward_dw``; coordinates with the cuda-graph wgrad @@ -352,6 +354,7 @@ def __init__(self, layer): self.cuda_graph_scope = layer.config.cuda_graph_scope def backward_dw(self): + """Run eager or graphed backward wgrad callables for the wrapped layer.""" is_replay = hasattr(self.layer, 'cuda_graphs') and self.layer.cuda_graphs if self.shared_expert_dw_callable is not None and ( not is_replay or CudaGraphScope.moe_router not in self.cuda_graph_scope diff --git a/megatron/core/models/hybrid/fine_grained_callables.py b/megatron/core/models/hybrid/fine_grained_callables.py index c451b377e93..0ea38ef6d34 100644 --- a/megatron/core/models/hybrid/fine_grained_callables.py +++ b/megatron/core/models/hybrid/fine_grained_callables.py @@ -39,6 +39,7 @@ def __init__(self, layer): ) def backward_dw(self): + """Run shared-expert backward wgrad after pre-dispatch autograd backward.""" if self.shared_expert_dw_callable is not None: self.shared_expert_dw_callable() self.layer = None diff --git a/megatron/core/models/hybrid/model_chunk_schedule_plan.py b/megatron/core/models/hybrid/model_chunk_schedule_plan.py index 3fcc405846a..195895eee02 100644 --- a/megatron/core/models/hybrid/model_chunk_schedule_plan.py +++ b/megatron/core/models/hybrid/model_chunk_schedule_plan.py @@ -98,6 +98,7 @@ def create_node(stream, module, name): self.mtp_post_process = NoopScheduleNode() def get_fp8_context(self): + """Return an FP8 context only for plain transformer layers.""" # Grouped hybrid layers (and inferred-layer-type entries that point at # a HybridStack rather than a plain TransformerLayer) don't have a # ``layer_number`` we can hand to ``get_fp8_context``; the inner layers @@ -121,7 +122,8 @@ class HybridStackModelChunkSchedulePlan(TransformerModelChunkSchedulePlan): LAYER_SCHEDULE_PLAN_CLASS = HybridStackSchedulePlan - def init(self, model, *args, **kwargs): + def __init__(self, model, *args, **kwargs): + """Initialize the hybrid chunk plan after validating cuda graph support.""" assert model.config.cuda_graph_impl == "none", ( "EP A2A overlap with grouped HybridStack patterns (e.g. '[*E]') does not " "support cuda graphs yet. Set cuda_graph_impl='none' or use an ungrouped pattern."