Skip to content

[Docker] Megatron version bump to Feb 13 and upgrade fla==0.4.1#643

Merged
guapisolo merged 17 commits into
radixark:mainfrom
guapisolo:feat/megatron-bump
Mar 4, 2026
Merged

[Docker] Megatron version bump to Feb 13 and upgrade fla==0.4.1#643
guapisolo merged 17 commits into
radixark:mainfrom
guapisolo:feat/megatron-bump

Conversation

@guapisolo
Copy link
Copy Markdown
Collaborator

@guapisolo guapisolo commented Feb 26, 2026

NOTICE: SOMEONE NOTICE THERE IS PROBLEM TO RUN THIS PR WITH GLM-5. IF PROBLEM EXIST, try downgrade megatron, turn on DEPRECATED_MEGATRON_COMPATIBLE and try.

ci-megatron-pr: upstream/1dcf0dafa

PR related: radixark/Megatron-LM#13

There is some critical megatron changes between Dec 18, 2025 and Feb 13, 2026.

If there is any problem with new megatron, you can revert it to the old version Rebased Megatron from Dec 18 and add DEPRECATED_MEGATRON_COMPATIBLE to align with old version code.

Adapt to upstream Megatron-LM breaking changes

1. TransformerConfig auto-registration

Megatron-LM now auto-generates CLI arguments directly from TransformerConfig dataclass fields (#2896), and passes a pre-built config + pg_collection into model_provider (#2608).

model_provider.py — Accept the new config and pg_collection kwargs in all three provider paths (custom, bridge, default). Assert config is None since Miles builds its own config from args via core_transformer_config_from_args.

arguments.py — The old CLI arg --norm-epsilon was manually mapped to TransformerConfig.layernorm_epsilon. With auto-registration, this mapping is gone — the field is now exposed directly as args.layernorm_epsilon. Update the HF config validation accordingly.

2. Fix TP weight gather to use partition_stride

Background. In tensor-parallel SwiGLU/GLU models, each TP rank's linear_fc1.weight stores interleaved [gate, up] blocks — the correct partition_stride is 2. Megatron-LM previously hard-coded stride=1 for all parameters; our old code compensated with manual chunk(2) → reorder → cat for fc1 and a partition_dim swap workaround for fc2's grouped-MoE bug.

Megatron-LM #2708 partially fixed this: linear_fc1 now correctly set partition_stride=2 (when gated_linear_unit=True) and linear_fc2 set partition_stride=1. But when --moe-grouped-gemm is set, partition_stride is still 1.

So the solution is to remove the original assertion partition_stride == 1

3. Remove old Megatron ckpt format dependents fully_sharded_model_space to dp_reshardable.

Related bug fix can be found at radixark/Megatron-LM#13

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @guapisolo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request updates the integration with Megatron-LM to ensure compatibility with recent changes, specifically regarding how tensor parallel sharding is handled for certain layers. It refactors the parameter gathering mechanism to support strided partitioning, adjusts model provider function signatures to accommodate new configuration parameters, and includes a minor correction to an argument mapping for Hugging Face configurations. These changes collectively aim to maintain correct behavior and functionality with an updated Megatron version.

Highlights

  • Megatron Parallel State Initialization: The initialization of the Megatron parallel state has been moved to an earlier point in the actor's init method, ensuring it's set up before certain debug checks.
  • Model Provider Signature Update: The wrapped_model_provider and model_provider functions now accept additional config and pg_collection parameters, with an assertion added to ensure config is None as Miles manages its construction.
  • Strided Tensor Parallel Sharding Support: New helper functions _gather_with_stride and _check_partition_stride were introduced to correctly handle strided (interleaved) tensor parallel sharding, particularly for linear_fc1.weight in GLU/SwiGLU layers, and the all_gather_param and all_gather_params_async functions were refactored to utilize this new logic.
  • Hugging Face Config Mapping Correction: A mapping in the _get_hf_config_mapping function was corrected, changing norm_epsilon to layernorm_epsilon for rms_norm_eps to align with Hugging Face configurations.
Changelog
  • miles/backends/megatron_utils/actor.py
    • Moved the create_megatron_parallel_state call to an earlier point in the init method.
    • Removed a redundant create_megatron_parallel_state call later in the init method.
  • miles/backends/megatron_utils/model_provider.py
    • Added config: TransformerConfig | None = None and pg_collection=None parameters to wrapped_model_provider.
    • Added an assertion assert config is None within wrapped_model_provider.
    • Introduced a new wrapped_bridge_provider function to handle the return from provider.provide.
    • Added config: TransformerConfig | None = None and pg_collection=None parameters to model_provider.
    • Added an assertion assert config is None within model_provider.
  • miles/backends/megatron_utils/update_weight/common.py
    • Imported the logging module and added a logger instance.
    • Implemented _gather_with_stride to gather partitions respecting partition_stride.
    • Implemented _check_partition_stride to validate partition_stride for known parameter patterns.
    • Updated the docstring for all_gather_param to reflect handling of strided partitioning.
    • Refactored all_gather_param to use _check_partition_stride and _gather_with_stride, removing previous ad-hoc handling.
    • Modified all_gather_params_async to pass param.partition_stride in gather_tasks.
    • Refactored the processing logic in all_gather_params_async to use _check_partition_stride and _gather_with_stride.
  • miles/utils/arguments.py
    • Updated the mapping for rms_norm_eps from norm_epsilon to layernorm_epsilon.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request appears to be updating the Megatron integration to a newer version. The changes include updating model provider function signatures, refactoring weight gathering logic to properly support strided tensor parallelism via partition_stride, and adjusting argument names for validation. Overall, these changes align the codebase with upstream updates.

However, I've identified a critical regression in miles/backends/megatron_utils/actor.py where the refactoring of parallel_state initialization can lead to incorrect behavior or crashes when using virtual pipeline parallelism. My review includes a specific comment and code suggestion to address this issue.

(self.model, self.optimizer, self.opt_param_scheduler, loaded_rollout_id) = initialize_model_and_optimizer(
args, role
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The parallel_state should be re-initialized here after the model is created. The initial creation at line 83 with model=None is sufficient for the debug_rollout_only path, but for the main training path, parallel_state requires information from the model's configuration, especially when virtual pipeline parallelism (vpp_size > 1) is used.

Removing this line causes the training process to use an incomplete parallel_state, which can lead to an AssertionError or incorrect behavior during training. Re-adding this line ensures the parallel state is always correctly configured.

Suggested change
self.parallel_state = create_megatron_parallel_state(model=self.model)

@guapisolo guapisolo marked this pull request as draft February 26, 2026 02:19
@guapisolo guapisolo marked this pull request as ready for review February 26, 2026 05:23
@guapisolo guapisolo force-pushed the feat/megatron-bump branch 4 times, most recently from b81daf9 to 6b7856d Compare February 28, 2026 06:31
debug_rollout_only mode calls train() which needs parallel_state for
rollout data preprocessing and logging. Previously parallel_state was
only created after model initialization, which is skipped in
debug_rollout_only mode. Move it before the early return with model=None.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@guapisolo guapisolo force-pushed the feat/megatron-bump branch 4 times, most recently from 772a011 to de4a130 Compare March 2, 2026 03:44
guapisolo and others added 11 commits March 2, 2026 23:47
Co-authored-by: Yueming Yuan <yym022502@gmail.com>
…d fc1/fc2 logic

Megatron-LM PR #2708 fixed partition_stride to correctly report stride=2
for linear_fc1 (GLU/SwiGLU interleaved [gate, up]) and stride=1 for
linear_fc2. Replace the old hard-coded fc1 chunk reordering and fc2
partition_dim workaround with generic stride-aware gathering.

Add _check_partition_stride() asserts to validate expected stride values
for linear_fc1 (must be 2) and linear_fc2 (must be 1).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…name

Megatron-LM renamed the config field from norm_epsilon to
layernorm_epsilon. Update the HF config validation mapping accordingly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…o_hf

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@guapisolo guapisolo force-pushed the feat/megatron-bump branch from 6fa3ec6 to 8db6954 Compare March 3, 2026 07:53
Copy link
Copy Markdown
Collaborator

@yueming-yuan yueming-yuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as offline discussed

@guapisolo guapisolo changed the title [Docker] Megatron version bump [Docker] Megatron version bump to Feb 13 and upgrade fla==0.4.1 Mar 4, 2026
@guapisolo guapisolo merged commit 051cd15 into radixark:main Mar 4, 2026
38 of 64 checks passed
JD-ETH pushed a commit to JensenFire/miles that referenced this pull request Apr 11, 2026
…xark#643)

Co-authored-by: Yueming Yuan <yym022502@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
GuanxingLu pushed a commit to GuanxingLu/miles that referenced this pull request Apr 21, 2026
…xark#643)

Co-authored-by: Yueming Yuan <yym022502@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants