diff --git a/megatron/core/models/common/fine_grained_callables.py b/megatron/core/models/common/fine_grained_callables.py new file mode 100644 index 00000000000..52a9d738eee --- /dev/null +++ b/megatron/core/models/common/fine_grained_callables.py @@ -0,0 +1,145 @@ +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Layer-callable builders for the combined-1F1B fine-grained schedule plan. + +These build_* functions assemble the per-layer ``(forward_funcs, backward_dw)`` +tuple that the schedule plan plugs into ``TransformerLayerNode``. + +The TransformerLayer-specific builder lives in ``gpt/fine_grained_callables.py`` +because it depends on GPT's MoE wiring; the MTP builder and the dispatcher +``build_layer_callables`` are model-agnostic — both GPTModel and HybridModel +schedule MTP layers identically — so they live here. +""" + +from contextlib import nullcontext +from functools import partial + +import torch + +from megatron.core import tensor_parallel +from megatron.core.models.gpt.fine_grained_callables import build_transformer_layer_callables +from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.multi_token_prediction import ( + MultiTokenPredictionLayer, + get_mtp_layer_offset, +) +from megatron.core.transformer.transformer_layer import TransformerLayer + + +def build_mtp_layer_callables(layer): + """Callables for multi-token prediction layer nodes. + + Wraps the inner ``layer.mtp_model_layer``'s callables with MTP-specific + pre-process (chunk and concat embeddings) and post-process (gather across + depths) steps. The inner layer is built by ``build_layer_callables`` so + that ``mtp_model_layer`` can be a TransformerLayer (today's case) or a + HybridStack (when an MTP depth uses the hybrid layout). + """ + + forward_funcs, backward_dw, is_moe, num_local_experts = build_layer_callables( + layer.mtp_model_layer + ) + (pre_dispatch_forward, dispatch_forward, mlp_forward, combine_forward, _) = forward_funcs + assert is_moe, "MTP layer in a2a overlap only supports MoE layer for now." + + def submodule_mtp_pre_dispatch_forward(node, hidden_states): + # MTP Block Preprocess + if node.is_first_layer: + offset = get_mtp_layer_offset(layer.config, node.chunk_state.model.vp_stage) + node.chunk_state.mtp_hidden_states = list(torch.chunk(hidden_states, 1 + offset, dim=0)) + hidden_states = node.chunk_state.mtp_hidden_states[offset] + + input_ids, position_ids, decoder_input, hidden_states = layer._get_embeddings( + input_ids=node.chunk_state.input_ids, + position_ids=node.chunk_state.position_ids, + embedding=node.chunk_state.model.embedding, + hidden_states=hidden_states, + ) + node.chunk_state.input_ids = input_ids + node.chunk_state.position_ids = position_ids + + # MTP Layer Preprocess + # norm, linear projection and transformer + assert ( + node.chunk_state.context is None + ), f"multi token prediction + cross attention is not yet supported." + assert ( + node.chunk_state.packed_seq_params is None + ), f"multi token prediction + sequence packing is not yet supported." + + if layer.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # fp8 context is added in 1f1b schedule, so we don't need to add it here + with rng_context: + hidden_states = layer._concat_embeddings(hidden_states, decoder_input) + return pre_dispatch_forward(node, hidden_states) + + def submodule_mtp_postprocess_forward(node, hidden_states): + hidden_states = layer._postprocess(hidden_states) + node.chunk_state.mtp_hidden_states.append(hidden_states) + if node.is_last_layer: + hidden_states = torch.cat(node.chunk_state.mtp_hidden_states, dim=0) + node.chunk_state.mtp_hidden_states = None + return hidden_states + + def rng_context_wrapper(func, *args, **kwargs): + """ + Wrapper to add rng context to submodule callables + """ + if layer.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + with rng_context: + return func(*args, **kwargs) + + # Build forward and backward callable functions. + # pre_dispatch_func already has rng context (rolled into + # submodule_mtp_pre_dispatch_forward), so it does not need to be wrapped. + pre_dispatch_func = submodule_mtp_pre_dispatch_forward + dispatch_func = partial(rng_context_wrapper, dispatch_forward) + mlp_func = partial(rng_context_wrapper, mlp_forward) + combine_func = partial(rng_context_wrapper, combine_forward) + mtp_post_process_func = submodule_mtp_postprocess_forward + + forward_funcs = [ + pre_dispatch_func, + dispatch_func, + mlp_func, + combine_func, + mtp_post_process_func, + ] + pre_dispatch_bwd = backward_dw["pre_dispatch_computation"] + if isinstance(pre_dispatch_bwd, list): + pre_dispatch_bwd.append(layer.eh_proj) + else: + backward_dw["pre_dispatch_computation"] = [pre_dispatch_bwd, layer.eh_proj] + + return forward_funcs, backward_dw, is_moe, num_local_experts + + +def build_layer_callables(layer): + """Dispatch to the appropriate layer-callable builder. + + Returns ``(forward_funcs, backward_dw, is_moe, num_local_experts)`` so the + schedule plan does not need to re-derive ``is_moe`` / + ``num_local_experts`` after the call — the build function already knows + the layer type. ``num_local_experts`` is ``None`` for dense layers. + """ + from megatron.core.models.hybrid.fine_grained_callables import build_hybrid_stack_callables + from megatron.core.models.hybrid.hybrid_block import HybridStack + + if isinstance(layer, HybridStack): + return build_hybrid_stack_callables(layer) + if isinstance(layer, MultiTokenPredictionLayer): + return build_mtp_layer_callables(layer) + if isinstance(layer, TransformerLayer): + forward_funcs, backward_dw = build_transformer_layer_callables(layer) + is_moe = isinstance(layer.mlp, MoELayer) + num_local_experts = layer.mlp.num_local_experts if is_moe else None + return forward_funcs, backward_dw, is_moe, num_local_experts + + raise ValueError(f"Unsupported layer type: {type(layer)}") diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index 9032d337e00..a2d167e60a9 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -35,20 +35,23 @@ class TransformerLayerSchedulePlan: mtp post process nodes. layer (TransformerLayerSchedulePlan) - ├── attn (TransformerLayerNode): attention -> layernorm -> router -> dispatch preprocess + ├── pre_dispatch_computation (TransformerLayerNode): + │ attention -> layernorm -> router -> dispatch preprocess ├── moe_dispatch (TransformerLayerNode): dispatch All2All ├── mlp (TransformerLayerNode): mlp module ├── moe_combine (TransformerLayerNode): combine All2All └── mtp_post_process (PostProcessNode): mtp post process Note that MTP layer has the same operation and execution order with TransformerLayer regarding - moe_dispatch, mlp, moe_combine, but contains extra operations in attn and mtp_post_process: - * mtp.attn wraps around transformer_layer.attn with extra norm, proj and embedding operations. + moe_dispatch, mlp, moe_combine, but contains extra operations in + pre_dispatch_computation and mtp_post_process: + * mtp.pre_dispatch_computation wraps around transformer_layer.pre_dispatch_computation with + extra norm, proj and embedding operations. * mtp.mtp_post_process contains output_layer, mtp loss operations, whereas transformer_layer.mtp_post_process is empty. """ - attn = None + pre_dispatch_computation = None moe_dispatch = None mlp = None moe_combine = None @@ -70,10 +73,10 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar The event and chunk_state are binded to the TransformerModelChunkSchedulePlan and shared across all layers in the model chunk. """ - from megatron.core.models.gpt.fine_grained_callables import TransformerLayerState + from megatron.core.models.common.utils import LayerState self.config = layer.config - self.layer_state = TransformerLayerState() + self.layer_state = LayerState() self.chunk_state = chunk_state self.layer = layer self.event = event @@ -85,9 +88,9 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar def release_state(self): """Release reference, this helps avoid memory leak.""" - if hasattr(self, 'attn') and self.attn is not None: - del self.attn - self.attn = None + if hasattr(self, 'pre_dispatch_computation') and self.pre_dispatch_computation is not None: + del self.pre_dispatch_computation + self.pre_dispatch_computation = None if hasattr(self, 'moe_dispatch') and self.moe_dispatch is not None: del self.moe_dispatch self.moe_dispatch = None @@ -109,23 +112,20 @@ def release_state(self): def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args): """ Builds the callable nodes for the transformer/mtp layer: - attn, mlp, moe_dispatch and moe_combine, and mtp_post_process. + pre_dispatch_computation, moe_dispatch, mlp, moe_combine, + and mtp_post_process. """ - from megatron.core.models.gpt.fine_grained_callables import ( - TransformerLayerNode, - build_layer_callables, - ) - from megatron.core.transformer.moe.moe_layer import MoELayer + from megatron.core.models.common.fine_grained_callables import build_layer_callables + from megatron.core.models.common.utils import TransformerLayerNode from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer - # build the forward and backward callables for the transformer/mtp layer - fwd_callables, bwd_dw_callable_map = build_layer_callables(self.layer) + # The dispatcher returns is_moe / num_local_experts directly since it + # already knows the layer type (saves a separate isinstance dance here). + fwd_callables, bwd_dw_callable_map, is_moe, num_local_experts = build_layer_callables( + self.layer + ) - # get flags for latter use is_mtp = isinstance(self.layer, MultiTokenPredictionLayer) - transformer_layer = self.layer.mtp_model_layer if is_mtp else self.layer - is_moe = isinstance(transformer_layer.mlp, MoELayer) - num_local_experts = transformer_layer.mlp.num_local_experts if is_moe else None extra_args["config"] = self.layer.config extra_args["is_moe"] = is_moe @@ -148,7 +148,7 @@ def create_node(stream, module, name): ) ( - attn_module, + pre_dispatch_module, moe_dispatch_module, mlp_module, moe_combine_module, @@ -157,7 +157,9 @@ def create_node(stream, module, name): # Create nodes for different operations in the layer # Each node type has a predefined name that determines its memory strategy - self.attn = create_node(comp_stream, attn_module, "attn") + self.pre_dispatch_computation = create_node( + comp_stream, pre_dispatch_module, "pre_dispatch_computation" + ) self.mlp = create_node(comp_stream, mlp_module, "mlp") if is_moe: self.moe_dispatch = create_node(comm_stream, moe_dispatch_module, "moe_dispatch") @@ -191,12 +193,12 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False) """Schedule one-forward-one-backward operations for a single transformer layer. This function interleaves forward and backward operations, overlapping the communications - (dispatch or combine) of one with the computations (att or mlp) of the other + (dispatch or combine) of one with the computations (pre_dispatch or mlp) of the other to maximize parallelism and efficiency. When f_layer and b_layer are not None, forward and backward pass are overlapped as follows: - comm_stream: combine_bwd | dispatch_fwd->dispatch_bwd | combine_fwd - comp_stream: attn_fwd | mlp_bwd->mlp_bwd_dw->mlp_fwd| attn_bwd + comm_stream: combine_bwd | dispatch_fwd->dispatch_bwd | combine_fwd + comp_stream: pre_dispatch_fwd | mlp_bwd->mlp_bwd_dw->mlp_fwd| pre_dispatch_bwd For MTP, mtp_post_process_fwd is executed after the combine_fwd in the comp_stream, and mtp_post_process_bwd is executed before the combine_bwd in the comp_stream. @@ -218,7 +220,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False) if f_layer is not None: with f_layer.get_fp8_context(): - f_input = f_layer.attn.forward(f_input) + f_input = f_layer.pre_dispatch_computation.forward(f_input) if b_layer is not None: b_grad = b_layer.mlp.backward(b_grad) @@ -232,7 +234,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False) b_grad = b_layer.moe_dispatch.backward(b_grad) if b_layer is not None and b_layer.config.ep_overlap_early_attn_memory_release: - b_grad = b_layer.attn.backward(b_grad) + b_grad = b_layer.pre_dispatch_computation.backward(b_grad) if f_layer is not None: with f_layer.get_fp8_context(): @@ -244,12 +246,12 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False) f_input = f_layer.mtp_post_process.forward(f_input) if b_layer is not None and not b_layer.config.ep_overlap_early_attn_memory_release: - b_grad = b_layer.attn.backward(b_grad) + b_grad = b_layer.pre_dispatch_computation.backward(b_grad) - # Delay the last attn_dw in backward pass (attn_dw of the first layer) - # for overlapping with the p2p comm + # Delay the last pre_dispatch_computation wgrad in backward pass (wgrad + # of the first layer) for overlapping with the p2p comm. if b_layer is not None and not is_last_layer_in_bwd: - b_layer.attn.backward_dw() + b_layer.pre_dispatch_computation.backward_dw() return f_input, b_grad @@ -267,8 +269,27 @@ class TransformerModelChunkSchedulePlan(AbstractSchedulePlan): │ ├── layer[1]: TransformerLayerSchedulePlan │ └── ... └── post_process: PostProcessNode + + Subclasses can swap the per-layer schedule plan by overriding the + ``LAYER_SCHEDULE_PLAN_CLASS`` class attribute (e.g. HybridStack uses a + layer plan that understands grouped/inferred layer types). They can also + swap the pre/post-process node classes via ``PRE_PROCESS_NODE_CLASS`` / + ``POST_PROCESS_NODE_CLASS`` so each model owns its own embedding / output + layer node implementations. """ + #: The TransformerLayerSchedulePlan-compatible class used to build per-layer + #: schedule plans. Subclasses override this to inject a layer-plan variant. + LAYER_SCHEDULE_PLAN_CLASS = None + + #: Pre/post-process node classes. Defaults below pull in the GPT-side + #: ``PreProcessNode`` / ``PostProcessNode`` (which call ``GPTModel._preprocess`` / + #: ``GPTModel._postprocess``). Subclasses set these to model-specific node + #: classes so the node calls the right model's ``_preprocess`` / + #: ``_postprocess`` methods. + PRE_PROCESS_NODE_CLASS = None + POST_PROCESS_NODE_CLASS = None + def __init__( self, model, @@ -303,7 +324,10 @@ def __init__( Returns: The model chunk schedule plan. """ - from megatron.core.models.gpt.fine_grained_callables import PostProcessNode, PreProcessNode + from megatron.core.models.common.utils import PostProcessNode, PreProcessNode + + pre_process_cls = self.PRE_PROCESS_NODE_CLASS or PreProcessNode + post_process_cls = self.POST_PROCESS_NODE_CLASS or PostProcessNode self._model_chunk_state = ModelChunkState() self._transformer_layers = [] @@ -330,7 +354,7 @@ def __init__( self._model_chunk_state.attention_bias = None # build preprocess - self.pre_process = PreProcessNode( + self.pre_process = pre_process_cls( model, self._model_chunk_state, self._event, get_comp_stream ) @@ -344,20 +368,18 @@ def __init__( # build post process if model.post_process: - self.post_process = PostProcessNode( + self.post_process = post_process_cls( model, self._model_chunk_state, self._event, get_comp_stream ) def _build_layer_schedule_plan(self, module, comp_stream, comm_stream): if module is None: return + plan_cls = self.LAYER_SCHEDULE_PLAN_CLASS or TransformerLayerSchedulePlan num_layers = len(module.layers) for layer_idx in range(num_layers): - extra_args = { - "is_first_layer": layer_idx == 0, - "is_last_layer": layer_idx == num_layers - 1, - } - layer_plan = TransformerLayerSchedulePlan( + extra_args = self._extra_args_for_layer(module, layer_idx, num_layers) + layer_plan = plan_cls( module.layers[layer_idx], self.event, self.state, @@ -367,6 +389,14 @@ def _build_layer_schedule_plan(self, module, comp_stream, comm_stream): ) self._transformer_layers.append(layer_plan) + def _extra_args_for_layer(self, module, layer_idx, num_layers): + """Per-layer ``extra_args`` dict passed to the layer plan constructor. + + Subclasses extend this hook to thread additional metadata (e.g. hybrid + layer-type symbols) without overriding ``_build_layer_schedule_plan``. + """ + return {"is_first_layer": layer_idx == 0, "is_last_layer": layer_idx == num_layers - 1} + @property def event(self): """Gets the CUDA event for synchronization.""" @@ -509,22 +539,22 @@ def run( if f_schedule_plan is not None and post_forward is not None: # post_forward()/send_forward_recv_forward() is running in the communication stream, - # so the p2p comm could be overlapped with the attn backward + # so the p2p comm could be overlapped with the pre_dispatch backward with torch.cuda.stream(get_comm_stream()): f_schedule_plan.wait_current_stream() post_forward(f_input, f_schedule_plan.vp_stage) # post_backward()/send_backward_recv_backward() is running in the computation stream, - # so the p2p comm could be overlapped with the wgrad of attn backward + # so the p2p comm could be overlapped with the wgrad of pre_dispatch backward if b_schedule_plan is not None and post_backward is not None: b_schedule_plan.wait_current_stream() post_backward(b_grad, b_schedule_plan.vp_stage) - # Delay the last attn_dw in backward pass (attn_dw of the first layer) - # for overlapping with the p2p comm + # Delay the last pre_dispatch_computation wgrad in backward pass (wgrad + # of the first layer) for overlapping with the p2p comm. if b_num_layers > 0: assert b_layer is not None - b_layer.attn.backward_dw() + b_layer.pre_dispatch_computation.backward_dw() b_layer.release_state() # post process forward diff --git a/megatron/core/models/common/utils.py b/megatron/core/models/common/utils.py new file mode 100644 index 00000000000..59bf2e40423 --- /dev/null +++ b/megatron/core/models/common/utils.py @@ -0,0 +1,371 @@ +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Schedule-plan helpers shared by GPTModel and HybridModel. + +These pieces used to live in ``core/models/gpt/fine_grained_callables.py`` and +were imported by ``core/models/common/model_chunk_schedule_plan.py`` and the +hybrid schedule plan via that path. They are model-agnostic in practice — the +``Pre/PostProcessNode`` classes call the model's ``_preprocess`` / +``_postprocess`` methods and don't otherwise care which model implements +them — so they live here and the GPT module re-exports them for backward +compatibility with existing imports. +""" + +import weakref +from functools import partial +from typing import Callable + +import torch + +from megatron.core.pipeline_parallel.utils import ScheduleNode, make_viewless +from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.module import GraphableMegatronModule, float16_to_fp32 +from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor +from megatron.core.utils import internal_api, nvtx_range_pop, nvtx_range_push + + +def weak_method(method): + """Wrap ``method`` in a weakref-keyed dispatcher to break refcycles. + + ``ScheduleNode`` keeps a reference to the bound forward / backward functions + of every node in the plan; using a strong reference would keep the layer + plan (and the model chunk through it) alive after the iteration completes. + The ``weakref.WeakMethod`` lets the schedule plan be torn down between + iterations without manual ``del`` chains. + """ + method_ref = weakref.WeakMethod(method) + del method + + def wrapped_func(*args, **kwarg): + return method_ref()(*args, **kwarg) + + return wrapped_func + + +@internal_api +def should_free_input(name, is_moe, config, num_local_experts): + """Whether the schedule node named ``name`` can free its input after forward. + + The schedule decomposes a transformer layer into ``pre_dispatch_computation``, + ``moe_dispatch``, ``mlp``, and ``moe_combine`` nodes; the inputs to some of + those nodes are not needed in backward and can be released early to lower + peak activation memory. Dense layers and the ``pre_dispatch_computation`` + node always need their input retained (the attention residual flows through + the post-MLP BDA). + + Args: + name: Schedule node name. + is_moe: True for MoE layers; dense layers always retain inputs. + config: ``TransformerConfig`` for the layer. + num_local_experts: Local expert count on this rank (None for dense). + + Returns: + True iff the named node may free its input after forward. + """ + # For dense layers [pre_dispatch_computation, fake, mlp, fake], the input is needed + # during backward pass + if not is_moe: + return False + enable_deepep = ( + config.moe_token_dispatcher_type == "flex" + and config.moe_flex_dispatcher_backend == "deepep" + ) + enable_hybridep = ( + config.moe_token_dispatcher_type == "flex" + and config.moe_flex_dispatcher_backend == "hybridep" + ) + # Define which nodes should free input memory. + # Since we split the computing graph into multiple nodes, we can manually control + # when and how to free the input memory. + # The input and output of A2A are not needed anymore after the forward pass, + # so we can free the input memory after the forward pass. + + # When low precision fp8/4 is enabled, the casted tensors are saved and the + # original bf16 tensors are safe to be freed. + free_mlp = config.fp8 is not None or config.fp4 is not None + if not free_mlp: + # AlltoAll dispatcher with local_num_experts=1 and HybridEP both use identity + # operation for `dispatch_postprocess`, hence the mlp inputs will be directly + # passed to GroupedGemm and should be saved for backward pass. + free_mlp = num_local_experts > 1 or config.moe_token_dispatcher_type != "alltoall" + free_mlp = free_mlp and not enable_hybridep + + free_input_nodes = { + "mlp": free_mlp, + "moe_combine": True, + # For non-DeepEP and non-HybridEP dispatcher mode, the input is the un-dispatched + # tokens and probs before dispatch A2A and it's not needed anymore after the + # forward pass. For DeepEP and HybridEP dispatcher mode, they are both needed in + # backward pass and cannot be freed. + # If moe_preprocess is in cuda graph scope, tokens and probs are fixed size + # tensors, so they cannot be freed. + "moe_dispatch": not (enable_deepep or enable_hybridep) + and (CudaGraphScope.moe_preprocess not in config.cuda_graph_scope), + } + + return free_input_nodes.get(name, False) + + +class LayerState: + """State shared between the schedule nodes that come from one logical layer. + + Empty placeholder; nodes attach their own attributes (residual, dispatched + probs, shared-expert outputs) for downstream nodes in the same layer to + consume. Kept as a real class so weakrefs work uniformly. + """ + + pass + + +class PreProcessNode(ScheduleNode): + """Run the model's ``_preprocess`` (embedding + rotary + padding mask). + + The schedule plan wraps a model that exposes a ``_preprocess`` method + returning the canonical 6-tuple ``(decoder_input, rotary_pos_emb, + rotary_pos_cos, rotary_pos_sin, sequence_len_offset, padding_mask)`` + (slots a given model doesn't use are returned as ``None``). The chunk + state is mutated in-place so layer nodes can read the same fields by + name. + """ + + def __init__(self, model, chunk_state, event, stream): + super().__init__(weak_method(self.forward_impl), stream, event, name="pre_process") + self.model = model + self.chunk_state = chunk_state + + def forward_impl(self): + """Run model preprocessing and store chunk-level inputs for layer nodes.""" + if not self.model.pre_process: + self.chunk_state.decoder_input = self.model.decoder.input_tensor + ( + decoder_input, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + padding_mask, + ) = self.model._preprocess( + input_ids=self.chunk_state.input_ids, + position_ids=self.chunk_state.position_ids, + decoder_input=self.chunk_state.decoder_input, + packed_seq_params=self.chunk_state.packed_seq_params, + padding_mask=self.chunk_state.padding_mask, + ) + + self.chunk_state.decoder_input = decoder_input + self.chunk_state.rotary_pos_emb = rotary_pos_emb + self.chunk_state.rotary_pos_cos = rotary_pos_cos + self.chunk_state.rotary_pos_sin = rotary_pos_sin + self.chunk_state.sequence_len_offset = sequence_len_offset + self.chunk_state.padding_mask = padding_mask + return decoder_input + + +class PostProcessNode(ScheduleNode): + """Run the model's ``_postprocess`` (final norm, output layer, loss). + + Calls ``_postprocess`` with ``mtp_in_postprocess=False`` because the + schedule plan handles MTP layers as sibling layer nodes inside the same + chunk; the model's MTP block is not invoked here. The optional final + layernorm — applied only when this rank holds an empty decoder shard + (early stage of pipeline parallel) — is handled here so the chunk plan + does not need a separate node for it. + """ + + def __init__(self, model, chunk_state, event, stream): + super().__init__(weak_method(self.forward_impl), stream, event, name="post_process") + self.model = model + self.chunk_state = chunk_state + + def forward_impl(self, hidden_states): + """Run model postprocessing for the chunk's final hidden states.""" + empty_decoder = len(self.model.decoder.layers) == 0 + layer_norm = self.model.decoder.final_layernorm + if not self.model.config.mtp_num_layers and empty_decoder and layer_norm: + hidden_states = layer_norm(hidden_states) + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + + loss = self.model._postprocess( + hidden_states=hidden_states, + input_ids=self.chunk_state.input_ids, + position_ids=self.chunk_state.position_ids, + labels=self.chunk_state.labels, + decoder_input=self.chunk_state.decoder_input, + rotary_pos_emb=self.chunk_state.rotary_pos_emb, + rotary_pos_cos=self.chunk_state.rotary_pos_cos, + rotary_pos_sin=self.chunk_state.rotary_pos_sin, + mtp_in_postprocess=False, + loss_mask=self.chunk_state.loss_mask, + attention_mask=self.chunk_state.attention_mask, + packed_seq_params=self.chunk_state.packed_seq_params, + sequence_len_offset=self.chunk_state.sequence_len_offset, + runtime_gather_output=self.chunk_state.runtime_gather_output, + extra_block_kwargs=self.chunk_state.extra_block_kwargs, + ) + + # combined-1F1B currently expects fp32 loss output. + return float16_to_fp32(loss) + + +class TransformerLayerNode(ScheduleNode): + """Schedule node for one slot of a fine-grained transformer layer plan. + + Each transformer layer is decomposed into ``pre_dispatch_computation``, + ``moe_dispatch``, ``mlp``, and ``moe_combine`` slots; this class is the scheduler-side + handle for one slot. It owns the slot's stream / event, the per-slot + ``free_input`` policy, and the optional delayed weight-gradient hook. + Subclasses override ``_resolve_free_input`` to specialize the policy + (HybridStackNode does this for grouped layers). + """ + + def __init__( + self, + stream, + event, + layer_state, + chunk_state, + submodule, + name="default", + bwd_dw_callables=None, + extra_args={}, + ): + config = extra_args.get("config", None) + assert config is not None, "model config must be passed to TransformerLayerNode." + is_moe = extra_args.get("is_moe", False) + num_local_experts = extra_args.get("num_local_experts", None) + free_input = self._resolve_free_input(name, is_moe, config, num_local_experts) + self.delay_wgrad_compute = extra_args.get("delay_wgrad_compute", False) + + super().__init__( + weak_method(self.forward_impl), + stream, + event, + weak_method(self.backward_impl), + free_input=free_input, + name=name, + ) + self.layer_state = layer_state + self.chunk_state = chunk_state + self.submodule = submodule + self.detached = tuple() + self.before_detached = tuple() + self.is_mtp = extra_args.get("is_mtp", False) + + self.is_first_layer = extra_args.get("is_first_layer", False) + self.is_last_layer = extra_args.get("is_last_layer", False) + + self.bwd_dw_callables = [] + if bwd_dw_callables is not None: + self.bwd_dw_callables = ( + bwd_dw_callables if isinstance(bwd_dw_callables, list) else [bwd_dw_callables] + ) + + @staticmethod + def _resolve_free_input(name, is_moe, config, num_local_experts): + """Free-input policy hook. Subclasses override to specialize.""" + return should_free_input(name, is_moe, config, num_local_experts) + + def detach(self, t): + """Detach a tensor and remember it for backward through the schedule node.""" + detached = make_viewless(t).detach() + detached.requires_grad = t.requires_grad + self.before_detached = self.before_detached + (t,) + self.detached = self.detached + (detached,) + return detached + + def forward_impl(self, *args): + """Invoke the slot's submodule forward.""" + return self.submodule(self, *args) + + def backward_impl(self, outputs, output_grad): + """Run the slot's backward, holding output_grads when wgrad is delayed.""" + detached_grad = tuple([e.grad for e in self.detached]) + grads = output_grad + detached_grad + self.default_backward_func(outputs + self.before_detached, grads) + # Release the output grad memory after backward finishes, except when + # delay_wgrad_compute is enabled — then the grads are kept until every + # registered ``backward_dw`` callable has run. + if self.delay_wgrad_compute: + self.output_grads = grads + self.delay_grads_release = len(self.bwd_dw_callables) > 0 + + return grads + + def backward_dw(self): + """Run the slot's delayed weight-gradient callables on the slot's stream.""" + if not self.delay_wgrad_compute: + return + if isinstance(self.stream, Callable): + self.stream = self.stream() + with torch.cuda.stream(self.stream): + nvtx_msg = f"{self.name} wgrad" + nvtx_range_push(nvtx_msg) + for module in self.bwd_dw_callables: + module.backward_dw() + nvtx_range_pop(nvtx_msg) + + # The output grad memory is last used in wgrad compute; safe to release now. + assert self.delay_grads_release, "output grad memory should be valid before wgrad." + if self.manual_release_grads: + for tensor in self.output_grads: + tensor.untyped_storage().resize_(0) + self.output_grads = None + + self.bwd_dw_callables = None + + def __del__(self): + # Release references early to help avoid leaks across iterations. + self.before_detached = None + self.detached = None + self.layer_state = None + self.chunk_state = None + self.submodule = None + + +class _BackwardDWWrapper: + """Backward weight-gradient wrapper for a transformer pre-dispatch slot. + + Runs the layer's ``self_attention.backward_dw`` plus, on MoE layers, the + shared-expert ``backward_dw``; coordinates with the cuda-graph wgrad + capture (``set_graphed_backward_dw_callable``) so that scopes covered by + the graph are not re-run eagerly. Used when + ``overlap_moe_expert_parallel_comm`` and ``delay_wgrad_compute`` are both + enabled. + """ + + def __init__(self, layer): + assert isinstance( + layer, GraphableMegatronModule + ), "cuda graphed ep overlap only supports GraphableMegatronModule." + assert isinstance( + layer, TransformerLayer + ), "cuda graphed ep overlap only supports TransformerLayer for now." + self.layer = layer + self.graphed_backward_dw_callable = None + self.attn_dw_callable = layer.self_attention.backward_dw + if layer.is_moe_layer: + self.shared_expert_dw_callable = partial( + layer.mlp.backward_dw, routed_experts=False, shared_experts=True + ) + else: + self.shared_expert_dw_callable = None + self.cuda_graph_scope = layer.config.cuda_graph_scope + + def backward_dw(self): + """Run eager or graphed backward wgrad callables for the wrapped layer.""" + is_replay = hasattr(self.layer, 'cuda_graphs') and self.layer.cuda_graphs + if self.shared_expert_dw_callable is not None and ( + not is_replay or CudaGraphScope.moe_router not in self.cuda_graph_scope + ): + self.shared_expert_dw_callable() + if not is_replay or CudaGraphScope.attn not in self.cuda_graph_scope: + self.attn_dw_callable() + if is_replay and self.graphed_backward_dw_callable is not None: + self.graphed_backward_dw_callable() + self.layer = None + + def set_graphed_backward_dw_callable(self, graphed_backward_dw_callable): + """Plug the cuda-graph backward wgrad replay callable.""" + self.graphed_backward_dw_callable = graphed_backward_dw_callable diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index fa2a2ec4934..09772351570 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -1,9 +1,6 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -import weakref -from contextlib import nullcontext -from functools import partial -from typing import Callable, Optional +from typing import Optional import torch from torch import Tensor @@ -13,397 +10,11 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.pipeline_parallel.utils import ScheduleNode, make_viewless -from megatron.core.transformer.enums import CudaGraphScope -from megatron.core.transformer.module import GraphableMegatronModule, float16_to_fp32 +from megatron.core.pipeline_parallel.utils import ScheduleNode +from megatron.core.transformer.module import GraphableMegatronModule from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.transformer.multi_token_prediction import ( - MultiTokenPredictionLayer, - get_mtp_layer_offset, -) from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor from megatron.core.typed_torch import apply_module, copy_signature -from megatron.core.utils import internal_api, nvtx_range_pop, nvtx_range_push - - -def weak_method(method): - """Creates a weak reference to a method to prevent circular references. - - This function creates a weak reference to a method and returns a wrapper function - that calls the method when invoked. This helps prevent memory leaks from circular - references. - """ - method_ref = weakref.WeakMethod(method) - del method - - def wrapped_func(*args, **kwarg): - # nonlocal object_ref - return method_ref()(*args, **kwarg) - - return wrapped_func - - -@internal_api -def should_free_input(name, is_moe, config, num_local_experts): - """Determine if the node should free its input memory. - - Args: - name: Node name - is_moe: Whether it's a MoE model - config: TransformerConfig object - num_local_experts: Number of local experts in MoE module - - Returns: - bool: Whether to free input memory - """ - # For dense layers [attn, fake, mlp, fake], the input is needed during backward pass - if not is_moe: - return False - enable_deepep = ( - config.moe_token_dispatcher_type == "flex" - and config.moe_flex_dispatcher_backend == "deepep" - ) - enable_hybridep = ( - config.moe_token_dispatcher_type == "flex" - and config.moe_flex_dispatcher_backend == "hybridep" - ) - # Define which nodes should free input memory - # Since we split the computing graph into multiple nodes, we can manually control - # when and how to free the input memory. - # The input and output of A2A are not needed anymore after the forward pass, - # so we can free the input memory after the forward pass. - - # When low precision fp8/4 is enabled, the casted tensors are saved and the - # original bf16 tensors are safe to be freed. - free_mlp = config.fp8 is not None or config.fp4 is not None - if not free_mlp: - # AlltoAll dispatcher with local_num_experts=1 and HybridEP both use identity - # operation for `dispatch_postprocess`, hence the mlp inputs will be directly - # passed to GroupedGemm and should be saved for backward pass. - free_mlp = num_local_experts > 1 or config.moe_token_dispatcher_type != "alltoall" - free_mlp = free_mlp and not enable_hybridep - - free_input_nodes = { - "mlp": free_mlp, - "moe_combine": True, - # For non-DeepEP and non-HybridEP dispatcher mode, the input is the un-dispatched tokens - # and probs before dispatch A2A and it's not needed anymore after the forward pass - # For DeepEP and HybridEP dispatcher mode, they are both needed in backward pass - # and cannot be freed. - # If moe_preprocess is in cuda graph scope, tokens and probs are fixed size tensors, - # so they cannot be freed. - "moe_dispatch": not (enable_deepep or enable_hybridep) - and (CudaGraphScope.moe_preprocess not in config.cuda_graph_scope), - } - - return free_input_nodes.get(name, False) - - -class TransformerLayerState: - """State shared within a transformer layer. - - This class holds state that is shared between different nodes - within a transformer layer. - """ - - pass - - -class PreProcessNode(ScheduleNode): - """Node responsible for preprocessing operations in the model. - - This node handles embedding and rotary positional embedding computations - before the main transformer layers. - """ - - def __init__(self, gpt_model, chunk_state, event, stream): - """Initializes a preprocessing node. - - Args: - gpt_model: The GPT model instance. - chunk_state (TransformerChunkState): State shared within a chunk - event: CUDA event for synchronization. - stream: CUDA stream for execution. - """ - super().__init__(weak_method(self.forward_impl), stream, event, name="pre_process") - self.gpt_model = gpt_model - self.chunk_state = chunk_state - - def forward_impl(self): - """forward pass for pre-processing. - - This method handles: - 1. Decoder embedding computation - 2. Rotary positional embedding computation - 3. Sequence length offset computation for flash decoding - - Returns: - The processed decoder input tensor. - """ - # Get decoder input - if not self.gpt_model.pre_process: - self.chunk_state.decoder_input = self.gpt_model.decoder.input_tensor - # Run GPTModel._preprocess - ( - decoder_input, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - sequence_len_offset, - padding_mask, - ) = self.gpt_model._preprocess( - input_ids=self.chunk_state.input_ids, - position_ids=self.chunk_state.position_ids, - decoder_input=self.chunk_state.decoder_input, - packed_seq_params=self.chunk_state.packed_seq_params, - padding_mask=self.chunk_state.padding_mask, - ) - - # Saved for later use - self.chunk_state.decoder_input = decoder_input - self.chunk_state.rotary_pos_emb = rotary_pos_emb - self.chunk_state.rotary_pos_cos = rotary_pos_cos - self.chunk_state.rotary_pos_sin = rotary_pos_sin - self.chunk_state.sequence_len_offset = sequence_len_offset - self.chunk_state.padding_mask = padding_mask - return decoder_input - - -class PostProcessNode(ScheduleNode): - """Node responsible for postprocessing operations in the model. - - This node handles final layer normalization and output layer computation - after the main transformer layers. - """ - - def __init__(self, gpt_model, chunk_state, event, stream): - """Initializes a postprocessing node. - - Args: - gpt_model: The GPT model instance. - chunk_state (TransformerChunkState): State shared within a chunk - event: CUDA event for synchronization. - stream: CUDA stream for execution. - """ - super().__init__(weak_method(self.forward_impl), stream, event, name="post_process") - self.gpt_model = gpt_model - self.chunk_state = chunk_state - - def forward_impl(self, hidden_states): - """Implements the forward pass for postprocessing. - - This method handles: - 1. Output layer computation - 2. Loss computation if labels are provided - - Args: - hidden_states: The hidden states from the transformer layers. - - Returns: - The logits or loss depending on whether labels are provided. - """ - - empty_decoder = len(self.gpt_model.decoder.layers) == 0 - layer_norm = self.gpt_model.decoder.final_layernorm - if not self.gpt_model.config.mtp_num_layers and empty_decoder and layer_norm: - hidden_states = layer_norm(hidden_states) - hidden_states = make_viewless_tensor( - inp=hidden_states, requires_grad=True, keep_graph=True - ) - - # Run GPTModel._postprocess - loss = self.gpt_model._postprocess( - hidden_states=hidden_states, - input_ids=self.chunk_state.input_ids, - position_ids=self.chunk_state.position_ids, - labels=self.chunk_state.labels, - decoder_input=self.chunk_state.decoder_input, - rotary_pos_emb=self.chunk_state.rotary_pos_emb, - rotary_pos_cos=self.chunk_state.rotary_pos_cos, - rotary_pos_sin=self.chunk_state.rotary_pos_sin, - mtp_in_postprocess=False, - loss_mask=self.chunk_state.loss_mask, - attention_mask=self.chunk_state.attention_mask, - packed_seq_params=self.chunk_state.packed_seq_params, - sequence_len_offset=self.chunk_state.sequence_len_offset, - runtime_gather_output=self.chunk_state.runtime_gather_output, - extra_block_kwargs=self.chunk_state.extra_block_kwargs, - ) - - # For now, 1f1b only supports fp16 module - return float16_to_fp32(loss) - - -class TransformerLayerNode(ScheduleNode): - """Base class for transformer layer computation nodes. - - This class provides common functionality for different types of - transformer layer nodes (attention, MLP, etc.) - """ - - def __init__( - self, - stream, - event, - layer_state, - chunk_state, - submodule, - name="default", - bwd_dw_callables=None, - extra_args={}, - ): - """Initialize a transformer layer node. - - Args: - stream (torch.cuda.Stream): CUDA stream for execution - event (torch.cuda.Event): Synchronization event - layer_state (TransformerLayerState): State shared within a layer - chunk_state (TransformerChunkState): State shared within a chunk - submodule (function): The submodule contain forward and dw function - it's the per_batch_state_context, o.w. nullcontext - name (str): Node name, also used to determine memory strategy - bwd_dw_callables (list): List of weight gradient functions for the layer. - extra_args (dict): Extra arguments for the node: is_moe, config. - """ - # determine whether to free input memory - config = extra_args.get("config", None) - assert config is not None, "model config must be passed to TransformerLayerNode." - is_moe = extra_args.get("is_moe", False) - num_local_experts = extra_args.get("num_local_experts", None) - free_input = should_free_input(name, is_moe, config, num_local_experts) - self.delay_wgrad_compute = extra_args.get("delay_wgrad_compute", False) - - super().__init__( - weak_method(self.forward_impl), - stream, - event, - weak_method(self.backward_impl), - free_input=free_input, - name=name, - ) - self.layer_state = layer_state - self.chunk_state = chunk_state - self.submodule = submodule - self.detached = tuple() - self.before_detached = tuple() - self.is_mtp = extra_args.get("is_mtp", False) - - # Create flags to indicate first and last layer - self.is_first_layer = extra_args.get("is_first_layer", False) - self.is_last_layer = extra_args.get("is_last_layer", False) - - # Initialize list to store registered dw callables - self.bwd_dw_callables = [] - if bwd_dw_callables is not None: - self.bwd_dw_callables = ( - bwd_dw_callables if isinstance(bwd_dw_callables, list) else [bwd_dw_callables] - ) - - def detach(self, t): - """Detaches a tensor and stores it for backward computation.""" - detached = make_viewless(t).detach() - detached.requires_grad = t.requires_grad - self.before_detached = self.before_detached + (t,) - self.detached = self.detached + (detached,) - return detached - - def forward_impl(self, *args): - """Calls the submodule as the forward pass.""" - return self.submodule(self, *args) - - def backward_impl(self, outputs, output_grad): - """Implements the backward pass for the transformer layer node.""" - detached_grad = tuple([e.grad for e in self.detached]) - grads = output_grad + detached_grad - self.default_backward_func(outputs + self.before_detached, grads) - # release the output grad memory after backward finishes, - # except when delay_wgrad_comptue is enabled, the grad should be - # kept until all modules' backward_dw has been invoked. - if self.delay_wgrad_compute: - self.output_grads = grads - self.delay_grads_release = len(self.bwd_dw_callables) > 0 - - # return grads for record stream - return grads - - def backward_dw(self): - """Computes the weight gradients for the transformer layer node.""" - if not self.delay_wgrad_compute: - return - if isinstance(self.stream, Callable): - self.stream = self.stream() - with torch.cuda.stream(self.stream): - nvtx_msg = f"{self.name} wgrad" - nvtx_range_push(nvtx_msg) - for module in self.bwd_dw_callables: - module.backward_dw() - nvtx_range_pop(nvtx_msg) - - # the output grad memory is last used in wgrad compute, should be safe to release. - assert self.delay_grads_release, "output grad memory should be valid before wgrad." - if self.manual_release_grads: - for tensor in self.output_grads: - tensor.untyped_storage().resize_(0) - self.output_grads = None - - self.bwd_dw_callables = None - - def __del__(self): - # Release reference as early as possible, this helps avoid memory leak. - self.before_detached = None - self.detached = None - self.layer_state = None - self.chunk_state = None - self.submodule = None - - -class _BackwardDWWrapper: - """Wrapper for managing backward weight gradient computation of attn module. - - This class handles the execution of weight gradient computations for transformer layers, - coordinating between CUDA graphed and non-graphed components. It is used when - overlap_moe_expert_parallel_comm and delay_wgrad_compute are enabled to manage - the delayed weight gradient computation in MoE models. - - The wrapper stores references to the attention and shared expert backward weight gradient - callables, and determines which components should be executed based on whether CUDA graphs - are being replayed and which scopes are covered by the graphs. - """ - - def __init__(self, layer): - assert isinstance( - layer, GraphableMegatronModule - ), "cuda graphed ep overlap only supports GraphableMegatronModule." - assert isinstance( - layer, TransformerLayer - ), "cuda graphed ep overlap only supports TransformerLayer for now." - self.layer = layer - self.graphed_backward_dw_callable = None - self.attn_dw_callable = layer.self_attention.backward_dw - if layer.is_moe_layer: - self.shared_expert_dw_callable = partial( - layer.mlp.backward_dw, routed_experts=False, shared_experts=True - ) - else: - self.shared_expert_dw_callable = None - self.cuda_graph_scope = layer.config.cuda_graph_scope - - def backward_dw(self): - """Execute weight gradients, skipping CUDA graphed components during replay.""" - is_replay = hasattr(self.layer, 'cuda_graphs') and self.layer.cuda_graphs - if self.shared_expert_dw_callable is not None and ( - not is_replay or CudaGraphScope.moe_router not in self.cuda_graph_scope - ): - self.shared_expert_dw_callable() - if not is_replay or CudaGraphScope.attn not in self.cuda_graph_scope: - self.attn_dw_callable() - if is_replay and self.graphed_backward_dw_callable is not None: - self.graphed_backward_dw_callable() - self.layer = None - - def set_graphed_backward_dw_callable(self, graphed_backward_dw_callable): - """Store the CUDA graphed backward weight gradient callable.""" - self.graphed_backward_dw_callable = graphed_backward_dw_callable def build_transformer_layer_callables(layer: TransformerLayer): @@ -412,12 +23,15 @@ def build_transformer_layer_callables(layer: TransformerLayer): functions. This decomposition separates computation-heavy tasks (e.g., self-attention, MLP) from communication-heavy tasks (e.g., MoE's All-to-All). - The five callables are: - 1. Attention (computation) - 2. Post-Attention (computation) - 3. MoE Dispatch (communication) - 4. MLP / MoE Experts (computation) - 5. MoE Combine (communication) + The five callables align with the schedule plan's slot order: + 1. pre_dispatch_computation (computation): + attention -> pre-MLP layernorm -> router -> dispatch preprocess. + For dense layers this is just the attention pass. + 2. moe_dispatch (communication): MoE dispatch All-to-All. + 3. mlp / moe_experts (computation): dense MLP or routed-experts compute. + 4. moe_combine (communication): MoE combine All-to-All + post-MLP residual. + 5. mtp_post_process (computation): always ``None`` here; only the MTP + wrapper in ``common/fine_grained_callables.py`` fills this slot. By assigning these functions to different CUDA streams (e.g., a compute stream and a communication stream), the scheduler can overlap their execution, preventing @@ -429,8 +43,11 @@ def build_transformer_layer_callables(layer: TransformerLayer): Returns: A tuple containing: - - forward_funcs: List of callable functions for the layer - - backward_dw: Dict of weight gradient functions for the layer + - forward_funcs: List of 5 callables, one per slot in the schedule plan + (pre_dispatch_computation, moe_dispatch, mlp, moe_combine, + mtp_post_process=None). + - backward_dw: Dict mapping slot name to the delayed-wgrad callable + (keys: "pre_dispatch_computation", "mlp"). """ is_moe = isinstance(layer.mlp, MoELayer) @@ -443,9 +60,9 @@ def build_transformer_layer_callables(layer: TransformerLayer): and layer.config.moe_flex_dispatcher_backend == "hybridep" ) - def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor): + def submodule_pre_dispatch_forward(node: ScheduleNode, hidden_states: torch.Tensor): """ - Performs same attnention forward logic as GPT Model and forward pass for + Performs the same attention forward logic as GPTModel and the forward pass for computations between attention and dispatch: pre mlp layernorm->router->dispatch preprocess """ @@ -543,7 +160,7 @@ def submodule_dispatch_forward( token_dispatcher = layer.mlp.token_dispatcher if enable_deepep or enable_hybridep: # update token_probs to be the detached version, prevents - # backward graph from connecting to attn submodule + # backward graph from connecting to pre_dispatch_computation submodule token_dispatcher._comm_manager.token_probs = probs dispatched_tokens, dispatched_probs = layer.mlp.dispatch(local_tokens, probs) @@ -638,116 +255,13 @@ def raise_not_implemented(*args): raise NotImplementedError("This callable is not implemented for Dense layer.") # Build forward and backward callable functions - attn_func = submodule_attn_forward + pre_dispatch_func = submodule_pre_dispatch_forward dispatch_func = submodule_dispatch_forward if is_moe else raise_not_implemented mlp_func = submodule_moe_forward if is_moe else mlp_wrapper combine_func = submodule_combine_forward if is_moe else raise_not_implemented layer.init_backward_dw_wrapper() - forward_funcs = [attn_func, dispatch_func, mlp_func, combine_func, None] - backward_dw = {"attn": layer.backward_dw_wrapper, "mlp": layer.mlp} + forward_funcs = [pre_dispatch_func, dispatch_func, mlp_func, combine_func, None] + backward_dw = {"pre_dispatch_computation": layer.backward_dw_wrapper, "mlp": layer.mlp} return forward_funcs, backward_dw - - -def build_mtp_layer_callables(layer): - """Callables for multi-token prediction layer nodes. - - This class contains the callable functions for different types of - multi-token prediction layer nodes (attention, MLP, etc.) - """ - - forward_funcs, backward_dw = build_transformer_layer_callables(layer.mtp_model_layer) - attn_forward, dispatch_forward, mlp_forward, combine_forward, _ = forward_funcs - is_moe = isinstance(layer.mtp_model_layer.mlp, MoELayer) - assert is_moe, "MTP layer in a2a overlap only supports MoE layer for now." - - def submodule_mtp_attn_forward(node, hidden_states): - # MTP Block Preprocess - if node.is_first_layer: - offset = get_mtp_layer_offset(layer.config, node.chunk_state.model.vp_stage) - node.chunk_state.mtp_hidden_states = list(torch.chunk(hidden_states, 1 + offset, dim=0)) - hidden_states = node.chunk_state.mtp_hidden_states[offset] - - input_ids, position_ids, decoder_input, hidden_states = layer._get_embeddings( - input_ids=node.chunk_state.input_ids, - position_ids=node.chunk_state.position_ids, - embedding=node.chunk_state.model.embedding, - hidden_states=hidden_states, - ) - node.chunk_state.input_ids = input_ids - node.chunk_state.position_ids = position_ids - - # MTP Layer Preprocess - # norm, linear projection and transformer - assert ( - node.chunk_state.context is None - ), f"multi token prediction + cross attention is not yet supported." - assert ( - node.chunk_state.packed_seq_params is None - ), f"multi token prediction + sequence packing is not yet supported." - - if layer.config.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - - # fp8 context is added in 1f1b schedule, so we don't need to add it here - with rng_context: - hidden_states = layer._concat_embeddings(hidden_states, decoder_input) - return attn_forward(node, hidden_states) - - def submodule_mtp_postprocess_forward(node, hidden_states): - hidden_states = layer._postprocess(hidden_states) - node.chunk_state.mtp_hidden_states.append(hidden_states) - if node.is_last_layer: - hidden_states = torch.cat(node.chunk_state.mtp_hidden_states, dim=0) - node.chunk_state.mtp_hidden_states = None - return hidden_states - - def rng_context_wrapper(func, *args, **kwargs): - """ - Wrapper to add rng context to submodule callables - """ - if layer.config.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - with rng_context: - return func(*args, **kwargs) - - # Build forward and backward callable functions - # attn_forward already has rng context, no need to wrap - attn_func = submodule_mtp_attn_forward - dispatch_func = partial(rng_context_wrapper, dispatch_forward) - mlp_func = partial(rng_context_wrapper, mlp_forward) - combine_func = partial(rng_context_wrapper, combine_forward) - mtp_post_process_func = submodule_mtp_postprocess_forward - - forward_funcs = [attn_func, dispatch_func, mlp_func, combine_func, mtp_post_process_func] - if isinstance(backward_dw["attn"], list): - backward_dw["attn"].append(layer.eh_proj) - else: - backward_dw["attn"] = [backward_dw["attn"], layer.eh_proj] - - return forward_funcs, backward_dw - - -def build_layer_callables(layer): - """ - Builds the callable functions(forward and dw) for the given layer. - For now, 1f1b overlap only support TransformerLayer and MultiTokenPredictionLayer. - - Args: - layer: The layer to build callables for. - - Returns: - forward_funcs: list of callable functions for the layer. - backward_dw: dict of weight gradient functions for the layer. - """ - if isinstance(layer, TransformerLayer): - return build_transformer_layer_callables(layer) - elif isinstance(layer, MultiTokenPredictionLayer): - return build_mtp_layer_callables(layer) - - raise ValueError(f"Unsupported layer type: {type(layer)}") diff --git a/megatron/core/models/hybrid/fine_grained_callables.py b/megatron/core/models/hybrid/fine_grained_callables.py new file mode 100644 index 00000000000..0ea38ef6d34 --- /dev/null +++ b/megatron/core/models/hybrid/fine_grained_callables.py @@ -0,0 +1,349 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +from contextlib import nullcontext +from functools import partial +from typing import Optional + +import torch +from torch import Tensor + +from megatron.core.enums import Fp8Recipe +from megatron.core.fp4_utils import get_fp4_context +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.models.common.utils import TransformerLayerNode, should_free_input +from megatron.core.models.hybrid.hybrid_block import HybridStack +from megatron.core.models.hybrid.hybrid_layer_allocation import LayerPatternItem +from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols as LayerSymbols +from megatron.core.models.hybrid.hybrid_layer_allocation import is_layer_group +from megatron.core.pipeline_parallel.utils import ScheduleNode +from megatron.core.transformer.transformer_layer import make_viewless_tensor + + +class _SharedExpertBackwardDWWrapper: + """Run MoE shared-experts wgrad as part of the ``pre_dispatch_computation`` slot. + + Why: shared-experts forward is part of ``_run_moe_preprocess`` (which runs in the + pre_dispatch slot), and TE's delay-wgrad model only ``put``s to the wgrad queue + inside the autograd backward (dgrad). So shared-experts' ``backward_dw`` must + fire *after* the pre_dispatch slot's autograd backward — registering it in the + ``mlp`` slot would call it before that dgrad and trigger + ``RuntimeError: Pop empty queue`` from TE. + """ + + def __init__(self, layer): + self.layer = layer + self.shared_expert_dw_callable = None + if layer.mlp.use_shared_expert and not layer.mlp.shared_expert_overlap: + self.shared_expert_dw_callable = partial( + layer.mlp.backward_dw, routed_experts=False, shared_experts=True + ) + + def backward_dw(self): + """Run shared-expert backward wgrad after pre-dispatch autograd backward.""" + if self.shared_expert_dw_callable is not None: + self.shared_expert_dw_callable() + self.layer = None + self.shared_expert_dw_callable = None + + +class HybridStackNode(TransformerLayerNode): + """Schedule node for HybridStack-built fine-grained callables. + + Subclassed from ``TransformerLayerNode`` so the runtime backbone (forward / + backward / backward_dw plumbing, detach bookkeeping, output-grad release) + is shared. The hybrid path keeps a separate node class so its free-input + policy can diverge from the GPT defaults — for example, the + ``pre_dispatch_computation`` slot here covers the whole pre-dispatch loop + (mamba + attention + …) rather than a single attention block, and + group-level decisions about whether the input is needed in backward may + differ from ``should_free_input`` in ``gpt/fine_grained_callables.py``. + Keep this override thin until a hybrid counter-example forces it to + diverge; the explicit subclass exists so the divergence can be made + surgically without touching the GPT class. + """ + + @staticmethod + def _resolve_free_input(name, is_moe, config, num_local_experts): + """Hybrid free-input policy. + + Currently mirrors the GPT default: dense layers always retain their + input for backward; MoE-only "moe_dispatch", "mlp", and "moe_combine" + slots can free, subject to the dispatcher / cuda-graph constraints + encoded in ``should_free_input``. Hybrid groups have a + "pre_dispatch_computation" slot whose semantics differ (it covers a + loop over Mamba/attention/GDN sub-layers, not a single attention + block), but its policy resolves to ``False`` in + ``should_free_input``, which is correct: pre-layer outputs are needed + for backward through the loop. Override here when a hybrid-specific + rule is needed. + """ + return should_free_input(name, is_moe, config, num_local_experts) + + +def _get_inner_quant_context(layer): + config = layer.config + if config.fp8 and config.fp8_recipe != Fp8Recipe.delayed: + return get_fp8_context(config, layer.layer_number - 1) + if config.fp4: + return get_fp4_context(config, layer.layer_number - 1) + return nullcontext() + + +def _as_hybrid_layers(layer, layer_type: Optional[LayerPatternItem]): + """Return ``(layer_type, layer)`` pairs for a hybrid logical layer.""" + if isinstance(layer, HybridStack): + return list(zip(layer.layer_type_list, layer.layers)) + assert layer_type is not None, "Hybrid layer scheduling requires the layer type symbol." + return [(layer_type, layer)] + + +def _maybe_apply_final_norm(node: ScheduleNode, hidden_states: Tensor): + final_norm = getattr(node.chunk_state.model.decoder, "final_norm", None) + final_norm = final_norm or getattr(node.chunk_state.model.decoder, "final_layernorm", None) + if not node.is_mtp and final_norm is not None and node.is_last_layer: + hidden_states = final_norm(hidden_states) + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + return hidden_states + + +def _get_moe_padding_mask(node: ScheduleNode): + padding_mask = node.chunk_state.padding_mask + if padding_mask is not None: + # MoELayer.forward receives [batch, seq] and transposes before routing. + padding_mask = padding_mask.transpose(0, 1).bool() + return padding_mask + + +def _run_moe_preprocess(layer, node: ScheduleNode, hidden_states: Tensor): + pre_mlp_layernorm_output = layer._forward_pre_mlp_layernorm(hidden_states) + if isinstance(pre_mlp_layernorm_output, tuple): + if len(pre_mlp_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is expected to have " + f"2 elements (output, residual), but got {len(pre_mlp_layernorm_output)}" + ) + pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output + else: + residual = hidden_states + + if layer.config.fp32_residual_connection: + residual = residual.float() + + shared_expert_output = layer.mlp.shared_experts_compute(pre_mlp_layernorm_output) + probs, routing_map = layer.mlp.route(pre_mlp_layernorm_output, _get_moe_padding_mask(node)) + local_tokens, probs = layer.mlp.preprocess(pre_mlp_layernorm_output, probs, routing_map) + + node.layer_state.residual = node.detach(residual) + if layer.mlp.use_shared_expert and not layer.mlp.shared_expert_overlap: + node.layer_state.shared_expert_output = node.detach(shared_expert_output) + + return local_tokens, probs + + +def _run_moe_experts(layer, node: ScheduleNode, dispatched_tokens: Tensor): + dispatched_probs = node.layer_state.dispatched_probs + enable_hybridep = ( + layer.config.moe_token_dispatcher_type == "flex" + and layer.config.moe_flex_dispatcher_backend == "hybridep" + ) + enable_deepep = ( + layer.config.moe_token_dispatcher_type == "flex" + and layer.config.moe_flex_dispatcher_backend == "deepep" + ) + token_dispatcher = layer.mlp.token_dispatcher + if enable_deepep or enable_hybridep: + token_dispatcher._comm_manager.dispatched_probs = dispatched_probs + + expert_output, _ = layer.mlp.routed_experts_compute(dispatched_tokens, dispatched_probs) + + if enable_hybridep: + tokens_per_expert = token_dispatcher._comm_manager.get_number_of_tokens_per_expert() + node.layer_state.tokens_per_expert = tokens_per_expert + + if layer.recompute_pre_mlp_layernorm: + layer.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(expert_output) + + return expert_output + + +def _run_moe_combine(layer, node: ScheduleNode, output: Tensor): + residual = node.layer_state.residual + shared_expert_output = getattr(node.layer_state, 'shared_expert_output', None) + output = layer.mlp.combine(output) + output = layer.mlp.postprocess(output, shared_expert_output) + output = layer._forward_post_mlp((output, None), residual) + + node.layer_state.residual.record_stream(torch.cuda.current_stream()) + if shared_expert_output is not None: + shared_expert_output.record_stream(torch.cuda.current_stream()) + + node.layer_state.residual = None + node.layer_state.shared_expert_output = None + + return _maybe_apply_final_norm(node, output) + + +def build_hybrid_stack_callables(layer, layer_type: Optional[LayerPatternItem] = None): + """Create fine-grained callables for one logical HybridStack layer. + + A logical layer may be a bracketed nested ``HybridStack`` (for example ``[M*E]``) + or a single legacy hybrid layer symbol. The split is: + pre-dispatch compute -> dispatch -> MLP/experts -> combine. + """ + layer_items = _as_hybrid_layers(layer, layer_type) + if any(is_layer_group(item_type) for item_type, _ in layer_items): + raise ValueError("Nested HybridStack groups are not supported in overlap scheduling.") + + terminal_idx = None + for idx, (item_type, _) in enumerate(layer_items): + if item_type in (LayerSymbols.MLP, LayerSymbols.MOE): + terminal_idx = idx + break + + if terminal_idx is not None and terminal_idx != len(layer_items) - 1: + raise ValueError("HybridStack overlap requires MLP/MoE to be the last layer in a group.") + + terminal_type = layer_items[terminal_idx][0] if terminal_idx is not None else None + terminal_layer = layer_items[terminal_idx][1] if terminal_idx is not None else None + pre_layers = layer_items[:terminal_idx] if terminal_idx is not None else layer_items + is_moe = terminal_type == LayerSymbols.MOE + num_local_experts = terminal_layer.mlp.num_local_experts if is_moe else None + + def pre_dispatch_computation(node: ScheduleNode, hidden_states: Tensor): + for item_type, item_layer in pre_layers: + with _get_inner_quant_context(item_layer): + if item_type == LayerSymbols.MAMBA: + hidden_states = item_layer( + hidden_states=hidden_states, + attention_mask=node.chunk_state.attention_mask, + inference_context=getattr(node.chunk_state, "inference_context", None), + packed_seq_params=node.chunk_state.packed_seq_params, + ) + elif item_type in ( + LayerSymbols.ATTENTION, + LayerSymbols.DS_ATTENTION, + LayerSymbols.GDN, + ): + # Use _forward_attention rather than __call__: an attention half-layer has + # mlp=IdentityOp / mlp_bda=IdentityFuncOp by default, and TransformerLayer's + # __call__ would route through _forward_mlp + mlp_bda, double-applying the + # post-attention residual. + hidden_states, _ = item_layer._forward_attention( + hidden_states=hidden_states, + attention_mask=node.chunk_state.attention_mask, + rotary_pos_emb=node.chunk_state.rotary_pos_emb, + rotary_pos_cos=node.chunk_state.rotary_pos_cos, + rotary_pos_sin=node.chunk_state.rotary_pos_sin, + packed_seq_params=node.chunk_state.packed_seq_params, + sequence_len_offset=node.chunk_state.sequence_len_offset, + ) + # _forward_attention returns the bias_dropout_add output which can be a + # view tensor (the mlp_bda's add into the post-attention residual produces + # a view from a fused/JIT kernel). Downstream cuBLAS matmuls — including + # the terminal MLP/MoE's pre_mlp_layernorm and the next attention's QKV + # projection in a multi-pre-layer group — pick algorithms based on input + # strides; a view's non-canonical strides can lead to different algo + # selection across processes and produce ~1e-5 bit drift on the forward + # output. TransformerLayer's full forward() inserts this exact call at the + # MLP exit (transformer_layer.py:895) for the same reason; the + # _forward_attention shortcut here doesn't get that cleanup, so we add it + # explicitly. Same idea as the make_viewless_tensor in _maybe_apply_final_norm. + hidden_states = make_viewless_tensor( + inp=hidden_states, + requires_grad=hidden_states.requires_grad, + keep_graph=True, + ) + else: + raise ValueError( + f"HybridStack overlap does not support layer type '{item_type}' before " + "the terminal MLP/MoE layer." + ) + + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + if terminal_type == LayerSymbols.MOE: + with _get_inner_quant_context(terminal_layer): + return _run_moe_preprocess(terminal_layer, node, hidden_states) + + if terminal_type is None: + return _maybe_apply_final_norm(node, hidden_states) + + return hidden_states + + def dispatch(node: ScheduleNode, local_tokens: Tensor, probs: Tensor): + enable_hybridep = ( + terminal_layer.config.moe_token_dispatcher_type == "flex" + and terminal_layer.config.moe_flex_dispatcher_backend == "hybridep" + ) + enable_deepep = ( + terminal_layer.config.moe_token_dispatcher_type == "flex" + and terminal_layer.config.moe_flex_dispatcher_backend == "deepep" + ) + token_dispatcher = terminal_layer.mlp.token_dispatcher + if enable_deepep or enable_hybridep: + token_dispatcher._comm_manager.token_probs = probs + with _get_inner_quant_context(terminal_layer): + dispatched_tokens, dispatched_probs = terminal_layer.mlp.dispatch(local_tokens, probs) + node.layer_state.dispatched_probs = node.detach(dispatched_probs) + return dispatched_tokens + + def mlp(node: ScheduleNode, hidden_states: Tensor): + if terminal_type == LayerSymbols.MLP: + with _get_inner_quant_context(terminal_layer): + hidden_states = terminal_layer._forward_mlp( + hidden_states, padding_mask=node.chunk_state.padding_mask + ) + return _maybe_apply_final_norm(node, hidden_states) + if terminal_type == LayerSymbols.MOE: + with _get_inner_quant_context(terminal_layer): + return _run_moe_experts(terminal_layer, node, hidden_states) + return hidden_states + + def combine(node: ScheduleNode, output: Tensor): + with _get_inner_quant_context(terminal_layer): + return _run_moe_combine(terminal_layer, node, output) + + def raise_not_implemented(*args): + raise NotImplementedError("This callable is not implemented for non-MoE hybrid layers.") + + backward_dw = {} + pre_bwd_dw = [] + for item_type, item_layer in pre_layers: + if item_type in (LayerSymbols.ATTENTION, LayerSymbols.DS_ATTENTION, LayerSymbols.GDN): + # TransformerLayer-backed pre-layers go through the standard + # _BackwardDWWrapper which coordinates attn / shared-expert wgrad + # with cuda-graph replay scopes. + item_layer.init_backward_dw_wrapper() + pre_bwd_dw.append(item_layer.backward_dw_wrapper) + elif item_type == LayerSymbols.MAMBA: + # MambaLayer is not a TransformerLayer, so init_backward_dw_wrapper + # would assert. MambaLayer.backward_dw delegates to its mixer, which + # in turn calls backward_dw on the in_proj / out_proj linears. The + # schedule node iterates this list and calls .backward_dw() on each; + # registering the layer directly is sufficient. + pre_bwd_dw.append(item_layer) + if is_moe: + # MoELayer.backward_dw default kwargs (routed_experts=True, shared_experts=False) + # handle the routed-experts wgrad in the mlp slot. The shared-experts wgrad goes + # into the pre_dispatch_computation slot so it runs after that slot's autograd + # backward (where TE's wgrad_store.put fires); the wrapper is a no-op when + # shared_expert_overlap is enabled. + shared_expert_dw = _SharedExpertBackwardDWWrapper(terminal_layer) + if shared_expert_dw.shared_expert_dw_callable is not None: + pre_bwd_dw.append(shared_expert_dw) + backward_dw["mlp"] = terminal_layer.mlp + elif terminal_type == LayerSymbols.MLP: + backward_dw["mlp"] = terminal_layer.mlp + + if pre_bwd_dw: + backward_dw["pre_dispatch_computation"] = pre_bwd_dw + + forward_funcs = [ + pre_dispatch_computation, + dispatch if is_moe else raise_not_implemented, + mlp, + combine if is_moe else raise_not_implemented, + None, + ] + return forward_funcs, backward_dw, is_moe, num_local_experts diff --git a/megatron/core/models/hybrid/hybrid_block.py b/megatron/core/models/hybrid/hybrid_block.py index 6d20bcdd6e5..f61764eaccd 100644 --- a/megatron/core/models/hybrid/hybrid_block.py +++ b/megatron/core/models/hybrid/hybrid_block.py @@ -19,7 +19,12 @@ from megatron.core.fp4_utils import get_fp4_context from megatron.core.fp8_utils import get_fp8_context from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.models.hybrid.hybrid_layer_allocation import LayerPatternItem from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols as LayerSymbols +from megatron.core.models.hybrid.hybrid_layer_allocation import ( + get_layer_type_physical_count, + is_layer_group, +) from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer import TransformerConfig @@ -77,8 +82,10 @@ def __init__( config: TransformerConfig, submodules: HybridStackSubmodules, pre_process: bool = True, - layer_type_list: Optional[list[str]] = None, + layer_type_list: Optional[list[LayerPatternItem]] = None, pp_layer_offset: int = 0, + logical_layer_offset: int = 0, + is_layer_group_stack: bool = False, post_layer_norm: bool = True, post_process: bool = True, device=None, @@ -91,6 +98,8 @@ def __init__( self.post_layer_norm = post_layer_norm self.post_process = post_process self.is_mtp_layer = is_mtp_layer + self.logical_layer_offset = logical_layer_offset + self.is_layer_group_stack = is_layer_group_stack assert pg_collection is not None, "pg_collection must be provided for HybridStack" @@ -109,16 +118,39 @@ def __init__( # Build layers from the pre-selected segment self.layers = nn.ModuleList() - for i, layer_type in enumerate(self.layer_type_list): - layer_number = i + 1 + pp_layer_offset - if self.config.fp8: - quant_init_context = get_fp8_context(self.config, i + pp_layer_offset, is_init=True) + physical_layer_offset = pp_layer_offset + for layer_type in self.layer_type_list: + layer_number = physical_layer_offset + 1 + if is_layer_group(layer_type): + quant_init_context = nullcontext() + elif self.config.fp8: + quant_init_context = get_fp8_context( + self.config, physical_layer_offset, is_init=True + ) elif self.config.fp4: - quant_init_context = get_fp4_context(self.config, i + pp_layer_offset, is_init=True) + quant_init_context = get_fp4_context( + self.config, physical_layer_offset, is_init=True + ) else: quant_init_context = nullcontext() with quant_init_context: - if layer_type == LayerSymbols.MAMBA: + if is_layer_group(layer_type): + layer = HybridStack( + config=self.config, + submodules=submodules, + pre_process=True, + layer_type_list=list(layer_type), + pp_layer_offset=physical_layer_offset, + logical_layer_offset=logical_layer_offset + len(self.layers), + is_layer_group_stack=True, + post_layer_norm=False, + post_process=False, + device=device, + dtype=dtype, + pg_collection=pg_collection, + is_mtp_layer=is_mtp_layer, + ) + elif layer_type == LayerSymbols.MAMBA: layer = build_module( submodules.mamba_layer, config=self.config, @@ -174,6 +206,7 @@ def __init__( else: raise ValueError("unexpected layer_type") self.layers.append(layer) + physical_layer_offset += get_layer_type_physical_count(layer_type) # Required for activation recomputation self.num_layers_per_pipeline_rank = len(self.layers) @@ -186,6 +219,17 @@ def __init__( eps=self.config.layernorm_epsilon, ) + @property + def final_layernorm(self): + """Alias for ``final_norm`` matching the attribute name on TransformerBlock. + + Lets generic decoder consumers (e.g. ``GPTModel.PostProcessNode``) discover the + final norm via the same attribute name they use for non-hybrid decoders, while + keeping ``final_norm`` as the registered submodule so existing hybrid checkpoint + keys are unchanged. + """ + return getattr(self, "final_norm", None) + def set_input_tensor(self, input_tensor: Tensor): """Set input tensor to be used instead of forward()'s input. @@ -202,7 +246,11 @@ def mamba_state_shapes_per_request(self) -> Optional[Tuple[Tuple[int], Tuple[int if this block contains Mamba layers (this may not be the case with PP > 1). """ for layer_type, layer in zip(self.layer_type_list, self.layers): - if layer_type == LayerSymbols.MAMBA: + if is_layer_group(layer_type): + shapes = layer.mamba_state_shapes_per_request() + if shapes is not None: + return shapes + elif layer_type == LayerSymbols.MAMBA: return layer.mamba_state_shapes_per_request() return None @@ -212,6 +260,10 @@ def forward( attention_mask: Tensor, inference_context: Optional[BaseInferenceContext] = None, rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + rotary_pos_cos_sin: Optional[Tensor] = None, + sequence_len_offset: Optional[Tensor] = None, *, inference_params: Optional[BaseInferenceContext] = None, packed_seq_params: Optional[PackedSeqParams] = None, @@ -252,7 +304,9 @@ def forward( inference_context.max_seqlen = inference_context.max_sequence_length inference_context.seqlen_offset = inference_context.sequence_len_offset - if ( + if sequence_len_offset is not None: + pass + elif ( ( ( self.config.cuda_graph_impl == "local" @@ -301,14 +355,35 @@ def get_inner_quant_context(config, layer_number): with outer_fp8_context: for layer in self.layers: # Layers have 1-indexed layer numbers attribute. - inner_quant_context = get_inner_quant_context(self.config, layer.layer_number - 1) + if isinstance(layer, HybridStack): + inner_quant_context = nullcontext() + else: + inner_quant_context = get_inner_quant_context( + self.config, layer.layer_number - 1 + ) with inner_quant_context: - if isinstance(layer, TransformerLayer): + if isinstance(layer, HybridStack): + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + sequence_len_offset=sequence_len_offset, + packed_seq_params=packed_seq_params, + padding_mask=padding_mask, + ) + elif isinstance(layer, TransformerLayer): hidden_states, _ = layer( hidden_states=hidden_states, attention_mask=attention_mask, inference_context=inference_context, rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, sequence_len_offset=sequence_len_offset, packed_seq_params=packed_seq_params, padding_mask=padding_mask, @@ -361,17 +436,48 @@ def sharded_state_dict( dict: The sharded state dictionary for the current object. """ + return self._sharded_state_dict( + prefix=prefix, + sharded_offsets=sharded_offsets, + metadata=metadata, + sharded_layer_prefix=None, + ) + + def _sharded_state_dict( + self, + prefix: str = '', + sharded_offsets: Optional[tuple] = None, + metadata: Optional[dict] = None, + sharded_layer_prefix: Optional[str] = None, + ) -> ShardedStateDict: + sharded_offsets = sharded_offsets or () sharded_state_dict = {} layer_prefix = f'{prefix}layers.' + if sharded_layer_prefix is None: + sharded_layer_prefix = layer_prefix - for local_layer_idx, layer in enumerate(self.layers): - - global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 - state_dict_prefix = ( - f'{layer_prefix}{local_layer_idx}.' # module list index in HybridStack + for local_layer_idx, (layer_type, layer) in enumerate( + zip(self.layer_type_list, self.layers) + ): + state_dict_prefix = f'{layer_prefix}{local_layer_idx}.' + logical_layer_idx = ( + self.logical_layer_offset + if self.is_layer_group_stack + else self.logical_layer_offset + local_layer_idx ) - sharded_prefix = f'{layer_prefix}{global_layer_offset}.' + if is_layer_group(layer_type): + sharded_state_dict.update( + layer._sharded_state_dict( + state_dict_prefix, + sharded_offsets, + metadata, + sharded_layer_prefix=sharded_layer_prefix, + ) + ) + continue + + sharded_prefix = f'{sharded_layer_prefix}{logical_layer_idx}.' sharded_pp_offset = [] layer_sharded_state_dict = layer.sharded_state_dict( @@ -385,15 +491,10 @@ def sharded_state_dict( # Add modules other than self.layers for name, module in self.named_children(): if not module is self.layers: - sharded_state_dict.update( - sharded_state_dict_default( - module, - f'{prefix}{name}.', - sharded_offsets, - metadata, - tp_group=self.tp_group, - ) + module_sharded_state_dict = sharded_state_dict_default( + module, f'{prefix}{name}.', sharded_offsets, metadata, tp_group=self.tp_group ) + sharded_state_dict.update(module_sharded_state_dict) return sharded_state_dict diff --git a/megatron/core/models/hybrid/hybrid_layer_allocation.py b/megatron/core/models/hybrid/hybrid_layer_allocation.py index f1ba94ef7fa..fdd51a94459 100644 --- a/megatron/core/models/hybrid/hybrid_layer_allocation.py +++ b/megatron/core/models/hybrid/hybrid_layer_allocation.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch @@ -22,6 +22,8 @@ class Symbols: MOE = 'E' PIPE = '|' MTP_SEPARATOR = "/" + GROUP_START = "[" + GROUP_END = "]" VALID_LAYERS = {MAMBA, GDN, ATTENTION, DS_ATTENTION, MLP, MOE} @classmethod @@ -37,6 +39,57 @@ def name_sorted_valid_layer_symbols(cls) -> list[str]: return [value for (_, value) in valid_layer_attrs] +LayerPatternItem = Union[str, Tuple[str, ...]] + + +def is_layer_group(layer_type: LayerPatternItem) -> bool: + """Return whether a parsed layer item is a bracketed group.""" + return isinstance(layer_type, tuple) + + +def flatten_layer_type_list(layer_type_list: List[LayerPatternItem]) -> List[str]: + """Flatten bracketed layer groups into their physical layer symbols.""" + flattened = [] + for layer_type in layer_type_list: + if is_layer_group(layer_type): + flattened.extend(layer_type) + else: + flattened.append(layer_type) + return flattened + + +def get_layer_type_physical_count(layer_type: LayerPatternItem) -> int: + """Return the number of physical layers represented by a parsed layer item.""" + return len(layer_type) if is_layer_group(layer_type) else 1 + + +def get_layer_type_logical_count(layer_type: LayerPatternItem) -> int: + """Return the number of logical layers represented by a parsed layer item.""" + return 1 + + +def get_layer_type_list_physical_count(layer_type_list: List[LayerPatternItem]) -> int: + """Return the number of physical layers represented by a parsed layer list.""" + return sum(get_layer_type_physical_count(layer_type) for layer_type in layer_type_list) + + +def get_layer_type_list_logical_count(layer_type_list: List[LayerPatternItem]) -> int: + """Return the number of logical layers represented by a parsed layer list.""" + return sum(get_layer_type_logical_count(layer_type) for layer_type in layer_type_list) + + +def layer_type_item_to_str(layer_type: LayerPatternItem) -> str: + """Render one parsed layer item back to pattern syntax.""" + if is_layer_group(layer_type): + return f"{Symbols.GROUP_START}{''.join(layer_type)}{Symbols.GROUP_END}" + return layer_type + + +def layer_type_list_to_str(layer_type_list: List[LayerPatternItem]) -> str: + """Render a parsed layer list back to pattern syntax.""" + return ''.join(layer_type_item_to_str(layer_type) for layer_type in layer_type_list) + + @dataclass class ParsedHybridPattern: """Result of parsing a unified hybrid pattern string. @@ -138,7 +191,10 @@ def get_hybrid_total_layer_count(pattern: str) -> int: """ main_pattern = pattern.split(Symbols.MTP_SEPARATOR)[0] _validate_pattern(main_pattern, "main", allow_pipe=True) - return len(main_pattern.replace(Symbols.PIPE, '')) + return sum( + get_layer_type_list_physical_count(validate_segment_layers(segment)) + for segment in main_pattern.split(Symbols.PIPE) + ) def get_hybrid_total_pipeline_segment_count(pattern: str) -> int: @@ -183,15 +239,14 @@ def get_hybrid_layer_counts(pattern: str) -> Dict[str, int]: # Count main decoder layers (skip '|' pipe separators) if parsed.main_pattern: - for char in parsed.main_pattern: - if char in counts: + for segment in parsed.main_pattern.split(Symbols.PIPE): + for char in flatten_layer_type_list(validate_segment_layers(segment)): counts[char] += 1 # Count MTP layers (pattern repeated mtp_num_depths times) if parsed.mtp_pattern and parsed.mtp_num_depths > 0: - for char in parsed.mtp_pattern: - if char in counts: - counts[char] += parsed.mtp_num_depths + for char in flatten_layer_type_list(validate_segment_layers(parsed.mtp_pattern)): + counts[char] += parsed.mtp_num_depths return counts @@ -266,6 +321,16 @@ def parse_hybrid_pattern(pattern: Optional[str]) -> ParsedHybridPattern: _validate_pattern(mtp_pattern, "MTP", allow_pipe=False) + # MTP layers are themselves a fused unit (each MTP depth contains its own attention + # + MLP), so it does not make sense to wrap them in a HybridStack group. Reject + # bracketed groups inside MTP patterns to keep downstream construction simple. + if Symbols.GROUP_START in mtp_pattern or Symbols.GROUP_END in mtp_pattern: + raise ValueError( + f"In MTP pattern, layer groups '{Symbols.GROUP_START}...{Symbols.GROUP_END}' " + f"are not supported because each MTP depth is already a fused unit. " + f"Got MTP pattern: '{mtp_pattern}'." + ) + return ParsedHybridPattern( main_pattern=main_pattern if main_pattern else None, mtp_pattern=mtp_pattern, @@ -284,20 +349,90 @@ def _validate_pattern(pattern: str, pattern_name: str, allow_pipe: bool = False) Raises: ValueError: If pattern contains invalid symbols """ - valid_chars = Symbols.VALID_LAYERS | {Symbols.PIPE} if allow_pipe else Symbols.VALID_LAYERS - for char in pattern: - if char not in valid_chars: + valid_chars = ( + Symbols.VALID_LAYERS + | {Symbols.GROUP_START, Symbols.GROUP_END} + | ({Symbols.PIPE} if allow_pipe else set()) + ) + if not allow_pipe and Symbols.PIPE in pattern: + raise ValueError( + f"In {pattern_name} pattern, '{Symbols.PIPE}' is not a valid layer symbol. " + f"Valid symbols are: {valid_chars}" + ) + flat_layers = [] + for segment in pattern.split(Symbols.PIPE): + flat_layers.extend( + flatten_layer_type_list( + _parse_segment_layers(segment, pattern_name, valid_chars=valid_chars) + ) + ) + + # Disallow Attention + MLA/DSA hybridity. + if Symbols.ATTENTION in flat_layers and Symbols.DS_ATTENTION in flat_layers: + raise ValueError("Not supported to have both Attention and MLA/DSA in one model") + + +def _parse_segment_layers( + segment: str, pattern_name: str, valid_chars: Optional[set[str]] = None +) -> List[LayerPatternItem]: + """Parse a pipe-free pattern segment into symbols and bracketed groups.""" + if valid_chars is None: + valid_chars = Symbols.VALID_LAYERS | {Symbols.GROUP_START, Symbols.GROUP_END} + + layer_type_list: List[LayerPatternItem] = [] + flat_layers = [] + i = 0 + while i < len(segment): + layer_char = segment[i] + if layer_char == Symbols.GROUP_START: + group_end = segment.find(Symbols.GROUP_END, i + 1) + if group_end == -1: + raise ValueError( + f"In {pattern_name} pattern, '[' starts a layer group without a matching ']'." + ) + group = segment[i + 1 : group_end] + if group == "": + raise ValueError(f"In {pattern_name} pattern, layer groups cannot be empty.") + if Symbols.GROUP_START in group or Symbols.GROUP_END in group: + raise ValueError( + f"In {pattern_name} pattern, nested layer groups are not supported." + ) + for group_char in group: + if group_char not in Symbols.VALID_LAYERS: + raise ValueError( + f"In {pattern_name} pattern, '{group_char}' is not a valid layer symbol. " + f"Valid symbols are: {valid_chars}" + ) + if Symbols.MOE in group[:-1]: + raise ValueError( + f"In {pattern_name} pattern, MoE layer '{Symbols.MOE}' must be the last " + f"symbol inside a layer group." + ) + group_tuple = tuple(group) + layer_type_list.append(group_tuple) + flat_layers.extend(group_tuple) + i = group_end + 1 + continue + if layer_char == Symbols.GROUP_END: raise ValueError( - f"In {pattern_name} pattern, '{char}' is not a valid layer symbol. " + f"In {pattern_name} pattern, ']' closes a layer group that was not opened." + ) + if layer_char not in Symbols.VALID_LAYERS: + raise ValueError( + f"In {pattern_name} pattern, '{layer_char}' is not a valid layer symbol. " f"Valid symbols are: {valid_chars}" ) + layer_type_list.append(layer_char) + flat_layers.append(layer_char) + i += 1 - # Disallow Attention + MLA/DSA hybridity. - if Symbols.ATTENTION in pattern and Symbols.DS_ATTENTION in pattern: + if Symbols.ATTENTION in flat_layers and Symbols.DS_ATTENTION in flat_layers: raise ValueError("Not supported to have both Attention and MLA/DSA in one model") + return layer_type_list + -def validate_segment_layers(segment: str) -> List[str]: +def validate_segment_layers(segment: str) -> List[LayerPatternItem]: """Validate and convert a single pipeline segment pattern to a layer type list. This is used after the main pattern has been split by '|' into segments. @@ -312,19 +447,91 @@ def validate_segment_layers(segment: str) -> List[str]: Raises: ValueError: If segment contains invalid layer symbols. """ - layer_type_list = list(segment) - for layer_char in layer_type_list: - if layer_char not in Symbols.VALID_LAYERS: + return _parse_segment_layers(segment, "hybrid layer pattern segment") + + +def _slice_layer_type_list_by_physical_range( + layer_type_list: List[LayerPatternItem], offset: int, count: int +) -> List[LayerPatternItem]: + """Slice parsed layer items by physical layer range without splitting groups.""" + selected = [] + cursor = 0 + end = offset + count + for layer_type in layer_type_list: + item_count = get_layer_type_physical_count(layer_type) + item_end = cursor + item_count + if item_end <= offset: + cursor = item_end + continue + if cursor >= end: + break + if cursor < offset or item_end > end: raise ValueError( - f"In hybrid layer pattern segment, '{layer_char}' is not " - f"one of {Symbols.VALID_LAYERS}" + "Pipeline splitting would split a bracketed hybrid layer group. " + "Add pipe ('|') separators around bracketed groups to define valid boundaries." ) + selected.append(layer_type) + cursor = item_end + return selected + + +def _get_logical_offset_from_physical_offset( + layer_type_list: List[LayerPatternItem], offset: int +) -> int: + """Return the logical item count before a physical-layer offset.""" + logical_offset = 0 + cursor = 0 + for layer_type in layer_type_list: + item_count = get_layer_type_physical_count(layer_type) + item_end = cursor + item_count + if item_end <= offset: + logical_offset += get_layer_type_logical_count(layer_type) + cursor = item_end + continue + if cursor == offset: + return logical_offset + raise ValueError( + "Pipeline splitting would split a bracketed hybrid layer group. " + "Add pipe ('|') separators around bracketed groups to define valid boundaries." + ) + if cursor == offset: + return logical_offset + raise ValueError(f"Physical layer offset {offset} is out of range for hybrid layer pattern.") - # Disallow Attention + MLA/DSA hybridity. - if Symbols.ATTENTION in segment and Symbols.DS_ATTENTION in segment: - raise ValueError("Not supported to have both Attention and MLA/DSA in one model") - return layer_type_list +def select_pipeline_segment_with_logical_offset( + main_pattern: str, + pp_group: Optional[torch.distributed.ProcessGroup], + vp_stage: Optional[int], + first_stage_layers: Optional[int] = None, + last_stage_layers: Optional[int] = None, +) -> Tuple[List[LayerPatternItem], int, int]: + """Select a pipeline segment and return physical and logical offsets.""" + layer_type_list, layer_offset = select_pipeline_segment( + main_pattern, + pp_group, + vp_stage, + first_stage_layers=first_stage_layers, + last_stage_layers=last_stage_layers, + ) + + segments = main_pattern.split(Symbols.PIPE) if main_pattern else [''] + if len(segments) == 1: + full_layer_type_list = validate_segment_layers(segments[0]) + logical_layer_offset = _get_logical_offset_from_physical_offset( + full_layer_type_list, layer_offset + ) + else: + pp_rank = torch.distributed.get_rank(pp_group) if pp_group is not None else 0 + pp_size = torch.distributed.get_world_size(pp_group) if pp_group is not None else 1 + vp_rel = vp_stage if vp_stage is not None else 0 + segment_index = vp_rel * pp_size + pp_rank + logical_layer_offset = sum( + get_layer_type_list_logical_count(validate_segment_layers(segments[i])) + for i in range(segment_index) + ) + + return layer_type_list, layer_offset, logical_layer_offset def select_pipeline_segment( @@ -395,7 +602,7 @@ def select_pipeline_segment( ) full_pattern = segments[0] layer_type_list = validate_segment_layers(full_pattern) - num_layers = len(layer_type_list) + num_layers = get_layer_type_list_physical_count(layer_type_list) if first_stage_layers is not None or last_stage_layers is not None: first = first_stage_layers or 0 @@ -438,12 +645,13 @@ def select_pipeline_segment( offset = pp_rank * layers_per_rank count = layers_per_rank - selected = layer_type_list[offset : offset + count] + selected = _slice_layer_type_list_by_physical_range(layer_type_list, offset, count) log_on_each_pipeline_stage( logger, logging.INFO, f"HybridModel: pp_rank={pp_rank}/{pp_size}, vp_stage={vp_stage}, " - f"layers='{''.join(selected)}' ({len(selected)} layers), " + f"layers='{layer_type_list_to_str(selected)}' " + f"({get_layer_type_list_physical_count(selected)} layers), " f"layer_offset={offset} (auto-split)", ) return selected, offset @@ -467,7 +675,10 @@ def select_pipeline_segment( f"the current PP/VPP configuration." ) - layer_offset = sum(len(segments[i]) for i in range(segment_index)) + layer_offset = sum( + get_layer_type_list_physical_count(validate_segment_layers(segments[i])) + for i in range(segment_index) + ) my_segment = segments[segment_index] layer_type_list = validate_segment_layers(my_segment) @@ -477,21 +688,23 @@ def select_pipeline_segment( logging.INFO, f"HybridModel: pp_rank={pp_rank}/{pp_size}, vp_stage={vp_rel}, " f"segment_index={segment_index}/{len(segments)}, " - f"layers='{my_segment}' ({len(layer_type_list)} layers), " + f"layers='{my_segment}' ({get_layer_type_list_physical_count(layer_type_list)} layers), " f"layer_offset={layer_offset}", ) return layer_type_list, layer_offset -def get_layer_maps_from_layer_type_list(layer_type_list: list[str]) -> dict[str, dict[int, int]]: +def get_layer_maps_from_layer_type_list( + layer_type_list: list[LayerPatternItem], +) -> dict[str, dict[int, int]]: """ Returns maps from global layer index to the corresponding layer index for each valid layer type (those in Symbols.VALID_LAYERS) given a layer type list. """ layer_types = [symbol for symbol in Symbols.name_sorted_valid_layer_symbols()] layer_maps = {layer_type: {} for layer_type in layer_types} - for global_layer_idx, layer_type in enumerate(layer_type_list): + for global_layer_idx, layer_type in enumerate(flatten_layer_type_list(layer_type_list)): layer_map = layer_maps[layer_type] local_layer_idx = len(layer_map) layer_map[global_layer_idx] = local_layer_idx diff --git a/megatron/core/models/hybrid/hybrid_model.py b/megatron/core/models/hybrid/hybrid_model.py index 4399c6984a7..fe2ba14cf72 100644 --- a/megatron/core/models/hybrid/hybrid_model.py +++ b/megatron/core/models/hybrid/hybrid_model.py @@ -1,12 +1,13 @@ # Copyright (c) 2023-2026, NVIDIA CORPORATION. All rights reserved. import logging -from typing import Literal, Optional +from typing import Dict, Literal, Optional from torch import Tensor from megatron.core import tensor_parallel from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding @@ -179,19 +180,21 @@ def __init__( # determine the pipeline segment for this model instance. from megatron.core.models.hybrid.hybrid_layer_allocation import ( parse_hybrid_pattern, - select_pipeline_segment, + select_pipeline_segment_with_logical_offset, ) parsed = parse_hybrid_pattern(self.hybrid_layer_pattern) self.mtp_pattern = parsed.mtp_pattern self.mtp_num_depths = parsed.mtp_num_depths - layer_type_list, layer_offset = select_pipeline_segment( - parsed.main_pattern or '', - self.pg_collection.pp, - vp_stage, - first_stage_layers=self.config.num_layers_in_first_pipeline_stage, - last_stage_layers=self.config.num_layers_in_last_pipeline_stage, + layer_type_list, layer_offset, logical_layer_offset = ( + select_pipeline_segment_with_logical_offset( + parsed.main_pattern or '', + self.pg_collection.pp, + vp_stage, + first_stage_layers=self.config.num_layers_in_first_pipeline_stage, + last_stage_layers=self.config.num_layers_in_last_pipeline_stage, + ) ) # Determine if MTP is needed (based on pattern parsing) @@ -256,6 +259,7 @@ def __init__( pre_process=self.pre_process, layer_type_list=layer_type_list, pp_layer_offset=layer_offset, + logical_layer_offset=logical_layer_offset, post_process=self.post_process, dtype=config.params_dtype, pg_collection=self.pg_collection, @@ -324,103 +328,36 @@ def set_input_tensor(self, input_tensor: Tensor) -> None: assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' self.decoder.set_input_tensor(input_tensor[0]) - def preprocess_for_fine_grained_offloading(self): - """Preprocess for fine-grained activation offloading.""" - off_interface.init_chunk_handler( - vp_size=self.config.virtual_pipeline_model_parallel_size, - vp_stage=self.vp_stage, - min_offloaded_tensor_size=self.config.min_offloaded_tensor_size, - ) - if self.disable_param_offloading: - for param in self.decoder.parameters(): - off_interface.mark_not_offloadable(param) - if self.mtp_process: - for param in self.mtp.parameters(): - off_interface.mark_not_offloadable(param) - if self.post_process: - for param in self.output_layer.parameters(): - off_interface.mark_not_offloadable(param) - self.disable_param_offloading = False - - def _should_call_local_cudagraph(self, *args, **kwargs): - """ - Check if we should call the local cudagraph path. - """ - if ( - not self.training - and hasattr(self, 'cudagraph_manager') - and ( - kwargs.get('inference_context') is not None - or kwargs.get('inference_params') is not None - ) - and CudaGraphScope.full_iteration_inference in self.config.cuda_graph_scope - ): - if kwargs['inference_context'].is_static_batching(): - using_cuda_graph = kwargs['inference_context'].is_decode_only() - else: - using_cuda_graph = kwargs['inference_context'].using_cuda_graph_this_step() - - if using_cuda_graph: - return True - return False - - def __call__(self, *args, **kwargs): - if self._should_call_local_cudagraph(*args, **kwargs): - return super().__call__(*args, **kwargs)[0] - return super().__call__(*args, **kwargs) - - def create_mcore_cudagraph_manager(self, config): - """ - Create the cudagraph manager for the full iteration inference scope - """ - if CudaGraphScope.full_iteration_inference in config.cuda_graph_scope: - from megatron.core.transformer.cuda_graphs import CudaGraphManager - - self.cudagraph_manager = CudaGraphManager(config) - - def forward( + def _preprocess( self, input_ids: Tensor, position_ids: Tensor, - attention_mask: Tensor, decoder_input: Tensor = None, - labels: Tensor = None, inference_context: BaseInferenceContext = None, - runtime_gather_output: Optional[bool] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, - loss_mask: Optional[Tensor] = None, - packed_seq_params: Optional[PackedSeqParams] = None, + packed_seq_params: PackedSeqParams = None, padding_mask: Optional[Tensor] = None, - is_spec_decode: Optional[bool] = None, - ) -> Tensor: - """Forward function of the Hybrid model. This function passes the input tensors - through the embedding layer, and then the decoder and finally into the post - processing layer (optional). - - It either returns the Loss values if labels are given or the final hidden units + ): + """Preprocess inputs for HybridStack or combined-1F1B scheduling. + + Mirrors ``GPTModel._preprocess`` so the eager forward and the + EP-overlap ``PreProcessNode`` see the same embedding / rotary / + padding-mask code. Returns the canonical 6-tuple ``(decoder_input, + rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset, + padding_mask)`` — slots HybridModel does not compute (rotary cos/sin) + come back as ``None``. """ - # If decoder_input is provided (not None), then input_ids and position_ids are ignored. - # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. - - if self.config.fine_grained_activation_offloading: - self.preprocess_for_fine_grained_offloading() - - inference_context = deprecate_inference_params(inference_context, inference_params) - in_inference_mode = inference_context is not None and not self.training - if in_inference_mode: - assert runtime_gather_output, "Inference must always gather TP logits" - - # Decoder embedding. + # If decoder_input is provided, input_ids and position_ids are ignored; + # otherwise apply the embedding layer to get decoder_input. if decoder_input is not None: pass elif self.pre_process: + # Decoder embedding. decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) - # Clear the outputs for padding tokens when using dynamic batching with - # quantization scales to avoid corrupting amax calculations + # Clear the outputs for padding tokens when using dynamic batching + # with quantization scales to avoid corrupting amax calculations. if ( in_inference_mode and inference_context.is_dynamic_batching() @@ -428,8 +365,8 @@ def forward( ): decoder_input[inference_context.padding_slice] = 0.0 else: - # intermediate stage of pipeline - # decoder will get hidden_states from encoder.input_tensor + # Intermediate stage of pipeline parallelism — the decoder will get + # hidden_states from encoder.input_tensor. decoder_input = None rotary_pos_emb = None @@ -445,45 +382,100 @@ def forward( rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( inference_context, self.decoder, decoder_input, self.config, packed_seq_params ) - # YarnRotaryEmbedding.forward returns (emb, mscale); discard mscale here + # YarnRotaryEmbedding.forward returns (emb, mscale); discard mscale here. rotary_pos_emb, _ = self.rotary_pos_emb( rotary_seq_len, packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', ) - # Wrap decoder_input to allow the decoder (HybridStack) to delete the - # reference held by this caller function, enabling early garbage collection - # for inference. + # ``sequence_len_offset`` is only needed for flash-decode / local-cudagraph + # static-batching inference; otherwise leave it as ``None``. + if ( + in_inference_mode + and ( + ( + self.config.cuda_graph_impl == "local" + and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope + ) + or self.config.flash_decode + ) + and inference_context.is_static_batching() + ): + current_batch_size = input_ids.shape[0] + import torch + + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * current_batch_size, + dtype=torch.int32, + device='cuda', + ) + else: + sequence_len_offset = None + + # Wrap decoder_input so the decoder (HybridStack) can drop its caller's + # reference for early garbage collection during inference. if in_inference_mode: decoder_input = WrappedTensor(decoder_input) - # The following assert will currently fail when running inference. - # Commented out for now. - # TODO (duncan/rwaleffe): (1) confirm that the externally-generated - # attention mask is not needed and is ignored by the model in - # inference mode, (2) reduce the size of the externally-generated - # attention mask to prevent CPU OOM (as we did for training), (3) - # force the attention mask passed to the model in inference mode to - # be None, so this assert will succeed. - # assert attention_mask is None, "The attention mask is ignored and should be set to None" + return decoder_input, rotary_pos_emb, None, None, sequence_len_offset, padding_mask - # Run decoder. - hidden_states = self.decoder( - hidden_states=decoder_input, - attention_mask=attention_mask, - inference_context=inference_context, - rotary_pos_emb=rotary_pos_emb, - packed_seq_params=packed_seq_params, - padding_mask=padding_mask, + def preprocess_for_fine_grained_offloading(self): + """Preprocess for fine-grained activation offloading.""" + off_interface.init_chunk_handler( + vp_size=self.config.virtual_pipeline_model_parallel_size, + vp_stage=self.vp_stage, + min_offloaded_tensor_size=self.config.min_offloaded_tensor_size, ) + if self.disable_param_offloading: + for param in self.decoder.parameters(): + off_interface.mark_not_offloadable(param) + if self.mtp_process: + for param in self.mtp.parameters(): + off_interface.mark_not_offloadable(param) + if self.post_process: + for param in self.output_layer.parameters(): + off_interface.mark_not_offloadable(param) + self.disable_param_offloading = False + + def _postprocess( + self, + hidden_states, + input_ids, + position_ids, + labels, + rotary_pos_emb, + rotary_pos_cos=None, + rotary_pos_sin=None, + mtp_in_postprocess=None, + loss_mask=None, + decoder_input=None, + attention_mask=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, + is_spec_decode=None, + ): + """Postprocess HybridStack hidden states into logits or language-model loss. + + Mirrors ``GPTModel._postprocess`` so the eager forward and the EP-overlap + ``PostProcessNode`` produce the same logits / loss / MTP outputs. + ``mtp_in_postprocess`` lets the EP-overlap path skip the inline MTP block + (it schedules MTP as separate layer nodes); the eager forward leaves it + ``True`` so the regular MTP forward runs here. + """ + in_inference_mode = inference_context is not None and not self.training + if in_inference_mode: + assert runtime_gather_output, "Inference must always gather TP logits" output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - # Check if speculative decoding is active. When it is, MTP must be - # computed *after* verification so that it is conditioned on verified - # tokens rather than stale speculative tokens from the previous step. + # Speculative decoding: when active, MTP must run *after* verification so + # it conditions on verified tokens rather than stale speculative ones. if is_spec_decode is None: is_spec_decode = ( in_inference_mode @@ -491,8 +483,11 @@ def forward( and inference_context.num_speculative_tokens > 0 ) - mtp_forward_ran = self.mtp_process and not (in_inference_mode or is_spec_decode) - if mtp_forward_ran: + # MTP forward inline (skipped when the EP-overlap plan schedules MTP + # separately, when running inference, or when speculative decoding is + # active). ``self.mtp_process`` guards against models built without an + # MTP block. + if mtp_in_postprocess and self.mtp_process and not (in_inference_mode or is_spec_decode): hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, @@ -510,8 +505,11 @@ def forward( if self.config.mtp_num_layers is not None and self.mtp_process: assert self.config.mtp_num_layers > 0 if in_inference_mode or is_spec_decode: + # Cache decoder hidden states for serial MTP computation after + # speculative token verification. self._decoder_hidden_states_cache = hidden_states else: + # In training/eval, fold MTP loss into hidden_states. hidden_states = process_mtp_loss( hidden_states=hidden_states, labels=labels, @@ -526,15 +524,16 @@ def forward( packed_seq_params=packed_seq_params, scale_logits_fn=self._scale_logits if self.config.use_mup else None, ) + sequence_parallel_override = False if in_inference_mode and inference_context.config.materialize_only_last_token_logits: if inference_context.is_static_batching(): hidden_states = hidden_states[-1:, :, :] else: if self.output_layer.sequence_parallel: - # Perform the sequence parallel gather here instead of after the output layer - # because we need to slice the last token logits from the full view of the - # packed logits across all requests. + # Perform the sequence-parallel gather here instead of after + # the output layer so we can slice the last-token logits from + # the full view of the packed logits across all requests. hidden_states = gather_from_sequence_parallel_region( hidden_states, group=self.pg_collection.tp ) @@ -551,7 +550,6 @@ def forward( ) logits = self._scale_logits(logits) - # Restore sequence parallel execution to the output layer if necessary. if sequence_parallel_override: assert ( in_inference_mode @@ -561,9 +559,184 @@ def forward( self.output_layer.sequence_parallel = True if labels is None: - # [s b h] => [b s h] return logits.transpose(0, 1).contiguous() loss = self.compute_language_model_loss(labels, logits) - return loss + + def build_schedule_plan( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_context: BaseInferenceContext = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, + ): + """Build the HybridModel combined-1F1B schedule plan.""" + if self.config.fine_grained_activation_offloading: + self.preprocess_for_fine_grained_offloading() + + from .model_chunk_schedule_plan import HybridStackModelChunkSchedulePlan + + return HybridStackModelChunkSchedulePlan( + self, + input_ids, + position_ids, + attention_mask, + decoder_input, + labels, + packed_seq_params, + extra_block_kwargs, + runtime_gather_output, + loss_mask, + padding_mask, + ) + + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None + ) -> ShardedStateDict: + """Return a Transformer-compatible sharded state dict for HybridModel.""" + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + output_layer_extra_state_key = f'{prefix}output_layer._extra_state' + + # Match GPTModel checkpoint compatibility: old GPT checkpoints do not include + # output layer extra state, and the TE extra state should be empty. + output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None) + assert not ( + output_extra_state and output_extra_state.data + ), f'Expected output layer extra state to be empty, got: {output_extra_state}' + + return sharded_state_dict + + def _should_call_local_cudagraph(self, *args, **kwargs): + """ + Check if we should call the local cudagraph path. + """ + if ( + not self.training + and hasattr(self, 'cudagraph_manager') + and ( + kwargs.get('inference_context') is not None + or kwargs.get('inference_params') is not None + ) + and CudaGraphScope.full_iteration_inference in self.config.cuda_graph_scope + ): + if kwargs['inference_context'].is_static_batching(): + using_cuda_graph = kwargs['inference_context'].is_decode_only() + else: + using_cuda_graph = kwargs['inference_context'].using_cuda_graph_this_step() + + if using_cuda_graph: + return True + return False + + def __call__(self, *args, **kwargs): + if self._should_call_local_cudagraph(*args, **kwargs): + return super().__call__(*args, **kwargs)[0] + return super().__call__(*args, **kwargs) + + def create_mcore_cudagraph_manager(self, config): + """ + Create the cudagraph manager for the full iteration inference scope + """ + if CudaGraphScope.full_iteration_inference in config.cuda_graph_scope: + from megatron.core.transformer.cuda_graphs import CudaGraphManager + + self.cudagraph_manager = CudaGraphManager(config) + + def forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_context: BaseInferenceContext = None, + runtime_gather_output: Optional[bool] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + padding_mask: Optional[Tensor] = None, + is_spec_decode: Optional[bool] = None, + ) -> Tensor: + """Forward function of the Hybrid model. This function passes the input tensors + through the embedding layer, and then the decoder and finally into the post + processing layer (optional). + + It either returns the Loss values if labels are given or the final hidden units + """ + if self.config.fine_grained_activation_offloading: + self.preprocess_for_fine_grained_offloading() + + inference_context = deprecate_inference_params(inference_context, inference_params) + + in_inference_mode = inference_context is not None and not self.training + if in_inference_mode: + assert runtime_gather_output, "Inference must always gather TP logits" + + # Mirror GPTModel.forward: delegate the embedding / rotary computation and + # the output-layer / MTP / loss computation to the same hooks the + # EP-overlap PreProcessNode / PostProcessNode call. Keeps the eager and + # combined-1F1B paths on the same code. + ( + decoder_input, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + padding_mask, + ) = self._preprocess( + input_ids=input_ids, + position_ids=position_ids, + decoder_input=decoder_input, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + padding_mask=padding_mask, + ) + + # The following assert will currently fail when running inference. + # Commented out for now. + # TODO (duncan/rwaleffe): (1) confirm that the externally-generated + # attention mask is not needed and is ignored by the model in + # inference mode, (2) reduce the size of the externally-generated + # attention mask to prevent CPU OOM (as we did for training), (3) + # force the attention mask passed to the model in inference mode to + # be None, so this assert will succeed. + # assert attention_mask is None, "The attention mask is ignored and should be set to None" + + # Run decoder. + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + padding_mask=padding_mask, + ) + + return self._postprocess( + hidden_states=hidden_states, + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + mtp_in_postprocess=True, + loss_mask=loss_mask, + attention_mask=attention_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + runtime_gather_output=runtime_gather_output, + inference_context=inference_context, + is_spec_decode=is_spec_decode, + ) diff --git a/megatron/core/models/hybrid/model_chunk_schedule_plan.py b/megatron/core/models/hybrid/model_chunk_schedule_plan.py new file mode 100644 index 00000000000..195895eee02 --- /dev/null +++ b/megatron/core/models/hybrid/model_chunk_schedule_plan.py @@ -0,0 +1,138 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Schedule-plan classes for HybridStack-based decoders. + +These extend the GPT-side ``TransformerLayerSchedulePlan`` / +``TransformerModelChunkSchedulePlan`` with the per-layer ``layer_type`` symbol +that HybridStack assigns to each entry of its ``layer_type_list`` (including +bracketed groups like ``[*-]``). The base classes remain GPT-only; this module +adds the hybrid-specific dispatch into ``build_hybrid_stack_callables`` and +uses ``HybridStackNode`` so the schedule node's free-input policy can diverge +from the GPT default. The pre/post-process nodes from +``core.models.common.utils`` are reused as-is — they already call +``model._preprocess`` / ``model._postprocess`` which work on a HybridModel. +""" + +from contextlib import nullcontext + +from megatron.core.models.common.model_chunk_schedule_plan import ( + TransformerLayerSchedulePlan, + TransformerModelChunkSchedulePlan, +) + + +class HybridStackSchedulePlan(TransformerLayerSchedulePlan): + """Per-layer schedule plan for HybridStack decoders. + + Adds the ``layer_type`` extra-arg propagation; routes through + ``build_hybrid_stack_callables`` when ``layer_type`` is set (i.e. the layer + is a HybridStack entry, possibly a bracketed group); falls back to the GPT + path for plain TransformerLayer / MTP layers when ``layer_type`` is None. + """ + + def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_args=None): + if extra_args is None: + extra_args = {} + self.layer_type = extra_args.get("layer_type", None) + super().__init__(layer, event, chunk_state, comp_stream, comm_stream, extra_args) + + def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args): + if self.layer_type is None: + return super()._build_callable_nodes(event, comp_stream, comm_stream, extra_args) + + # Hybrid grouped path. Imports are local because hybrid pulls in TE / SSM + # extensions that we don't want to load when only the GPT path is used. + from megatron.core.models.hybrid.fine_grained_callables import ( + HybridStackNode, + build_hybrid_stack_callables, + ) + from megatron.core.pipeline_parallel.utils import NoopScheduleNode + + fwd_callables, bwd_dw_callable_map, is_moe, num_local_experts = ( + build_hybrid_stack_callables(self.layer, layer_type=self.layer_type) + ) + + extra_args["config"] = self.layer.config + extra_args["is_moe"] = is_moe + extra_args["num_local_experts"] = num_local_experts + extra_args["delay_wgrad_compute"] = self.layer.config.delay_wgrad_compute + extra_args["is_mtp"] = False + + def create_node(stream, module, name): + bwd_dw_callables = bwd_dw_callable_map.get(name, None) + node_extra_args = dict(extra_args) + if bwd_dw_callables is None: + node_extra_args["delay_wgrad_compute"] = False + return HybridStackNode( + stream, + event, + self.layer_state, + self.chunk_state, + module, + name=name, + bwd_dw_callables=bwd_dw_callables, + extra_args=node_extra_args, + ) + + ( + pre_dispatch_module, + moe_dispatch_module, + mlp_module, + moe_combine_module, + mtp_post_process_module, + ) = fwd_callables + + self.pre_dispatch_computation = create_node( + comp_stream, pre_dispatch_module, "pre_dispatch_computation" + ) + self.mlp = create_node(comp_stream, mlp_module, "mlp") + if is_moe: + self.moe_dispatch = create_node(comm_stream, moe_dispatch_module, "moe_dispatch") + self.moe_combine = create_node(comm_stream, moe_combine_module, "moe_combine") + else: + self.moe_dispatch = NoopScheduleNode() + self.moe_combine = NoopScheduleNode() + + # HybridStack groups never carry an MTP terminal, so mtp_post_process is + # always a no-op here. + self.mtp_post_process = NoopScheduleNode() + + def get_fp8_context(self): + """Return an FP8 context only for plain transformer layers.""" + # Grouped hybrid layers (and inferred-layer-type entries that point at + # a HybridStack rather than a plain TransformerLayer) don't have a + # ``layer_number`` we can hand to ``get_fp8_context``; the inner layers + # manage their own per-layer fp8 context inside the hybrid callables. + if self.layer_type is not None or not hasattr(self.layer, "layer_number"): + return nullcontext() + return super().get_fp8_context() + + +class HybridStackModelChunkSchedulePlan(TransformerModelChunkSchedulePlan): + """Model-chunk schedule plan that builds ``HybridStackSchedulePlan`` layer plans. + + Threads HybridStack's ``layer_type_list[layer_idx]`` symbol into each + layer plan's ``extra_args`` so the per-layer plan can dispatch grouped + layers correctly. Ordinary GPT/MTP layers (no ``layer_type_list``) + default to ``layer_type=None`` and follow the GPT path. The pre/post + process nodes inherit from the GPT base class — they already dispatch + on ``model._preprocess`` / ``model._postprocess`` which a HybridModel + implements. + """ + + LAYER_SCHEDULE_PLAN_CLASS = HybridStackSchedulePlan + + def __init__(self, model, *args, **kwargs): + """Initialize the hybrid chunk plan after validating cuda graph support.""" + assert model.config.cuda_graph_impl == "none", ( + "EP A2A overlap with grouped HybridStack patterns (e.g. '[*E]') does not " + "support cuda graphs yet. Set cuda_graph_impl='none' or use an ungrouped pattern." + ) + super().init(model, *args, **kwargs) + + def _extra_args_for_layer(self, module, layer_idx, num_layers): + extra_args = super()._extra_args_for_layer(module, layer_idx, num_layers) + extra_args["layer_type"] = ( + module.layer_type_list[layer_idx] if hasattr(module, "layer_type_list") else None + ) + return extra_args diff --git a/megatron/core/pipeline_parallel/combined_1f1b.py b/megatron/core/pipeline_parallel/combined_1f1b.py index f4f222ad2a1..dea54dd8749 100644 --- a/megatron/core/pipeline_parallel/combined_1f1b.py +++ b/megatron/core/pipeline_parallel/combined_1f1b.py @@ -343,11 +343,9 @@ def forward_backward_step(): unwrapped_model = get_attr_wrapped_model( f_model, "build_schedule_plan", return_model_obj=True ) - from megatron.core.models.gpt.gpt_model import GPTModel - - assert isinstance(unwrapped_model, GPTModel), ( - "The final unwrapped model must be a GPTModel instance " - "since only GPTModel is supported for EP A2A overlapping." + assert hasattr(unwrapped_model, "build_schedule_plan"), ( + "The final unwrapped model must implement build_schedule_plan " + "to support EP A2A overlapping." ) f_schedule_plan, loss_func = forward_step_func( data_iterator, unwrapped_model, return_schedule_plan=True diff --git a/megatron/core/ssm/mamba_layer.py b/megatron/core/ssm/mamba_layer.py index 17903cebf3b..2a49fcbbade 100644 --- a/megatron/core/ssm/mamba_layer.py +++ b/megatron/core/ssm/mamba_layer.py @@ -151,6 +151,17 @@ def forward( return hidden_states + def backward_dw(self): + """Compute weight gradients for the layer's linear projections. + + Delegates to the mixer; lets the hybrid EP-overlap schedule plan + register a Mamba pre-layer's wgrad alongside attention/GDN pre-layers + so the schedule node iterates a uniform set of callables. No-op when + the linears in the spec do not support delayed wgrad. + """ + if hasattr(self.mixer, "backward_dw"): + self.mixer.backward_dw() + def sharded_state_dict( self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None ) -> ShardedStateDict: diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 727c6ef5fd6..a051164522c 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -1227,6 +1227,21 @@ def mamba_state_shapes_per_request(self) -> Tuple[Tuple[int], Tuple[int]]: ssm_states_shape = (self.nheads_local_tp, self.headdim, self.d_state) return (conv_states_shape, ssm_states_shape) + def backward_dw(self): + """Compute weight gradients for the linear layers wrapped by this mixer. + + Mirrors ``GatedDeltaNet.backward_dw``. The selective-scan kernel is a + single autograd function whose wgrad runs in the regular backward pass, + so only the input/output projections need delayed wgrad here. Each + ``backward_dw`` call is a no-op unless the underlying linear is built + from a TE primitive that supports delayed wgrad; if the spec uses + non-TE linears, ``backward_dw`` simply does nothing. + """ + if hasattr(self.in_proj, "backward_dw"): + self.in_proj.backward_dw() + if hasattr(self.out_proj, "backward_dw"): + self.out_proj.backward_dw() + def _get_states_from_cache(self, inference_context, batch_size, *, inference_params=None): """Initializes or retrieves the SSM state tensors from the cache. diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index 6539ee36105..b9535474161 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -199,12 +199,22 @@ def __init__(self, config: TransformerConfig, vp_stage: Optional[int] = None): self.cuda_graph_backward_dw_wrapper = None def init_backward_dw_wrapper(self): - """Initialize the backward_dw_wrapper.""" - from megatron.core.models.gpt.fine_grained_callables import _BackwardDWWrapper + """Initialize ``self.backward_dw_wrapper`` for delayed-wgrad scheduling. + + The wrapper coordinates the per-layer wgrad callables (attention + wgrad, optional shared-expert wgrad) with cuda-graph replay scope so + captured components are not re-run eagerly. The method is defined on + ``GraphableMegatronModule`` so any graphable subclass can opt in; + ``_BackwardDWWrapper`` itself currently asserts the underlying layer + is a ``TransformerLayer``, so MambaLayer-derived modules implement + ``backward_dw`` directly and skip this helper. + """ + from megatron.core.models.common.utils import _BackwardDWWrapper config = getattr(self, 'config', None) assert config is not None, ( - "TransformerLayer must be initialized before calling " "`init_backward_dw_wrapper`." + "Module must be fully constructed (config set) before calling " + "`init_backward_dw_wrapper`." ) self.backward_dw_wrapper = _BackwardDWWrapper(self) diff --git a/pretrain_hybrid.py b/pretrain_hybrid.py index 9ab52ed11ab..4f3c16d6ed3 100644 --- a/pretrain_hybrid.py +++ b/pretrain_hybrid.py @@ -209,12 +209,13 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optio return loss, num_tokens, report -def forward_step(data_iterator, model: HybridModel): +def forward_step(data_iterator, model: HybridModel, return_schedule_plan: bool = False): """Forward training step. Args: data_iterator : Input data iterator model (HybridModel): The Model + return_schedule_plan (bool): Whether to return the schedule plan instead of output tensor. """ timers = get_timers() @@ -253,14 +254,28 @@ def forward_step(data_iterator, model: HybridModel): timers('batch-generator').stop() with stimer: - output_tensor = model( - tokens, - position_ids, - attention_mask, - labels=labels, - packed_seq_params=packed_seq_params, - loss_mask=loss_mask - ) + if return_schedule_plan: + args = get_args() + assert args.overlap_moe_expert_parallel_comm, ( + "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan" + ) + output_tensor = model.build_schedule_plan( + tokens, + position_ids, + attention_mask, + labels=labels, + packed_seq_params=packed_seq_params, + loss_mask=loss_mask, + ) + else: + output_tensor = model( + tokens, + position_ids, + attention_mask, + labels=labels, + packed_seq_params=packed_seq_params, + loss_mask=loss_mask + ) # [ModelOpt]: model is needed to access ModelOpt distillation losses return output_tensor, partial(loss_func, loss_mask, model=model) diff --git a/tests/unit_tests/models/test_hybrid_model.py b/tests/unit_tests/models/test_hybrid_model.py index 98a53da0314..43c18e3adf9 100644 --- a/tests/unit_tests/models/test_hybrid_model.py +++ b/tests/unit_tests/models/test_hybrid_model.py @@ -198,6 +198,29 @@ def test_save_load(self, tmp_path): self.model.load_state_dict(torch.load(path)) + def test_grouped_sharded_state_dict_uses_transformer_checkpoint_keys(self): + """Grouped HybridModel checkpoints should be load-compatible with GPTModel keys.""" + model_config = TransformerConfig( + num_layers=2, hidden_size=256, num_attention_heads=4, use_cpu_initialization=True + ) + model = HybridModel( + config=model_config, + hybrid_stack_spec=hybrid_stack_spec, + vocab_size=100, + max_sequence_length=4, + hybrid_layer_pattern="[*-]", + ) + + sharded_state_dict = model.sharded_state_dict() + sharded_keys = {value.key for value in sharded_state_dict.values() if hasattr(value, "key")} + + assert "decoder.layers.0.self_attention.linear_qkv.weight" in sharded_keys + assert "decoder.layers.0.mlp.linear_fc1.weight" in sharded_keys + assert "decoder.layers.1.mlp.linear_fc1.weight" not in sharded_keys + assert "decoder.final_layernorm.weight" in sharded_keys + assert "decoder.final_norm.weight" not in sharded_keys + assert "output_layer._extra_state" not in sharded_state_dict + def test_layer_numbers(self): """ The layer numbers should start at one (for the embedding # layer) and go up diff --git a/tests/unit_tests/ssm/test_hybrid_block.py b/tests/unit_tests/ssm/test_hybrid_block.py index 08bf7f2bc28..1ac11d2ed4f 100644 --- a/tests/unit_tests/ssm/test_hybrid_block.py +++ b/tests/unit_tests/ssm/test_hybrid_block.py @@ -3,8 +3,13 @@ import pytest import torch +from megatron.core.models.hybrid.fine_grained_callables import build_hybrid_stack_callables from megatron.core.models.hybrid.hybrid_block import HybridStack -from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols, validate_segment_layers +from megatron.core.models.hybrid.hybrid_layer_allocation import ( + Symbols, + get_layer_type_list_physical_count, + validate_segment_layers, +) from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.ssm.gated_delta_net import GatedDeltaNet @@ -27,8 +32,10 @@ def setup_method(self, method): Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) - def get_pg_collection(self): - return ProcessGroupCollection.use_mpu_process_groups(required_pgs=['tp', 'pp', 'cp']) + def get_pg_collection(self, required_pgs=None): + if required_pgs is None: + required_pgs = ['tp', 'pp', 'cp'] + return ProcessGroupCollection.use_mpu_process_groups(required_pgs=required_pgs) def get_mamba_block(self, layer_pattern): layer_type_list = validate_segment_layers(layer_pattern) @@ -81,6 +88,60 @@ def get_dsa_mamba_block(self, layer_pattern): pg_collection=self.get_pg_collection(), ) + def get_attention_mlp_block(self, layer_pattern): + layer_type_list = validate_segment_layers(layer_pattern) + transformer_config = TransformerConfig( + hidden_size=256, + num_layers=get_layer_type_list_physical_count(layer_type_list), + num_attention_heads=4, + hidden_dropout=0.0, + attention_dropout=0.0, + use_cpu_initialization=True, + ) + return HybridStack( + transformer_config, + hybrid_stack_spec.submodules, + layer_type_list=layer_type_list, + pp_layer_offset=0, + pg_collection=self.get_pg_collection(), + ) + + def get_attention_moe_block(self, layer_pattern): + layer_type_list = validate_segment_layers(layer_pattern) + transformer_config = TransformerConfig( + hidden_size=256, + num_layers=get_layer_type_list_physical_count(layer_type_list), + num_attention_heads=4, + ffn_hidden_size=256, + num_moe_experts=8, + expert_model_parallel_size=1, + moe_router_topk=2, + moe_grouped_gemm=True, + moe_token_dispatcher_type="alltoall", + hidden_dropout=0.0, + attention_dropout=0.0, + use_cpu_initialization=True, + ) + return HybridStack( + transformer_config, + hybrid_stack_spec.submodules, + layer_type_list=layer_type_list, + pp_layer_offset=0, + pg_collection=self.get_pg_collection( + required_pgs=[ + 'tp', + 'pp', + 'cp', + 'tp_cp', + 'tp_dp_cp', + 'ep', + 'expt_tp', + 'tp_ep', + 'expt_dp', + ] + ), + ) + def teardown_method(self, method): Utils.destroy_model_parallel() @@ -118,6 +179,102 @@ def test_layer_types(self): assert isinstance(layers[2], TransformerLayer) assert isinstance(layers[2].mlp, MLP) + def test_group_layer_type_builds_nested_hybrid_stack(self): + """Bracketed groups build an inner HybridStack with physical layer numbering.""" + layer_type_list = validate_segment_layers("M[M*]-") + transformer_config = TransformerConfig( + hidden_size=256, + num_layers=get_layer_type_list_physical_count(layer_type_list), + num_attention_heads=4, + use_cpu_initialization=True, + ) + block = HybridStack( + transformer_config, + hybrid_stack_spec.submodules, + layer_type_list=layer_type_list, + pp_layer_offset=0, + pg_collection=self.get_pg_collection(), + ) + assert isinstance(block.layers[0], MambaLayer) + assert isinstance(block.layers[1], HybridStack) + assert isinstance(block.layers[1].layers[0], MambaLayer) + assert isinstance(block.layers[1].layers[1], TransformerLayer) + assert isinstance(block.layers[2], TransformerLayer) + assert [layer.layer_number for layer in block.layers[1].layers] == [2, 3] + assert block.layers[2].layer_number == 4 + + def test_group_sharded_state_dict_uses_logical_layer_keys(self): + """Grouped attention+MLP layers share one Transformer-compatible checkpoint key.""" + layer_type_list = validate_segment_layers("[*-]") + transformer_config = TransformerConfig( + hidden_size=256, + num_layers=get_layer_type_list_physical_count(layer_type_list), + num_attention_heads=4, + use_cpu_initialization=True, + ) + block = HybridStack( + transformer_config, + hybrid_stack_spec.submodules, + layer_type_list=layer_type_list, + pp_layer_offset=0, + logical_layer_offset=0, + pg_collection=self.get_pg_collection(), + ) + + sharded_state_dict = block.sharded_state_dict(prefix="decoder.") + sharded_keys = {value.key for value in sharded_state_dict.values() if hasattr(value, "key")} + + assert "decoder.layers.0.self_attention.linear_qkv.weight" in sharded_keys + assert "decoder.layers.0.mlp.linear_fc1.weight" in sharded_keys + assert "decoder.layers.1.mlp.linear_fc1.weight" not in sharded_keys + assert "decoder.final_layernorm.weight" in sharded_keys + assert "decoder.final_norm.weight" not in sharded_keys + + def test_group_forward_matches_equivalent_flat_layers(self): + """A bracket group is only a scheduling/checkpoint boundary, not new math.""" + flat_block = self.get_attention_mlp_block("*-") + group_block = self.get_attention_mlp_block("[*-]") + + group_block.layers[0].layers[0].load_state_dict(flat_block.layers[0].state_dict()) + group_block.layers[0].layers[1].load_state_dict(flat_block.layers[1].state_dict()) + group_block.final_norm.load_state_dict(flat_block.final_norm.state_dict()) + + flat_block.cuda().eval() + group_block.cuda().eval() + sequence_length = 16 + micro_batch_size = 2 + hidden_states = torch.randn( + sequence_length, micro_batch_size, flat_block.config.hidden_size, device="cuda" + ) + attention_mask = torch.ones( + (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool, device="cuda" + ) + + with torch.no_grad(): + flat_output = flat_block(hidden_states.clone(), attention_mask=attention_mask) + group_output = group_block(hidden_states.clone(), attention_mask=attention_mask) + + torch.testing.assert_close(group_output, flat_output, rtol=0, atol=0) + + def test_group_overlap_callables_keep_ep_moe_split_visible(self): + """EP-overlap scheduling still sees dispatch/experts/combine inside a group.""" + block = self.get_attention_moe_block("[*E]") + + forward_callables, bwd_dw_callable_map, is_moe, num_local_experts = ( + build_hybrid_stack_callables(block.layers[0], layer_type=block.layer_type_list[0]) + ) + + pre_dispatch, dispatch, experts, combine, mtp_post_process = forward_callables + assert callable(pre_dispatch) + assert callable(dispatch) + assert callable(experts) + assert callable(combine) + assert mtp_post_process is None + assert is_moe + assert num_local_experts == 8 + assert "pre_dispatch_computation" in bwd_dw_callable_map + assert "mlp" in bwd_dw_callable_map + def test_invalid_layer_types_cause_failure(self): invalid_symbol = '+' assert invalid_symbol not in Symbols.VALID_LAYERS # sanity check. diff --git a/tests/unit_tests/ssm/test_hybrid_layer_allocation.py b/tests/unit_tests/ssm/test_hybrid_layer_allocation.py index fe0d7c2dc1e..4b721609804 100644 --- a/tests/unit_tests/ssm/test_hybrid_layer_allocation.py +++ b/tests/unit_tests/ssm/test_hybrid_layer_allocation.py @@ -15,6 +15,7 @@ parse_hybrid_pattern, pattern_from_ratios, select_pipeline_segment, + select_pipeline_segment_with_logical_offset, validate_segment_layers, ) @@ -71,6 +72,8 @@ def test_valid_patterns(self): """Test that valid segment patterns produce the correct layer type lists.""" test_cases = [ ("M*-M*-M*-", ['M', '*', '-', 'M', '*', '-', 'M', '*', '-']), + ("M[M*]-", ['M', ('M', '*'), '-']), + ("[M*E]", [('M', '*', 'E')]), ("MMMMMMMMM", ['M'] * 9), ("MM*-MM*-", ['M', 'M', '*', '-', 'M', 'M', '*', '-']), ("E", ['E']), @@ -98,6 +101,10 @@ def test_invalid_symbols_cause_failure(self): validate_segment_layers("M|M") # pipe not valid in a segment with pytest.raises(ValueError): validate_segment_layers("M/M") # MTP separator not valid in a segment + with pytest.raises(ValueError): + validate_segment_layers("M[[M]]") # nested groups are not valid + with pytest.raises(ValueError): + validate_segment_layers("M[EM]") # MoE must be last in a group with pytest.raises(ValueError): # Not allowed to have both standard Attention and MLA/DSA validate_segment_layers("MDM*-") @@ -110,6 +117,8 @@ def test_simple_patterns(self): assert get_hybrid_total_layer_count("M*M*") == 4 assert get_hybrid_total_layer_count("MMMM") == 4 assert get_hybrid_total_layer_count("M") == 1 + assert get_hybrid_total_layer_count("[M*E]") == 3 + assert get_hybrid_total_layer_count("M[M*]-") == 4 def test_with_pipe_separators(self): assert get_hybrid_total_layer_count("M-M-|M-M*-") == 9 @@ -155,6 +164,8 @@ def test_main_pattern_only(self): """Test patterns without MTP (no / separator).""" test_cases = [ ("M*M*", "M*M*"), + ("[M*E]", "[M*E]"), + ("M[M*]-", "M[M*]-"), ("MMMM", "MMMM"), ("*M*M", "*M*M"), ("MM-*", "MM-*"), @@ -231,11 +242,22 @@ def test_invalid_symbols_in_main_pattern(self): "M*X*", # X is not valid "MaMM", # a is not valid "M*M*1", # 1 is not valid + "M[M*]X", # X is not valid after a group ] for pattern in invalid_patterns: with pytest.raises(ValueError, match="not a valid layer symbol"): parse_hybrid_pattern(pattern) + def test_invalid_group_syntax(self): + with pytest.raises(ValueError, match="without a matching"): + parse_hybrid_pattern("M[M*") + with pytest.raises(ValueError, match="not supported"): + parse_hybrid_pattern("M[M[*]]") + with pytest.raises(ValueError, match="cannot be empty"): + parse_hybrid_pattern("M[]") + with pytest.raises(ValueError, match="must be the last"): + parse_hybrid_pattern("M[EM]") + def test_invalid_symbols_in_mtp_pattern(self): """Test that invalid symbols in MTP pattern raise ValueError.""" # Single MTP depth with invalid symbol - should raise "not a valid layer symbol" @@ -350,6 +372,9 @@ def test_with_pipes_and_mtp(self): def test_moe_pattern(self): assert get_hybrid_layer_counts("MEME") == {'*': 0, 'D': 0, 'G': 0, 'M': 2, '-': 0, 'E': 2} + def test_group_pattern(self): + assert get_hybrid_layer_counts("M[M*]E") == {'*': 1, 'D': 0, 'G': 0, 'M': 2, '-': 0, 'E': 1} + def test_mtp_with_attention(self): # MTP pattern "*M" repeated 3 depths -> 3 attn + 3 mamba from MTP assert get_hybrid_layer_counts("MMMM/*M/*M/*M") == { @@ -414,6 +439,21 @@ def test_four_segments(self, mock_log): assert layer_types == expected_layers, f"Failed for vp_stage={vp_stage}" assert offset == expected_offset, f"Failed for vp_stage={vp_stage}" + @patch('megatron.core.models.hybrid.hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_group_segment_offsets(self, mock_log): + layer_types, offset = select_pipeline_segment("[M*E]|M-", pp_group=None, vp_stage=1) + assert layer_types == ['M', '-'] + assert offset == 3 + + @patch('megatron.core.models.hybrid.hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_group_segment_logical_offsets(self, mock_log): + layer_types, physical_offset, logical_offset = select_pipeline_segment_with_logical_offset( + "[*-][*-]|[*E][*E]", pp_group=None, vp_stage=1 + ) + assert layer_types == [('*', 'E'), ('*', 'E')] + assert physical_offset == 4 + assert logical_offset == 2 + @patch('megatron.core.models.hybrid.hybrid_layer_allocation.log_on_each_pipeline_stage') def test_empty_segment(self, mock_log): """Empty segments are allowed for pipeline balancing.""" @@ -688,3 +728,12 @@ def test_all_mamba(self): assert mamba_map == {0: 0, 1: 1, 2: 2} assert mlp_map == {} assert moe_map == {} + + def test_grouped_layers_are_flattened(self): + maps = get_layer_maps_from_layer_type_list([("M", "*", "E"), "M"]) + attention_map, mamba_map, moe_map = operator.itemgetter( + Symbols.ATTENTION, Symbols.MAMBA, Symbols.MOE + )(maps) + assert attention_map == {1: 0} + assert mamba_map == {0: 0, 3: 1} + assert moe_map == {2: 0} diff --git a/tests/unit_tests/transformer/test_submodule_callables.py b/tests/unit_tests/transformer/test_submodule_callables.py index 7b41b3ca197..f4e6cdc0c89 100644 --- a/tests/unit_tests/transformer/test_submodule_callables.py +++ b/tests/unit_tests/transformer/test_submodule_callables.py @@ -2,7 +2,7 @@ import pytest import torch -from megatron.core.models.gpt.fine_grained_callables import build_layer_callables +from megatron.core.models.common.fine_grained_callables import build_layer_callables from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_with_transformer_engine_submodules, ) @@ -65,7 +65,7 @@ def run_model_submodules_with_capture(model, input_tensors, microbatches): output_tensors = [] # get callables - callables, dw = build_layer_callables(model) + callables, dw, _is_moe, _num_local_experts = build_layer_callables(model) attn, dispatch, moe, combine, post_process = callables assert post_process is None dummy_model = DummyState()