Add grouped HybridStack overlap support#1
Conversation
| for item_type, item_layer in pre_layers: | ||
| with _get_inner_quant_context(item_layer): | ||
| if item_type == LayerSymbols.MAMBA: | ||
| hidden_states = _apply_mamba_layer(item_layer, node, hidden_states) |
There was a problem hiding this comment.
After removing cuda graph from these 2 methods, can we remove _apply_mamba_layer and _apply_attention_layer method and directly call layer.forward()?
There was a problem hiding this comment.
Both helpers are deleted and inlined at the single call site in pre_dispatch_computation (d2daf2925). Mamba uses item_layer(...) directly. Attention had to keep item_layer._forward_attention(...) instead of __call__, with this comment at the call site:
# Use _forward_attention rather than __call__: an attention half-layer has
# mlp=IdentityOp / mlp_bda=IdentityFuncOp by default, and TransformerLayer's
# __call__ would route through _forward_mlp + mlp_bda, double-applying the
# post-attention residual.If we wanted to literally call __call__ for the attention case too, the attention layer spec would need mlp_bda set to something that returns mlp_output_with_bias[0] instead of the tuple. Happy to take that as a follow-up, but kept this round scoped to inlining.
hybrid/fine_grained_callables.py: remove _SharedExpertBackwardDWWrapper. The wrapper only existed to call mlp.backward_dw(routed_experts=False, shared_experts=True) under a separate scheduling slot; the same wgrad work runs correctly when MoELayer.shared_experts is registered as a sibling callable in backward_dw["mlp"], which the schedule node already iterates. Skip the shared callable when shared_expert_overlap is enabled (its wgrad is folded into the dispatcher overlap path). gpt/fine_grained_callables.py + hybrid/hybrid_block.py: revert the final_layernorm/final_norm dual lookup in PostProcessNode; instead expose final_layernorm as a property on HybridStack that returns final_norm. GPT's PostProcessNode now finds the final norm under the same attribute name it always used, without GPT-side changes. The registered submodule stays final_norm so existing hybrid checkpoint keys are unchanged. hybrid/fine_grained_callables.py: inline _apply_mamba_layer and _apply_attention_layer at their single call sites in pre_dispatch_computation. The attention case still uses item_layer._forward_attention(...) rather than item_layer(...) because attention half-layers have mlp=IdentityOp and mlp_bda=IdentityFuncOp, so __call__ would route through _forward_mlp + mlp_bda and double-apply the post-attention residual. Comment added at the call site. hybrid/hybrid_layer_allocation.py: reject bracketed groups inside MTP patterns. Each MTP depth is itself a fused unit, so wrapping its symbols in '[...]' has no defined meaning and breaks downstream construction. Raises ValueError with a clear message. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Restore TransformerLayerSchedulePlan and TransformerModelChunkSchedulePlan in
core/models/common/model_chunk_schedule_plan.py to their pre-PR shape
(GPT/MTP only, no layer_type concept) and drop the backward-compat aliases.
The hybrid plan now lives in core/models/hybrid/model_chunk_schedule_plan.py
as HybridStackSchedulePlan / HybridStackModelChunkSchedulePlan, subclassed
from the GPT classes.
To make the subclass small, the GPT base classes grow three extension points:
- LAYER_SCHEDULE_PLAN_CLASS picks the per-layer plan class.
- PRE_PROCESS_NODE_CLASS / POST_PROCESS_NODE_CLASS pick the embedding /
output-layer node classes.
- _extra_args_for_layer is the hook subclasses override to thread per-layer
metadata into the layer plan constructor.
Defaults fall back to the GPT classes so non-hybrid callers see no behavior
change.
HybridModel.forward now delegates to the existing _preprocess / _postprocess
methods instead of inlining the embedding / rotary / output-layer / loss
code, so the eager forward and the EP-overlap PreProcessNode / PostProcessNode
read the same chunk_state slots.
HybridStackNode (in hybrid/fine_grained_callables.py) is the schedule node
class used for hybrid layer plans. Subclassed from TransformerLayerNode so
the runtime backbone is shared, but the free-input policy is now resolved
through a method (_resolve_free_input) that subclasses override. The hybrid
class currently delegates to should_free_input but exists so any
hybrid-specific policy can be added surgically without touching GPT.
HybridPreProcessNode and HybridPostProcessNode (in
hybrid/model_chunk_schedule_plan.py) mirror the GPT counterparts but take a
HybridModel rather than a GPTModel, so the hybrid schedule plan does not
have to pass a HybridModel under a GPT-named attribute.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The MTP forward block in _postprocess was only gated on the kwarg mtp_in_postprocess and the inference / spec-decode flags; the eager forward path (which now delegates to _postprocess with mtp_in_postprocess=True) hits AttributeError: 'HybridModel' object has no attribute 'mtp' on models built without an MTP block. The eager forward had this guard inline before its body was lifted into _postprocess; reinstate it on the method itself so both the eager and EP-overlap paths skip the call when no MTP block is configured. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move the model-agnostic schedule-plan pieces — weak_method, should_free_input, TransformerLayerState, TransformerLayerNode, _BackwardDWWrapper, PreProcessNode, PostProcessNode — out of core/models/gpt/fine_grained_callables.py and into a new core/models/common/utils.py. The common chunk-schedule-plan module and the hybrid path now import from common/utils.py instead of crossing into gpt/. The GPT-specific TransformerLayer / MTPLayer callable builders stay in gpt/fine_grained_callables.py and re-export the moved names so existing imports keep working. Pre/PostProcessNode now use a generic ``model`` attribute (was ``gpt_model``); they call ``model._preprocess`` / ``model._postprocess`` which works for any model that exposes those methods. With the rename HybridModel is just another consumer of the same nodes — the dedicated HybridPreProcessNode / HybridPostProcessNode subclasses are no longer needed and are deleted, along with the PRE_PROCESS_NODE_CLASS / POST_PROCESS_NODE_CLASS overrides on HybridStackModelChunkSchedulePlan. transformer/module.py also imports _BackwardDWWrapper from common/utils.py now; the previous import path crossed from megatron/core/transformer into megatron/core/models/gpt/, which is the cross-cut the reviewer flagged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
31cf8c8 to
da436ac
Compare
| """ | ||
|
|
||
| @staticmethod | ||
| def _resolve_free_input(name, is_moe, config, num_local_experts): |
There was a problem hiding this comment.
The current code works fine if pre_dispatch_computation starts with attn layer and ends with attn layer. We need to doublecheck:
- if
pre_dispatch_computationstarts withssm/GDN, can we release the input tensor ofpre_dispatch_computationafterpre_dispatch_computation's forward pass . - if
pre_dispatch_computationends withssm/GDN, can we release the input tensor ofdispatchafterdispatch's forward pass.
Note: whether the input tensor can be released is determined by:
- whether the input tensor is needed by later submodule's forward, like the input tensor is a residual.
- whether the input tensor is added to
saved_for_backwardand can only be released after its backward pass finishes.
There was a problem hiding this comment.
Good catch. The current HybridStackNode._resolve_free_input delegates to the GPT default, which returns False for the pre_dispatch_computation slot — that's conservative but correct in every case I can reason about right now: the slot's input is the previous logical layer's output, and since hybrid groups always end the pre-dispatch phase by feeding hidden_states into MoE preprocess (which detaches the residual on node.layer_state.residual), nothing downstream of the slot actually retains a reference back to the slot's input.
The two cases you flag:
-
Pre-dispatch starts with SSM/GDN: the slot's input goes into Mamba/GDN forward, which produces a new
hidden_states(Mamba does residual + mixer; GDN's_forward_attentionreturns the post-attn-bda hidden_states). Mamba/GDN backward needs the post-projection activations, not the original input — the input projection's saved-tensors handle that. So in principle the slot input could be freed. -
Pre-dispatch ends with SSM/GDN (no terminal MoE): same reasoning for the final SSM/GDN layer's input. The dispatch slot is a no-op when there's no MoE terminal, so freeing the dispatch input is moot here.
I'd rather leave the policy at False for now and address this in a separate change once we've checked each backward path in detail (ideally with a memory measurement to confirm the saving). The seam to specialize is already there — _resolve_free_input is just a method override on HybridStackNode. Let me know if you'd prefer I make the change now and do the audit, or carry it as a follow-up.
There was a problem hiding this comment.
Sounds good, we can get back here as a follow-up.
There was a problem hiding this comment.
Filed mentally as a follow-up; will revisit when we audit memory peaks under SSM/GDN-only pre-dispatch.
Three follow-up cleanups on the schedule-plan plumbing: * Rename the per-layer placeholder class TransformerLayerState to LayerState in core/models/common/utils.py (the name lives in common now and the Transformer prefix was misleading). * Move build_mtp_layer_callables and the build_layer_callables dispatcher out of core/models/gpt/fine_grained_callables.py and into the new core/models/common/fine_grained_callables.py. MTP is shared between GPTModel and HybridModel; the dispatcher's job is to dispatch on layer type and is naturally common too. Only build_transformer_layer_callables stays in gpt/ since it depends on GPT's MoE wiring. * Drop the gpt/ re-export block (the __all__ list and the bulk import from common.utils). Callers now import the moved names directly from core/models/common/utils.py and core/models/common/fine_grained_callables .py. Gpt/fine_grained_callables.py is GPT-only again. Restore the explanatory comments inside HybridModel._preprocess and _postprocess that the earlier forward-refactor lost — they were originally inline in forward() and got dropped when the body moved into the helper methods. The new comments explain decoder-input handling, rotary cos/sin discard, sequence-parallel gather rationale, and the speculative-decoding ordering constraint. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The slot covers more than attention: in the GPT path it runs attention + pre-MLP layernorm + router + dispatch preprocess; in the hybrid grouped path it loops over Mamba / attention / GDN sub-layers and ends with the MoE dispatch preprocess. The 'attn' name is misleading on both sides and diverges from the hybrid forward callable that has long been called pre_dispatch_computation. Rename the slot consistently: * TransformerLayerSchedulePlan: class attribute attn -> pre_dispatch_ computation, release_state, and run() updated. backward_dw key 'attn' -> 'pre_dispatch_computation' so it matches the slot name the schedule node passes when looking up the bwd_dw callables map. * HybridStackSchedulePlan: same self.pre_dispatch_computation slot. * GPTModel build_transformer_layer_callables and the MTP builder: forward_funcs first entry is now pre_dispatch_func; backward_dw uses the renamed key. * Hybrid build_hybrid_stack_callables: backward_dw uses the renamed key. No behavior change; the slot still receives the same callables and runs on the same stream. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
MambaLayer used to be skipped in the hybrid EP-overlap pre-layer backward-dw registration loop because MambaLayer is not a TransformerLayer and the standard _BackwardDWWrapper(self) constructor asserts isinstance (layer, TransformerLayer). That meant Mamba pre-layers in a grouped HybridStack ran their wgrad inline in the regular backward pass while attention / GDN pre-layers got delayed wgrad — different scheduling for identical-shaped slots. Mirror GatedDeltaNet.backward_dw on Mamba: MambaMixer.backward_dw calls in_proj.backward_dw and out_proj.backward_dw (no-op when the spec uses non-TE linears that lack delayed wgrad); MambaLayer.backward_dw delegates to its mixer. In the hybrid pre-layer registration, add a MAMBA branch that appends the layer directly (not the wrapper, since MambaLayer is not a TransformerLayer). The schedule node already iterates the list and calls .backward_dw() on each, so the direct registration just works. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…tch_computation Follow-up to the schedule-plan slot rename: the test assertion that the hybrid grouped-overlap callables expose a backward-dw entry for the pre-dispatch slot was still checking the old 'attn' key; update it to 'pre_dispatch_computation'. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Three follow-up cleanups from review of common/fine_grained_callables.py: * build_layer_callables now dispatches HybridStack alongside TransformerLayer and MultiTokenPredictionLayer, so an mtp_model_layer that happens to be a HybridStack (or a future generic decoder layer) goes through the dispatcher rather than the TransformerLayer-only entrypoint. * The dispatcher now returns (forward_funcs, backward_dw, is_moe, num_local_experts) for every layer type (the build function already knows the layer, so the caller doesn't have to re-derive these). build_mtp_layer _callables calls build_layer_callables on its inner layer so the same pass-through works recursively. * Rename the local variable attn_forward (and the helper submodule_mtp_attn _forward) inside build_mtp_layer_callables to pre_dispatch_forward / submodule_mtp_pre_dispatch_forward — the schedule slot was renamed in 363f4bd but the local names inside the MTP wrapper were missed. Also tighten the assert message in GraphableMegatronModule.init_backward_dw _wrapper: the wrapper is no longer hard-bound to TransformerLayer in documentation tone (MambaLayer.backward_dw exists now; Mamba just doesn't use the init_backward_dw_wrapper path because _BackwardDWWrapper still asserts TransformerLayer in __init__). The assertion checks config, not type, and the message now says so. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Round-2 renamed the schedule slot 'attn' to 'pre_dispatch_computation' and round-3 caught the MTP-wrapper local-variable parallel, but left peer docstrings and the GPT-side nested function un-renamed. Sweep the remaining references so a reader of the docs sees the same slot name the code uses: * common/model_chunk_schedule_plan.py: update _build_callable_nodes docstring, the run() overlap diagram (attn_fwd/attn_bwd -> pre_dispatch_fwd/pre_dispatch_bwd), and the two post-overlap comments. * common/utils.py: update should_free_input, TransformerLayerNode, and _BackwardDWWrapper docstrings; update the inline [attn, fake, mlp, fake] dense-layer list comment. * hybrid/fine_grained_callables.py: update the HybridStackNode docstring to reference the pre_dispatch_computation slot. * gpt/fine_grained_callables.py: rename submodule_attn_forward -> submodule_pre_dispatch_forward (def + binding), rewrite the build_transformer_layer_callables 5-callables enumeration and Returns block to match the real slot names and backward_dw dict keys, fix the 'attnention' typo, and disambiguate the dispatch_forward 'attn submodule' comment. No behavior change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…-queue d2daf29 ('Cleanup hybrid EP-overlap callables...') moved the MoE shared-experts wgrad from the pre_dispatch_computation slot into the mlp slot's bwd_dw_callables list. The commit message argued the relocation was equivalent because TransformerLayerNode.backward_dw iterates the list and calls .backward_dw() on each entry. That equivalence does not hold: TE's delay-wgrad model only puts to the wgrad queue inside the autograd backward (dgrad). Specifically, transformer_engine/.../linear.py calls ctx.wgrad_store.put(...) from inside the autograd Function's backward method, not from forward. So module.backward_dw() requires that module's autograd backward to have already run. For shared experts the relevant autograd backward lives in the pre_dispatch_computation slot (its forward is part of _run_moe_preprocess). The schedule plan calls mlp.backward_dw() before pre_dispatch_computation.backward(), so the shared-experts wgrad queue is still empty when the mlp slot tries to pop it, producing 'RuntimeError: Pop empty queue' from TE's _common.py. Restore the original wiring: register a _SharedExpertBackwardDWWrapper in pre_bwd_dw so the shared-experts wgrad fires after the pre_dispatch slot's autograd backward. The wrapper delegates to mlp.backward_dw(routed_experts=False, shared_experts=True) so the MoELayer-level guards (use_shared_expert / shared_expert_overlap / moe_latent_size + fc1_latent_proj) remain authoritative. Reproduction: DeepSeek-V3-Lite-deter-Hybrid with --hybrid-layer-pattern '*-*-|*-*-|*-*-|*-*-|*E*E|*E*E|*E*E|*E*E', PP=4 EP=2, --overlap-moe-expert-parallel-comm --delay-wgrad-compute, 20 train iterations. Bisect identified d2daf29 as the first failing commit; 55d1e34 ran cleanly with bitwise-identical loss to the eager baseline. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…uped pre-dispatch The grouped HybridStack pre_dispatch_computation closure calls TransformerLayer._forward_attention directly for ATTENTION / DSA / GDN pre-layers (rather than __call__, to avoid double-applying the post-attention residual through mlp_bda + _forward_mlp). The downside: _forward_attention returns the raw bias_dropout_add output, which can be a view tensor produced by the fused/JIT BDA kernel. The full forward() path at transformer_layer.py:895 normalizes this via make_viewless_tensor before returning; the shortcut here did not. When the view's non-canonical strides cross the schedule-node boundary into a downstream cuBLAS matmul (the terminal MLP's pre_mlp_layernorm linear, the next attention's linear_qkv, or the MoE router/dispatcher), the matmul kernel's algorithm-selection heuristics can pick a different algo based on stride layout. Different processes invoking the same model with the same seed can land in different deterministic clusters and produce ~1e-5 bit drift on the forward output of grouped layers containing GDN/attention pre-layers — even with the full Paul Gibbons deterministic recipe applied. Reproduction (before this commit): --hybrid-layer-pattern '*-*-|*-*-|*-*-|*-*-|[G-][G-]|[G-][G-]|[G-][G-]|[G-][G-]' Full deterministic recipe: --deterministic-mode true, NVTE_ALLOW_NONDETERMINISTIC_ALGO=0, MAMBA_DETERMINISTIC=1, CAUSAL_CONV1D_DETERMINISTIC=1, TRITON_ENABLE_AUTOTUNE=0, NCCL_ALGO=Ring, --attention-backend flash. Two reruns: iter 1 loss differs by ~1e-5; subsequent iters drift in a bounded 1e-5..1e-4 band. With this commit: iter 1 is bitwise across reruns. Verified on draco 8 H100 GPUs PP=4 EP=2. The Mamba branch is unaffected: it uses __call__, which already routes through the full forward() that ends with make_viewless_tensor. Only the _forward_attention shortcut needed this companion fix. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/claude review |
Summary
This MR adds grouped HybridStack support and extends EP-overlap scheduling/checkpoint compatibility to grouped hybrid layers.
Key changes:
[*-],[*E],M[M*]-.HybridStackinstances.TransformerSchedulePlanusage toHybridStackSchedulePlan, preserving backward-compatible aliases.HybridModel.sharded_state_dict()match GPT behavior by dropping emptyoutput_layer._extra_state.Checkpoint Compatibility
Grouped HybridStack checkpoint keys now canonicalize to logical Transformer layer keys. For example,
[*-]maps both attention and MLP weights to:instead of using separate physical layer ids.
final_normis also sharded asfinal_layernormfor Transformer compatibility.Verified both directions:
Cross-load status:
Validation
unit tests:
Result:
Also ran:
git diff --check--dist-ckpt-strictness raise_unexpectedNotes
The checkpoint cross-load validation intentionally uses model-weight / finetune-style loading with optimizer and RNG state skipped. Full non-finetune resume across Transformer vs grouped HybridModel is still expected to hit config arg mismatches because Transformer uses logical
num_layers=16, while grouped Hybrid uses physical symbols grouped into 16 logical layers.