[Upstream] Update megatron version to dev branch (Feb 13) and rebase modifications#13
Merged
Merged
Conversation
4ac0f2f to
07220d2
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR rebases from radixark Megatron fork miles-20260218 and resolve conflicts.
Upgrade Megatron from Dec 17 (3714d81) to Feb 13 (1dcf0da)
This PR has been reviewed by Yueming.
The desription is generated by claude code and reviewed by me.
dp_reshardablefix should be removed after [DO NOT MERGE] Upstream feb 26 #15Rebase Diff Report: miles-main vs rebase-miles-main
Base branches:
miles-main(rdxa/miles-main) diverges fromrdxa/devat commit3714d81d4(older base)rebase-miles-mainrebases the same feature set ontordxa/devcommit1dcf0dafaCommit Mapping
[1/8]fix: misc compatibility fixes for PyTorch and TE[1/8]samerdxa/dev[2/8]support partial checkpoint loading[2/8]samerdxa/dev[3/8]post-attention and post-MLP layernorm[3/8]same[4/8]MLA RoPE triton kernel fix[4/8]same[5/8]detach output layer params for RL[5/8]support MTP training in RL[6/8]support MTP training in RL[7/8]support rollout routing replay (R3)[6/8]R3 + bypass for MTP layers[fix]bypass r3 for mtp layer[8/8]INT4 fake QAT for MoE[7/8]same[8/8]fix: CUDA IPC incompatibility from Megatron bump[9/8]fix: dp_reshardable checkpoint backward compatDetailed Differences
[3/8] feat: add post-attention and post-MLP layernorm support
Files changed:
gpt_layer_specs.py,transformer_layer.py,transformer_config.py,arguments.pygpt_layer_specs.py— structural adaptationget_transformer_layer_spec_for_backend()helper — only one insertion point neededget_transformer_layer_spec_for_backend()helper and inlined spec construction into two separate branches. Functionally equivalent.transformer_layer.py— identicalBoth versions add the same +24 lines: dataclass fields, module build in
__init__, and layernorm application after self-attention and MLP outputs. No difference.Note: An earlier version of the rebase had a bug (duplicate
recompute_pre_mlp_layernormblock) which was already fixed before this report.transformer_config.py— minor differencepost_self_attn_layernorm,post_mlp_layernorm, anduse_gated_attentionfieldspost_self_attn_layernormandpost_mlp_layernormonlyuse_gated_attentionhas no consumers in either codebase; dropped during rebase.arguments.py— different mechanismgroup.add_argument()calls (--post-self-attn-layernorm,--post-mlp-layernorm,--use-gated-attention) + 2 explicitkw_args[...] = args.xxxassignments incore_transformer_config_from_args()post_self_attn_layernormandpost_mlp_layernormfrom theexcludelist inArgumentGroupFactory(TransformerConfig, exclude=exclude)+ keeps the 2 explicitkw_argsassignments (redundant but harmless)ArgumentGroupFactoryto auto-generate CLI args fromTransformerConfigdataclass fields. Adding manualadd_argument()for the same fields causes an argparse conflict error (conflicting option string). Removing them from theexcludelist lets the auto-generation handle it. The explicitkw_argsassignments are redundant (the auto-loop incore_transformer_config_from_argsalready does this) but kept for clarity.[5/8] feat: support MTP training in RL (rebased from miles-main [6/8], with [5/8] detach logic folded in)
miles-main had two separate commits:
[5/8]detach output layer params — edits existingcompute_output_layer_and_language_model_loss()inlanguage_module.py(1 file, 9 insertions, 1 deletion)[6/8]support MTP training in RL — addsmtp_kwargsinterface, changes MTP label/loss_mask flow ingpt_model.py+multi_token_prediction.pyIn range-diff, miles-main
[5/8]is dropped as a standalone commit, and miles-main[6/8]is rewritten into rebase[5/8]with an expanded commit message that includes detach + MTP behavior.Commit message vs actual behavior (miles-main)
[5/8]message says "detach output layer params for RL training", but code only changes the non-fused branch inlanguage_module.py(functional_callwith detached module params). The fusedlinear_cross_entropypath still uses the caller-providedweightas-is.[6/8]message says "support MTP training in RL". It addsmtp_kwargsand MTP flow changes, and callscompute_output_layer_and_language_model_loss(...)from MTP path, but fused-path detach is still not fully enforced there.[5/8]message explicitly combines these concerns (detach + mtp_kwargs + label/loss_mask roll), and implementation detaches both fused and non-fused MTP output-layer paths ingpt_model.py.Detach implementation — where it lives
compute_output_layer_and_language_model_loss()inlanguage_module.py+ MTP callsite ingpt_model.pygpt_model.py_postprocesslanguage_module.pyfunctional_calldetach in non-fused branch)functional_callwith all params detached +col_linear_kwargs={'weight': output_weight.detach()}functional_callwith all params detached +weight=output_weight.detach()in kwargsweightparameter comes from caller asself.shared_embedding_or_output_weight()— NOT detachedweight=output_weight.detach()— correctly detachedKey difference: In miles-main, the fused path (
linear_cross_entropy) receives theweightparameter from the method signature, which is passed by the caller ingpt_model.pyasself.shared_embedding_or_output_weight()without.detach(). This means the fused path does not block MTP gradient from flowing back to the output layer. The non-fused path is correct in both versions.rebase-miles-main fixes this by detaching weight in both paths at the call site.
mtp_kwargsinterface — identicalBoth versions add the same
mtp_kwargs: Optional[dict] = {}parameter, the samemtp_labelssourcing frommtp_kwargs['mtp_labels'], and the sameloss_maskroll logic.multi_token_prediction.py— identicalBoth versions make the same changes:
position_idsNone check,decoder_input.detach(),keep_graph=False, and_checkpointed_forwardrewrite to support non-tensor arguments.Gradient flow (both versions intend the same behavior)
[6/8] feat: support rollout routing replay (R3) and bypass for MTP layers (was miles-main [7/8] + [fix])
miles-main had two separate commits:
[7/8]support rollout routing replay — adds R3 integration (2 files, 6 lines)[fix]bypass r3 for mtp layer — addsis_mtpbypass (4 files, 13 lines)rebase-miles-main merges both into a single
[6/8], and additionally replaces rdxa/dev's built-inRouterReplayclass.API — identical
from miles.utils.replay_base import routing_replay_managerrouting_replay_manager.register_to_module(self, "routing_replay")routing_replay_manager.get_topk_fn(compute_topk, return_probs=True)routing_replay_manager.get_topk_fn(_compute_topk, return_probs=True)— rdxa/dev renamed internal function; functionally equivalentis_mtp)[fix]addsis_mtpflag,set_is_mtp()method, bypass logicrdxa/dev
RouterReplayhandlingrdxa/dev has its own built-in
RouterReplayclass (router_replay.py), withmoe_enable_routing_replayconfig and parameter-passing style integration. rebase-miles-main [6/8]:RouterReplayimports and usage fromrouter.pyandmoe_utils.pyrouter_replay=self.router_replayparameter passingmiles.utils.replay_base.routing_replay_manager(same as miles-main)router_replay.pyfile still exists but is dead code (no imports remain)[8/8] fix: CUDA IPC incompatibility from Megatron bump (new)
Not present on miles-main. Commit message documents a failure after rebasing to
rdxa/dev: colocated IPC weight update hitstorch.AcceleratorError: CUDA error: invalid argumentduring CUDA tensor serialization withtorch.multiprocessing.Motivation (from commit message): TMS hook behavior from upstream Megatron bump can make allocator behavior IPC-incompatible in this flow.
Code-level fix in this commit:
dynamic_context.py— disablestorch_memory_saverhook mode in this context (HAVE_TORCH_MEMORY_SAVER = False)This commit introduce this bug. NVIDIA@42986ac
[9/8] fix: dp_reshardable checkpoint backward compat in Megatron core (new)
Background: When loading a
dp_reshardablecheckpoint saved with a different DP world size, bucket counts may differ. Megatron'ssharded_param_state_dp_reshardablepads the bucket state list with{"padding": True}entries for alignment, but the loading side had two bugs:dict_utils.mergelist length mismatch —merge()raisesValueErrorwhen shard file lists have different lengths (extra padding entries from save side). Fix: detect optimizer/param_state paths via_is_optimizer_param_state_key()and truncate the longer list (x1) to match x2.distrib_optimizer.pyKeyError on['padding']— Old checkpoint entries lack thepaddingfield entirely. Fix:bucket_state_elem.get('padding', False)instead ofbucket_state_elem['padding'].mapping.pyflattened_range exception —ShardedTensor.__init__raisesCheckpointingExceptionforflattened_range. Fix: downgrade to deprecation warning (logged once) for backward compat with older checkpoint formats.Files changed:
megatron/core/dist_checkpointing/dict_utils.py—_is_optimizer_param_state_key()helper +merge()truncation logicmegatron/core/optimizer/distrib_optimizer.py—.get('padding', False)inload_parameter_state_from_dp_reshardablemegatron/core/dist_checkpointing/mapping.py—flattened_rangedeprecation warningSummary of Changes vs miles-main
gpt_layer_specs.pyarguments.pyArgumentGroupFactoryexclude list instead of manualadd_argumentgpt_model.pygpt_model.pyvslanguage_module.py_postprocess; miles-main touched existinglanguage_module.pymethod (no new method added)moe_utils.py,router.pyRouterReplayreplaced withmiles.utils.replay_baseuse_gated_attentiontransformer_config.py,arguments.pymapping.py,dynamic_context.pydict_utils.py,distrib_optimizer.py,mapping.py