From 2012e6ab87adf964404c7b9a6a287c9873af7f8b Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Wed, 25 Feb 2026 23:52:34 +0000 Subject: [PATCH 01/16] fix: create parallel_state before debug_rollout_only early return debug_rollout_only mode calls train() which needs parallel_state for rollout data preprocessing and logging. Previously parallel_state was only created after model initialization, which is skipped in debug_rollout_only mode. Move it before the early return with model=None. Co-Authored-By: Claude Opus 4.6 --- miles/backends/megatron_utils/actor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 4c430c1529..91ec391919 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -82,6 +82,7 @@ def init( logger.info(f"Set torch_memory_saver.memory_margin_bytes to {x}") torch_memory_saver.memory_margin_bytes = x + self.parallel_state = create_megatron_parallel_state(model=None) if self.args.debug_rollout_only: return 0 @@ -99,8 +100,6 @@ def init( args, role ) - self.parallel_state = create_megatron_parallel_state(model=self.model) - if role == "critic": if self.args.offload_train: self.sleep() From 34ed2592035420587bf7ced6b4c34c7a2cea9b76 Mon Sep 17 00:00:00 2001 From: Jiajun Li <48857426+guapisolo@users.noreply.github.com> Date: Mon, 2 Mar 2026 23:47:37 -0800 Subject: [PATCH 02/16] Fix bug Co-authored-by: Yueming Yuan --- miles/backends/megatron_utils/actor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 91ec391919..9f6ef27755 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -82,8 +82,8 @@ def init( logger.info(f"Set torch_memory_saver.memory_margin_bytes to {x}") torch_memory_saver.memory_margin_bytes = x - self.parallel_state = create_megatron_parallel_state(model=None) if self.args.debug_rollout_only: + self.parallel_state = create_megatron_parallel_state(model=None) return 0 if role == "critic": @@ -100,6 +100,8 @@ def init( args, role ) + self.parallel_state = create_megatron_parallel_state(model=self.model) + if role == "critic": if self.args.offload_train: self.sleep() From 68635cf2da06e76791f58f123f6beae03a0be522 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Sat, 28 Feb 2026 00:30:46 +0000 Subject: [PATCH 03/16] add save hash check ci --- miles/backends/megatron_utils/ci_utils.py | 90 ++++++++++ miles/backends/megatron_utils/model.py | 8 + miles/utils/arguments.py | 8 + tests/e2e/ckpt/test_glm47_flash_ckpt.py | 194 ++++++++++++++++++++++ tests/e2e/ckpt/test_qwen3_4B_ckpt.py | 21 ++- 5 files changed, 318 insertions(+), 3 deletions(-) create mode 100644 tests/e2e/ckpt/test_glm47_flash_ckpt.py diff --git a/miles/backends/megatron_utils/ci_utils.py b/miles/backends/megatron_utils/ci_utils.py index e6ce784ca2..78942882f4 100644 --- a/miles/backends/megatron_utils/ci_utils.py +++ b/miles/backends/megatron_utils/ci_utils.py @@ -1,12 +1,102 @@ """CI utilities for Megatron backend testing.""" +import hashlib +import json import logging +import re from collections.abc import Sequence +from pathlib import Path +import torch +from megatron.core import parallel_state as mpu from megatron.core.distributed import DistributedDataParallel as DDP logger = logging.getLogger(__name__) +_LAYER_PATTERNS = ( + re.compile(r"(?:^|\.)encoder\.layers\.(\d+)(?:\.|$)"), + re.compile(r"(?:^|\.)decoder\.layers\.(\d+)(?:\.|$)"), + re.compile(r"(?:^|\.)layers\.(\d+)(?:\.|$)"), + re.compile(r"(?:^|\.)layer\.(\d+)(?:\.|$)"), +) + + +def _layer_key(name: str) -> str: + for pattern in _LAYER_PATTERNS: + match = pattern.search(name) + if match: + return f"layer_{int(match.group(1)):04d}" + return "non_layer" + + +def _hash_tensor_bytes(tensor: torch.Tensor) -> bytes: + data = tensor.detach() + if data.is_cuda: + data = data.cpu() + if not data.is_contiguous(): + data = data.contiguous() + return data.view(torch.uint8).numpy().tobytes() + + +def compute_model_hashes_by_layer(model: Sequence[DDP]) -> dict[str, str]: + """Compute per-layer SHA256 hashes over parameter bytes. + + Hash input includes parameter name, shape, dtype, and raw bytes. + """ + hashers: dict[str, hashlib._Hash] = {} + for pp_idx, model_chunk in enumerate(model): + for name, param in sorted(model_chunk.named_parameters(), key=lambda x: x[0]): + if param is None: + continue + full_name = f"pp{pp_idx}.{name}" + key = _layer_key(full_name) + hasher = hashers.setdefault(key, hashlib.sha256()) + hasher.update(full_name.encode("utf-8")) + hasher.update(str(tuple(param.shape)).encode("utf-8")) + hasher.update(str(param.dtype).encode("utf-8")) + hasher.update(_hash_tensor_bytes(param)) + return {k: v.hexdigest() for k, v in sorted(hashers.items(), key=lambda x: x[0])} + + +def _hash_file_path(base_dir: str | Path, iteration: int) -> Path: + tp_rank = mpu.get_tensor_model_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + dp_rank = mpu.get_data_parallel_rank(with_context_parallel=True) + cp_rank = mpu.get_context_parallel_rank() + base = Path(base_dir) + iter_dir = base if base.name.startswith("iter_") else base / f"iter_{int(iteration):07d}" + return iter_dir / f"model_hash_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}_cp{cp_rank}.json" + + +def save_model_hashes(args, model: Sequence[DDP], iteration: int, hashes: dict[str, str]) -> None: + if not args.ci_test or not args.ci_save_model_hash: + return + path = _hash_file_path(args.save, iteration) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + json.dump(hashes, f, indent=2, sort_keys=True) + logger.info(f"[CI hash] Saved model hashes to {path}") + + +def check_model_hashes(args, model: Sequence[DDP], iteration: int) -> None: + if not args.ci_test or not args.ci_check_model_hash: + return + path = _hash_file_path(args.load, iteration) + if not path.is_file(): + raise AssertionError(f"[CI hash] Hash file missing: {path}") + with path.open("r", encoding="utf-8") as f: + expected = json.load(f) + actual = compute_model_hashes_by_layer(model) + if actual != expected: + missing = sorted(set(expected) - set(actual)) + extra = sorted(set(actual) - set(expected)) + mismatched = sorted(k for k in expected.keys() & actual.keys() if expected[k] != actual[k]) + raise AssertionError( + "[CI hash] Model hash mismatch after load. " + f"missing={missing[:5]}, extra={extra[:5]}, mismatched={mismatched[:5]}" + ) + logger.info(f"[CI hash] Model hashes match for iteration {iteration}.") + def check_mtp_only_grad(model: Sequence[DDP], step_id: int) -> None: """Check that only MTP parameters have non-zero gradients. diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index b6cd49d9a1..e08b15c207 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -29,6 +29,7 @@ from ..training_utils.loss import loss_function from ..training_utils.parallel import ParallelState from .checkpoint import load_checkpoint, save_checkpoint, save_checkpoint_with_lora +from .ci_utils import check_model_hashes, compute_model_hashes_by_layer, save_model_hashes from .lora_utils import is_lora_enabled, is_lora_model from .model_provider import get_model_provider_func from .parallel import get_packed_seq_params @@ -686,6 +687,9 @@ def save( opt_param_scheduler (OptimizerParamScheduler): LR/WD scheduler. """ args = get_args() + hashes = None + if args.ci_test and args.ci_save_model_hash: + hashes = compute_model_hashes_by_layer(model) if should_disable_forward_pre_hook(args): disable_forward_pre_hook(model) @@ -703,6 +707,8 @@ def save( preprocess_common_state_dict_fn=None, ) + if hashes is not None: + save_model_hashes(args, model, iteration, hashes) if should_disable_forward_pre_hook(args): enable_forward_pre_hook(model) @@ -799,6 +805,8 @@ def initialize_model_and_optimizer( ) clear_memory() + check_model_hashes(args, model, iteration) + opt_param_scheduler.step(increment=iteration * args.global_batch_size) return model, optimizer, opt_param_scheduler, iteration diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 14ae4ff7b8..1385cde788 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1433,6 +1433,14 @@ def add_ci_arguments(parser): type=str, default=None, ) + parser.add_argument( + "--ci-save-model-hash", + action="store_true", + ) + parser.add_argument( + "--ci-check-model-hash", + action="store_true", + ) return parser def add_user_provided_function_arguments(parser): diff --git a/tests/e2e/ckpt/test_glm47_flash_ckpt.py b/tests/e2e/ckpt/test_glm47_flash_ckpt.py new file mode 100644 index 0000000000..72ee0127d0 --- /dev/null +++ b/tests/e2e/ckpt/test_glm47_flash_ckpt.py @@ -0,0 +1,194 @@ +import os +from argparse import ArgumentParser + +import miles.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("MILES_TEST_ENABLE_EVAL", "1"))) +TIGHT_HOST_MEMORY = bool(int(os.environ.get("MILES_TEST_TIGHT_HOST_MEMORY", "1"))) +USE_DEEPEP = bool(int(os.environ.get("MILES_TEST_USE_DEEPEP", "0"))) + +MODEL_NAME = "GLM-4.7-Flash" +MODEL_TYPE = "glm4.7-flash" +NUM_GPUS = 8 + + +parser = ArgumentParser() +parser.add_argument("--async-save", action="store_true", help="Whether to test async save/load.") + + +def _get_latest_checkpointed_iteration() -> int: + latest_path = f"/root/models/{MODEL_NAME}_miles/latest_checkpointed_iteration.txt" + with open(latest_path, encoding="utf-8") as f: + latest_text = f.read().strip() + if not latest_text.isdigit(): + raise ValueError(f"Invalid latest checkpoint value: {latest_text}") + return int(latest_text) + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + # GLM-4.7-Flash requires a newer transformers version. + U.exec_command( + "pip install git+https://github.com/huggingface/transformers.git@" + "76732b4e7120808ff989edbd16401f61fa6a0afa --break-system-packages" + ) + U.exec_command(f"hf download zai-org/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"rm -rf /root/models/{MODEL_NAME}_miles") + U.hf_download_dataset("zhuzilin/dapo-math-17k") + U.hf_download_dataset("zhuzilin/aime-2024") + + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_GPUS, + dir_dst="/root/models", + ) + + +def execute(mode: str = "", ckpt_step: int | None = None): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}_torch_dist " + if mode == "save": + ckpt_args += f"--save /root/models/{MODEL_NAME}_miles " + ckpt_args += "--save-interval 2 " + elif mode == "async_save": + ckpt_args += f"--save /root/models/{MODEL_NAME}_miles " + ckpt_args += "--save-interval 2 " + ckpt_args += "--async-save " + ckpt_args += "--use-persistent-ckpt-worker " + elif mode == "load": + ckpt_args += f"--load /root/models/{MODEL_NAME}_miles " + ckpt_args += f"--ckpt-step {ckpt_step} " + + rollout_args = ( + "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type deepscaler " + "--num-rollout 3 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 1 " + "--global-batch-size 32 " + "--balance-data " + ) + + eval_args = "" + if ENABLE_EVAL: + eval_args = ( + "--eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 2048 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 4 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + f"--max-tokens-per-gpu {2048 if TIGHT_HOST_MEMORY else 32768} " + ) + + grpo_args = ( + "--advantage-estimator grpo " + f"{'' if TIGHT_HOST_MEMORY else '--use-kl-loss '}" + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + "--use-rollout-routing-replay " + "--use-miles-router " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 4 " + f"--sglang-mem-fraction-static {0.7 if TIGHT_HOST_MEMORY else 0.8} " + "--sglang-speculative-algorithm EAGLE " + "--sglang-speculative-num-steps 2 " + "--sglang-speculative-eagle-topk 1 " + "--sglang-speculative-num-draft-tokens 3 " + ) + + if USE_DEEPEP: + sglang_args += "--sglang-moe-a2a-backend deepep --sglang-deepep-mode auto " + + mtp_args = "--enable-mtp-training --mtp-loss-scaling-factor 0.2 " + + ci_args = "--ci-test " + if mode in {"save", "async_save"}: + ci_args += "--ci-save-model-hash " + if mode == "load": + ci_args += "--ci-check-model-hash " + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + ) + + if USE_DEEPEP: + misc_args += "--moe-token-dispatcher-type flex --moe-enable-deepep " + else: + misc_args += "--moe-token-dispatcher-type alltoall " + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{mtp_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + extra_env_vars={ + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", + "MILES_TEST_R3_THRESHOLD": "1.0", + }, + ) + + +if __name__ == "__main__": + args = parser.parse_args() + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute("save" if not args.async_save else "async_save") + latest_step = _get_latest_checkpointed_iteration() + execute("load", ckpt_step=latest_step) diff --git a/tests/e2e/ckpt/test_qwen3_4B_ckpt.py b/tests/e2e/ckpt/test_qwen3_4B_ckpt.py index 0df4492e10..e1aceb4067 100644 --- a/tests/e2e/ckpt/test_qwen3_4B_ckpt.py +++ b/tests/e2e/ckpt/test_qwen3_4B_ckpt.py @@ -16,6 +16,15 @@ parser.add_argument("--async-save", action="store_true", help="Whether to test async save/load.") +def _get_latest_checkpointed_iteration() -> int: + latest_path = f"/root/models/{MODEL_NAME}_miles/latest_checkpointed_iteration.txt" + with open(latest_path, encoding="utf-8") as f: + latest_text = f.read().strip() + if not latest_text.isdigit(): + raise ValueError(f"Invalid latest checkpoint value: {latest_text}") + return int(latest_text) + + def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") @@ -28,7 +37,7 @@ def prepare(): ) -def execute(mode: str = ""): +def execute(mode: str = "", ckpt_step: int | None = None): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}_torch_dist " if mode == "save": ckpt_args += f"--save /root/models/{MODEL_NAME}_miles " @@ -37,9 +46,10 @@ def execute(mode: str = ""): ckpt_args += f"--save /root/models/{MODEL_NAME}_miles " ckpt_args += "--save-interval 2 " ckpt_args += "--async-save " + ckpt_args += "--use-persistent-ckpt-worker " elif mode == "load": ckpt_args += f"--load /root/models/{MODEL_NAME}_miles " - ckpt_args += "--ckpt-step 1 " + ckpt_args += f"--ckpt-step {ckpt_step} " rollout_args = ( "--prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl " @@ -93,6 +103,10 @@ def execute(mode: str = ""): sglang_args = "--rollout-num-gpus-per-engine 2 --sglang-mem-fraction-static 0.8 --sglang-cuda-graph-bs 1 2 4 8 16 " ci_args = "--ci-test " + if mode in {"save", "async_save"}: + ci_args += "--ci-save-model-hash " + if mode == "load": + ci_args += "--ci-check-model-hash " misc_args = ( # default dropout in megatron is 0.1 @@ -135,4 +149,5 @@ def execute(mode: str = ""): for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): os.environ.pop(proxy_var, None) execute("save" if not args.async_save else "async_save") - execute("load") + latest_step = _get_latest_checkpointed_iteration() + execute("load", ckpt_step=latest_step) From 91ac19e1626df47580561723b568168546745056 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Sat, 28 Feb 2026 02:09:37 +0000 Subject: [PATCH 04/16] update pr --- .github/workflows/pr-test.yml | 53 +++++++++++++++++++++++++++++--- .github/workflows/pr-test.yml.j2 | 2 ++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 915f4399d9..4e28059cbd 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -16,6 +16,8 @@ + + name: PR Test on: @@ -806,7 +808,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}] + info: [{"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py --async-save"}] defaults: run: working-directory: ${{ github.workspace }} @@ -1045,11 +1047,54 @@ jobs: rm -rf /tmp/ray/* 2>/dev/null || true sleep 3 + - name: Resolve dependency refs + id: resolve-refs + shell: bash + env: + PR_BODY: ${{ github.event.pull_request.body || '' }} + INPUT_MEGATRON_PR: ${{ github.event.inputs.ci_megatron_pr || '' }} + INPUT_SGLANG_PR: ${{ github.event.inputs.ci_sglang_pr || '' }} + run: | + # Priority: workflow_dispatch input > PR description > default + MEGATRON_PR="${INPUT_MEGATRON_PR}" + SGLANG_PR="${INPUT_SGLANG_PR}" + + # Parse PR description for "ci-megatron-pr:" and "ci-sglang-pr:" + if [ -n "$PR_BODY" ]; then + PR_MEGATRON_PR=$(echo "$PR_BODY" | grep -oP '(?<=ci-megatron-pr:\s)\S+' || true) + PR_SGLANG_PR=$(echo "$PR_BODY" | grep -oP '(?<=ci-sglang-pr:\s)\S+' || true) + [ -z "$MEGATRON_PR" ] && [ -n "$PR_MEGATRON_PR" ] && MEGATRON_PR="$PR_MEGATRON_PR" + [ -z "$SGLANG_PR" ] && [ -n "$PR_SGLANG_PR" ] && SGLANG_PR="$PR_SGLANG_PR" + fi + + # Defaults + [ -z "$MEGATRON_PR" ] && MEGATRON_PR="miles-main" + [ -z "$SGLANG_PR" ] && SGLANG_PR="sglang-miles" + + # Convert "#N" PR syntax to git fetch ref: "pull/N/head" + resolve_fetch_ref() { + local ref="$1" + if [[ "$ref" =~ ^#([0-9]+)$ ]]; then + echo "pull/${BASH_REMATCH[1]}/head" + else + echo "$ref" + fi + } + MEGATRON_FETCH=$(resolve_fetch_ref "$MEGATRON_PR") + SGLANG_FETCH=$(resolve_fetch_ref "$SGLANG_PR") + + echo "ci_megatron_pr=$MEGATRON_FETCH" >> $GITHUB_OUTPUT + echo "ci_sglang_pr=$SGLANG_FETCH" >> $GITHUB_OUTPUT + echo "Resolved: megatron=$MEGATRON_PR -> fetch=$MEGATRON_FETCH, sglang=$SGLANG_PR -> fetch=$SGLANG_FETCH" + - name: Install shell: bash + env: + MEGATRON_PR: ${{ steps.resolve-refs.outputs.ci_megatron_pr }} + SGLANG_PR: ${{ steps.resolve-refs.outputs.ci_sglang_pr }} run: | - cd /sgl-workspace/sglang && git fetch origin sglang-miles && git checkout FETCH_HEAD && git log --oneline -1 && pip install -e python --no-deps --break-system-packages - cd /root/Megatron-LM && git reset --hard HEAD && git log --oneline -1 && git apply $GITHUB_WORKSPACE/docker/patch/dev/megatron.patch && pip install -e . --no-deps --break-system-packages + cd /sgl-workspace/sglang && git reset --hard HEAD && git clean -fd && git fetch origin "$SGLANG_PR" && git checkout -f FETCH_HEAD && git log --oneline -1 && pip install -e python --no-deps --break-system-packages + cd /root/Megatron-LM && git reset --hard HEAD && git clean -fd && git fetch origin "$MEGATRON_PR" && git checkout -f FETCH_HEAD && git log --oneline -1 && pip install -e . --no-deps --break-system-packages cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages - name: Execute @@ -1078,7 +1123,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_glm47_flash_r3_mtp.py"}, {"num_gpus": 8, "test_file": "e2e/lora/test_lora_qwen2.5_0.5B.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 8, "test_file": "e2e/precision/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k_async.py"}] + info: [{"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_vl_4B_fsdp.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_fsdp_distributed.py"}, {"num_gpus": 8, "test_file": "e2e/fsdp/test_qwen3_0.6B_megatron_fsdp_align.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_30B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "e2e/megatron/test_mimo_7B_mtp_only_grad.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "e2e/megatron/test_glm47_flash_r3_mtp.py"}, {"num_gpus": 8, "test_file": "e2e/lora/test_lora_qwen2.5_0.5B.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 8, "test_file": "e2e/short/test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 8, "test_file": "e2e/precision/test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_qwen3_4B_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py"}, {"num_gpus": 8, "test_file": "e2e/ckpt/test_glm47_flash_ckpt.py --async-save"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 8, "test_file": "e2e/long/test_qwen2.5_0.5B_gsm8k_async.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index b07b7ee007..948b00b466 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -32,6 +32,8 @@ <% set ckpt_tests = [ {'test_file': 'e2e/ckpt/test_qwen3_4B_ckpt.py', 'num_gpus': 8}, {'test_file': 'e2e/ckpt/test_qwen3_4B_ckpt.py --async-save', 'num_gpus': 8}, + {'test_file': 'e2e/ckpt/test_glm47_flash_ckpt.py', 'num_gpus': 8}, + {'test_file': 'e2e/ckpt/test_glm47_flash_ckpt.py --async-save', 'num_gpus': 8}, ] %> <% set long_tests = [ From 2b7d667c6d7cfa5dc2c49e7913e0ab65aa849cc0 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Sat, 28 Feb 2026 03:38:32 +0000 Subject: [PATCH 05/16] fix lora test --- .../backends/megatron_utils/test_lora_model_branches.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/fast/backends/megatron_utils/test_lora_model_branches.py b/tests/fast/backends/megatron_utils/test_lora_model_branches.py index dc555bbdbd..496dbe5dae 100644 --- a/tests/fast/backends/megatron_utils/test_lora_model_branches.py +++ b/tests/fast/backends/megatron_utils/test_lora_model_branches.py @@ -170,6 +170,7 @@ def test_lora_raw_mode_skips_bridge(self, mock_lora_setup, mock_get_model, mock_ class TestSaveLoRaBranch: + @patch(f"{_MODEL_MODULE}.save_model_hashes") @patch(f"{_MODEL_MODULE}.enable_forward_pre_hook") @patch(f"{_MODEL_MODULE}.disable_forward_pre_hook") @patch(f"{_MODEL_MODULE}.should_disable_forward_pre_hook", return_value=False) @@ -177,7 +178,7 @@ class TestSaveLoRaBranch: @patch(f"{_MODEL_MODULE}.save_checkpoint_with_lora") @patch(f"{_MODEL_MODULE}.is_lora_model", return_value=True) def test_lora_model_calls_lora_save( - self, mock_is_lora, mock_save_lora, mock_get_args, mock_should, mock_disable, mock_enable + self, mock_is_lora, mock_save_lora, mock_get_args, mock_should, mock_disable, mock_enable, mock_save_hashes ): from miles.backends.megatron_utils.model import save @@ -186,6 +187,7 @@ def test_lora_model_calls_lora_save( mock_save_lora.assert_called_once() + @patch(f"{_MODEL_MODULE}.save_model_hashes") @patch(f"{_MODEL_MODULE}.enable_forward_pre_hook") @patch(f"{_MODEL_MODULE}.disable_forward_pre_hook") @patch(f"{_MODEL_MODULE}.should_disable_forward_pre_hook", return_value=False) @@ -193,7 +195,7 @@ def test_lora_model_calls_lora_save( @patch(f"{_MODEL_MODULE}.save_checkpoint") @patch(f"{_MODEL_MODULE}.is_lora_model", return_value=False) def test_non_lora_model_calls_regular_save( - self, mock_is_lora, mock_save_ckpt, mock_get_args, mock_should, mock_disable, mock_enable + self, mock_is_lora, mock_save_ckpt, mock_get_args, mock_should, mock_disable, mock_enable, mock_save_hashes ): from miles.backends.megatron_utils.model import save From ba2752427a3d0ec28a43628d1af8a67868420cae Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Wed, 25 Feb 2026 23:53:35 +0000 Subject: [PATCH 06/16] refactor: use partition_stride for TP all-gather instead of hard-coded fc1/fc2 logic Megatron-LM PR #2708 fixed partition_stride to correctly report stride=2 for linear_fc1 (GLU/SwiGLU interleaved [gate, up]) and stride=1 for linear_fc2. Replace the old hard-coded fc1 chunk reordering and fc2 partition_dim workaround with generic stride-aware gathering. Add _check_partition_stride() asserts to validate expected stride values for linear_fc1 (must be 2) and linear_fc2 (must be 1). Co-Authored-By: Claude Opus 4.6 --- .../megatron_utils/update_weight/common.py | 66 +++++++++++-------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/miles/backends/megatron_utils/update_weight/common.py b/miles/backends/megatron_utils/update_weight/common.py index 85fe76a1b8..a1aca6c5f6 100644 --- a/miles/backends/megatron_utils/update_weight/common.py +++ b/miles/backends/megatron_utils/update_weight/common.py @@ -1,4 +1,5 @@ import inspect +import logging import re from argparse import Namespace from collections.abc import Iterator, Sequence @@ -11,11 +12,37 @@ from miles.backends.megatron_utils.misc_utils import strip_param_name_prefix from miles.utils.types import ParamInfo +logger = logging.getLogger(__name__) + + +def _gather_with_stride( + param_partitions: list[torch.Tensor], partition_dim: int, partition_stride: int +) -> torch.Tensor: + """Gather partitions respecting partition_stride (strided/interleaved TP sharding).""" + if partition_stride == 1: + return torch.cat(param_partitions, dim=partition_dim) + # Interleaved (strided) partitioning, e.g. linear_fc1.weight under GLU/SwiGLU + chunks_per_rank = [p.chunk(partition_stride, dim=partition_dim) for p in param_partitions] + interleaved = [chunks_per_rank[r][s] for s in range(partition_stride) for r in range(len(param_partitions))] + return torch.cat(interleaved, dim=partition_dim) + + +def _check_partition_stride(name: str, partition_stride: int) -> None: + """Validate partition_stride values for known parameter patterns. + + After Megatron-LM PR #2708, linear_fc1 correctly reports partition_stride=2 + (GLU/SwiGLU interleaved [gate, up]) and linear_fc2 reports partition_stride=1. + """ + if "linear_fc1.weight" in name: + assert partition_stride == 2, f"Expected partition_stride=2 for {name} (GLU/SwiGLU), got {partition_stride}" + elif "linear_fc2.weight" in name: + assert partition_stride == 1, f"Expected partition_stride=1 for {name}, got {partition_stride}" + def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: """ All-gather TP-sharded param to full tensor. expert_bias→param, non-TP/duplicated→param.data. - Uses expert-TP for ".experts.", else regular-TP. linear_fc1 rechunked (GLU), linear_fc2 dim fix. + Uses expert-TP for ".experts.", else regular-TP. Handles strided partitioning via partition_stride. """ if "expert_bias" in name: return param @@ -34,17 +61,10 @@ def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: param_partitions = [torch.empty_like(param.data) for _ in range(tp_size)] dist.all_gather(param_partitions, param.data, group=tp_group) partition_dim = param.partition_dim - assert param.partition_stride == 1, "partition_stride != 1 is not supported" - # TODO: here we did an extra copy during concat, maybe merge this with convert_to_hf is better? - # TODO: check only GLU is used. - if "linear_fc1.weight" in name: - param_partitions = [p.chunk(2, dim=0) for p in param_partitions] - param_partitions = [p[0] for p in param_partitions] + [p[1] for p in param_partitions] - # this is bug in megatron's grouped moe. - if "linear_fc2.weight" in name: - if partition_dim == 0: - partition_dim = 1 - param = torch.cat(param_partitions, dim=partition_dim) + partition_stride = param.partition_stride + + _check_partition_stride(name, partition_stride) + param = _gather_with_stride(param_partitions, partition_dim, partition_stride) return param @@ -63,10 +83,10 @@ def all_gather_params_async( for info, param in param_infos_and_params: # Prepare async all_gather if "expert_bias" in info.name: - gather_tasks.append((info, param, None, None, None)) + gather_tasks.append((info, param, None, None, None, None)) handles.append(None) elif not param.tensor_model_parallel or getattr(param, "parallel_mode", None) == "duplicated": - gather_tasks.append((info, param.data, None, None, None)) + gather_tasks.append((info, param.data, None, None, None, None)) handles.append(None) else: # Start async all_gather @@ -79,7 +99,7 @@ def all_gather_params_async( param_partitions = [torch.empty_like(param.data) for _ in range(tp_size)] handle = dist.all_gather(param_partitions, param.data, group=tp_group, async_op=True) - gather_tasks.append((info, None, handle, param_partitions, param.partition_dim)) + gather_tasks.append((info, None, handle, param_partitions, param.partition_dim, param.partition_stride)) handles.append(handle) # Phase 2: Wait for ALL async operations to complete at once @@ -90,23 +110,13 @@ def all_gather_params_async( # Phase 3: Process all results after all communications are done gathered_params = [] - for info, direct_param, handle, param_partitions, partition_dim in gather_tasks: + for info, direct_param, handle, param_partitions, partition_dim, partition_stride in gather_tasks: if handle is None: # No all_gather needed param = direct_param else: - # Process the gathered partitions (same logic as original all_gather_param) - assert partition_dim is not None, "partition_stride != 1 is not supported" - # TODO: here we did an extra copy during concat, maybe merge this with convert_to_hf is better? - # TODO: check only GLU is used. - if "linear_fc1.weight" in info.name: - param_partitions = [p.chunk(2, dim=0) for p in param_partitions] - param_partitions = [p[0] for p in param_partitions] + [p[1] for p in param_partitions] - # this is bug in megatron's grouped moe. - if "linear_fc2.weight" in info.name: - if partition_dim == 0: - partition_dim = 1 - param = torch.cat(param_partitions, dim=partition_dim) + _check_partition_stride(info.name, partition_stride) + param = _gather_with_stride(param_partitions, partition_dim, partition_stride) gathered_params.append(param) From a82fa838ab3c00c3e50f58a817f25c9d1d55abdd Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Wed, 25 Feb 2026 23:55:23 +0000 Subject: [PATCH 07/16] fix: update norm_epsilon to layernorm_epsilon for new Megatron param name Megatron-LM renamed the config field from norm_epsilon to layernorm_epsilon. Update the HF config validation mapping accordingly. Co-Authored-By: Claude Opus 4.6 --- miles/utils/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 1385cde788..bc214da434 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1865,7 +1865,7 @@ def equal(x, y): ("num_hidden_layers", "num_layers", equal), ("intermediate_size", "ffn_hidden_size", equal), ("tie_word_embeddings", "untie_embeddings_and_output_weights", lambda x, y: not x == y), - ("rms_norm_eps", "norm_epsilon", equal), + ("rms_norm_eps", "layernorm_epsilon", equal), ("rope_theta", "rotary_base", equal), ]: if hasattr(hf_config, hf_config_name): From 87d452a8fecb3f56fa579fafcab56d1e42581994 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Thu, 26 Feb 2026 01:36:28 +0000 Subject: [PATCH 08/16] skip megatron build args --- .../backends/megatron_utils/model_provider.py | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/miles/backends/megatron_utils/model_provider.py b/miles/backends/megatron_utils/model_provider.py index 5f54503979..0e70fa0fd0 100644 --- a/miles/backends/megatron_utils/model_provider.py +++ b/miles/backends/megatron_utils/model_provider.py @@ -63,8 +63,13 @@ def get_model_provider_func( if getattr(args, "custom_model_provider_path", None): def wrapped_model_provider( - pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None + pre_process: bool = True, + post_process: bool = True, + vp_stage: int | None = None, + config: TransformerConfig | None = None, + pg_collection=None, ) -> GPTModel: + assert config is None, "miles builds the config from args, so it expects config to be None" custom_model_provider = load_function(args.custom_model_provider_path) # Check if the custom provider supports vp_stage parameter has_vp_stage = "vp_stage" in inspect.signature(custom_model_provider).parameters @@ -93,9 +98,26 @@ def wrapped_model_provider( provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size provider.sequence_parallel = args.sequence_parallel provider.finalize() - return provider.provide - def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None) -> GPTModel: + def wrapped_bridge_provider( + pre_process: bool = True, + post_process: bool = True, + vp_stage: int | None = None, + config: TransformerConfig | None = None, + pg_collection=None, + ) -> GPTModel: + assert config is None, "miles builds the config from args, so it expects config to be None" + return provider.provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) + + return wrapped_bridge_provider + + def model_provider( + pre_process: bool = True, + post_process: bool = True, + vp_stage: int | None = None, + config: TransformerConfig | None = None, + pg_collection=None, + ) -> GPTModel: """Builds the model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. @@ -111,7 +133,8 @@ def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage use_te = args.transformer_impl == "transformer_engine" # Experimental loading arguments from yaml - config: TransformerConfig = core_transformer_config_from_args(args) + assert config is None, "miles builds the config from args, so it expects config to be None" + config = core_transformer_config_from_args(args) if args.spec is not None: transformer_layer_spec = import_module(args.spec) From d88d3f5fb83a27299800e7bb1b7bbf3f48521108 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Thu, 26 Feb 2026 05:11:31 +0000 Subject: [PATCH 09/16] update partition_stride checker impl --- .../megatron_utils/update_weight/common.py | 18 +++++++++++++----- .../update_weight/hf_weight_iterator_direct.py | 7 +++++-- .../update_weight_from_distributed.py | 4 ++-- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/miles/backends/megatron_utils/update_weight/common.py b/miles/backends/megatron_utils/update_weight/common.py index a1aca6c5f6..a0ba566812 100644 --- a/miles/backends/megatron_utils/update_weight/common.py +++ b/miles/backends/megatron_utils/update_weight/common.py @@ -27,19 +27,26 @@ def _gather_with_stride( return torch.cat(interleaved, dim=partition_dim) -def _check_partition_stride(name: str, partition_stride: int) -> None: +def _check_partition_stride(args: Namespace, name: str, partition_stride: int) -> int: """Validate partition_stride values for known parameter patterns. After Megatron-LM PR #2708, linear_fc1 correctly reports partition_stride=2 (GLU/SwiGLU interleaved [gate, up]) and linear_fc2 reports partition_stride=1. """ if "linear_fc1.weight" in name: - assert partition_stride == 2, f"Expected partition_stride=2 for {name} (GLU/SwiGLU), got {partition_stride}" + if args.moe_grouped_gemm and partition_stride != 2 and args.swiglu: + # Megatron bug: TEGroupedLinaer does not set partition_stride=2 for linear_fc1 + partition_stride = 2 + if args.swiglu: + assert ( + partition_stride == 2 + ), f"Expected partition_stride=2 for {name} (GLU/SwiGLU), got {partition_stride}" elif "linear_fc2.weight" in name: assert partition_stride == 1, f"Expected partition_stride=1 for {name}, got {partition_stride}" + return partition_stride -def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: +def all_gather_param(args: Namespace, name: str, param: torch.nn.Parameter) -> torch.Tensor: """ All-gather TP-sharded param to full tensor. expert_bias→param, non-TP/duplicated→param.data. Uses expert-TP for ".experts.", else regular-TP. Handles strided partitioning via partition_stride. @@ -63,12 +70,13 @@ def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: partition_dim = param.partition_dim partition_stride = param.partition_stride - _check_partition_stride(name, partition_stride) + partition_stride = _check_partition_stride(args, name, partition_stride) param = _gather_with_stride(param_partitions, partition_dim, partition_stride) return param def all_gather_params_async( + args: Namespace, param_infos_and_params: list[tuple[ParamInfo, torch.Tensor]], ) -> list[torch.Tensor]: """ @@ -115,7 +123,7 @@ def all_gather_params_async( # No all_gather needed param = direct_param else: - _check_partition_stride(info.name, partition_stride) + partition_stride = _check_partition_stride(args, info.name, partition_stride) param = _gather_with_stride(param_partitions, partition_dim, partition_stride) gathered_params.append(param) diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py index af2250dc1b..ecdba3c8c7 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py @@ -27,7 +27,9 @@ def get_hf_weight_chunks(self, megatron_local_weights): for megatron_local_param_infos in tqdm( self.megatron_local_param_info_buckets, disable=rank != 0, desc="Update weights" ): - megatron_full_params = _get_megatron_full_params(megatron_local_param_infos, megatron_local_weights) + megatron_full_params = _get_megatron_full_params( + self.args, megatron_local_param_infos, megatron_local_weights + ) hf_named_tensors = self._convert_to_hf_named_tensors(megatron_full_params, megatron_local_param_infos) yield hf_named_tensors del megatron_full_params @@ -42,6 +44,7 @@ def _convert_to_hf_named_tensors(self, megatron_full_params: Sequence[torch.Tens def _get_megatron_full_params( + args: Namespace, megatron_local_param_infos: Sequence[ParamInfo], megatron_local_weights, ) -> Sequence[torch.Tensor]: @@ -100,7 +103,7 @@ def _get_megatron_full_params( setattr(param, key, value) # Batch async all_gather for all parameters - gathered_params = all_gather_params_async(list(zip(megatron_local_param_infos, params, strict=False))) + gathered_params = all_gather_params_async(args, list(zip(megatron_local_param_infos, params, strict=False))) return gathered_params diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py index caf6ae54f1..c166b195f4 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -147,7 +147,7 @@ def _update_weight_from_distributed( Non-expert: gather TP → rm pad → HF → buffer (flush if full). All gather, PP source buffers. Returns updated bytes on source, None on non-source. """ - param = all_gather_param(name, param) + param = all_gather_param(self.args, name, param) if not self._is_pp_src_rank: return @@ -170,7 +170,7 @@ def _update_expert_weight_from_distributed( """ Expert: gather TP → rm pad → buffer. EP gather + HF deferred. Threshold × EP size. """ - param = all_gather_param(name, param) + param = all_gather_param(self.args, name, param) param_size = param.numel() * param.element_size() if ( From 3b7d5eb2e24a2b06346605f5deaea1ba484b448c Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Thu, 26 Feb 2026 05:48:05 +0000 Subject: [PATCH 10/16] fix: add missing args parameter to all_gather_param call in convert_to_hf Co-Authored-By: Claude Opus 4.6 --- tools/convert_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/convert_to_hf.py b/tools/convert_to_hf.py index b84d68e665..a383473046 100644 --- a/tools/convert_to_hf.py +++ b/tools/convert_to_hf.py @@ -76,7 +76,7 @@ def main(args): for key, value in info.attrs.items(): setattr(param, key, value) - param = update_weight_utils.all_gather_param(info.name, param) + param = update_weight_utils.all_gather_param(args, info.name, param) param = update_weight_utils.remove_padding(info.name, param, vocab_size) # use torch.distributed if is_save_rank: From a29165cd65223da8a58d1038cbfa3117d0a11335 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Fri, 27 Feb 2026 03:56:28 +0000 Subject: [PATCH 11/16] add new fast test for megatron transformer input --- tests/fast/test_megatron_cli_flags.py | 50 +++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 tests/fast/test_megatron_cli_flags.py diff --git a/tests/fast/test_megatron_cli_flags.py b/tests/fast/test_megatron_cli_flags.py new file mode 100644 index 0000000000..3b58eafcb4 --- /dev/null +++ b/tests/fast/test_megatron_cli_flags.py @@ -0,0 +1,50 @@ +import sys + +import pytest + + +def test_post_layernorm_flags_propagate_to_megatron(monkeypatch): + pytest.importorskip("megatron.training.arguments") + + import torch + from megatron.training.arguments import core_transformer_config_from_args + + import miles.backends.megatron_utils.arguments as megatron_arguments + import miles.utils.arguments as miles_arguments + + monkeypatch.setattr(miles_arguments, "miles_validate_args", lambda args: None) + monkeypatch.setattr(megatron_arguments, "validate_args", lambda args: None) + + argv = [ + "pytest", + "--train-backend", + "megatron", + "--rollout-batch-size", + "1", + "--num-layers", + "1", + "--hidden-size", + "8", + "--num-attention-heads", + "1", + "--post-self-attn-layernorm", + "--post-mlp-layernorm", + ] + monkeypatch.setattr(sys, "argv", argv) + + args = miles_arguments.parse_args() + + assert args.post_self_attn_layernorm is True + assert args.post_mlp_layernorm is True + + if args.bf16: + args.params_dtype = torch.bfloat16 + elif args.fp16: + args.params_dtype = torch.float16 + else: + args.params_dtype = torch.float32 + + config = core_transformer_config_from_args(args) + + assert config.post_self_attn_layernorm is True + assert config.post_mlp_layernorm is True From 0ccaa995e05d22d0ca664e9b12cb5743cdc59a6d Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 2 Mar 2026 02:55:35 +0000 Subject: [PATCH 12/16] add fall backs --- .../megatron_utils/update_weight/common.py | 27 ++++++++++--------- miles/utils/arguments.py | 6 ++++- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/miles/backends/megatron_utils/update_weight/common.py b/miles/backends/megatron_utils/update_weight/common.py index a0ba566812..c729a36e68 100644 --- a/miles/backends/megatron_utils/update_weight/common.py +++ b/miles/backends/megatron_utils/update_weight/common.py @@ -27,23 +27,22 @@ def _gather_with_stride( return torch.cat(interleaved, dim=partition_dim) -def _check_partition_stride(args: Namespace, name: str, partition_stride: int) -> int: +def _check_and_fix_partition(args: Namespace, name: str, partition_stride: int, partition_dim: int) -> tuple[int, int]: """Validate partition_stride values for known parameter patterns. After Megatron-LM PR #2708, linear_fc1 correctly reports partition_stride=2 - (GLU/SwiGLU interleaved [gate, up]) and linear_fc2 reports partition_stride=1. + (GLU/SwiGLU interleaved [gate, up]), so assert partition_stride==2 is removed. + But TEGroupedLinear still does not set partition_stride/partition_dim correctly for grouped moe gemm """ - if "linear_fc1.weight" in name: - if args.moe_grouped_gemm and partition_stride != 2 and args.swiglu: - # Megatron bug: TEGroupedLinaer does not set partition_stride=2 for linear_fc1 - partition_stride = 2 - if args.swiglu: - assert ( - partition_stride == 2 - ), f"Expected partition_stride=2 for {name} (GLU/SwiGLU), got {partition_stride}" + if "linear_fc1.weight" in name and args.swiglu: + partition_stride = 2 elif "linear_fc2.weight" in name: assert partition_stride == 1, f"Expected partition_stride=1 for {name}, got {partition_stride}" - return partition_stride + if partition_dim == 0: + partition_dim = 1 + else: + assert partition_stride == 1, f"Expected partition_stride=1 for {name}, got {partition_stride}" + return partition_stride, partition_dim def all_gather_param(args: Namespace, name: str, param: torch.nn.Parameter) -> torch.Tensor: @@ -70,7 +69,7 @@ def all_gather_param(args: Namespace, name: str, param: torch.nn.Parameter) -> t partition_dim = param.partition_dim partition_stride = param.partition_stride - partition_stride = _check_partition_stride(args, name, partition_stride) + partition_stride, partition_dim = _check_and_fix_partition(args, name, partition_stride, partition_dim) param = _gather_with_stride(param_partitions, partition_dim, partition_stride) return param @@ -123,7 +122,9 @@ def all_gather_params_async( # No all_gather needed param = direct_param else: - partition_stride = _check_partition_stride(args, info.name, partition_stride) + partition_stride, partition_dim = _check_and_fix_partition( + args, info.name, partition_stride, partition_dim + ) param = _gather_with_stride(param_partitions, partition_dim, partition_stride) gathered_params.append(param) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index bc214da434..3d4d349c18 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1865,7 +1865,11 @@ def equal(x, y): ("num_hidden_layers", "num_layers", equal), ("intermediate_size", "ffn_hidden_size", equal), ("tie_word_embeddings", "untie_embeddings_and_output_weights", lambda x, y: not x == y), - ("rms_norm_eps", "layernorm_epsilon", equal), + ( + "rms_norm_eps", + "norm_epsilon" if os.getenv("DEPRECATED_MEGATRON_COMPATIBLE", "0") == "1" else "layernorm_epsilon", + equal, + ), ("rope_theta", "rotary_base", equal), ]: if hasattr(hf_config, hf_config_name): From e30f37c79696965d6477d004d79ef0c1aff746f0 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 2 Mar 2026 03:44:22 +0000 Subject: [PATCH 13/16] use dp reshardable in arguments --- miles/backends/megatron_utils/arguments.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/miles/backends/megatron_utils/arguments.py b/miles/backends/megatron_utils/arguments.py index 24496011b1..e533256182 100644 --- a/miles/backends/megatron_utils/arguments.py +++ b/miles/backends/megatron_utils/arguments.py @@ -17,8 +17,8 @@ def set_default_megatron_args(args): if args.seq_length is None: args.seq_length = 4096 args.max_position_embeddings = args.seq_length - # TODO: revisit this when megatron(dev) have solved the optimizer-cpu-offload ckpt saving bug - args.dist_ckpt_save_pre_mcore_014 = True + # Notice(Jiajun): new megatron has removed this argument and use dp_reshardable instead of fully_shard + # args.dist_ckpt_save_pre_mcore_014 = True # compatible for megatron if hasattr(args, "rope_type") and args.rope_type is None: args.rope_type = "yarn" if args.multi_latent_attention else "rope" From 8db6954f3376384b499025db2d74d32df0526ec4 Mon Sep 17 00:00:00 2001 From: Jiajun Li <48857426+guapisolo@users.noreply.github.com> Date: Mon, 2 Mar 2026 16:42:32 -0800 Subject: [PATCH 14/16] Update arguments.py --- miles/backends/megatron_utils/arguments.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/miles/backends/megatron_utils/arguments.py b/miles/backends/megatron_utils/arguments.py index e533256182..e1c7ba8f5d 100644 --- a/miles/backends/megatron_utils/arguments.py +++ b/miles/backends/megatron_utils/arguments.py @@ -1,4 +1,5 @@ import logging +import os from megatron.training.arguments import parse_args, validate_args from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding @@ -18,7 +19,8 @@ def set_default_megatron_args(args): args.seq_length = 4096 args.max_position_embeddings = args.seq_length # Notice(Jiajun): new megatron has removed this argument and use dp_reshardable instead of fully_shard - # args.dist_ckpt_save_pre_mcore_014 = True + if os.getenv("DEPRECATED_MEGATRON_COMPATIBLE", "0") == "1": + args.dist_ckpt_save_pre_mcore_014 = True # compatible for megatron if hasattr(args, "rope_type") and args.rope_type is None: args.rope_type = "yarn" if args.multi_latent_attention else "rope" From 15eef51939fef3b573cd563c0e19931c0f1a1675 Mon Sep 17 00:00:00 2001 From: Jiajun Li <48857426+guapisolo@users.noreply.github.com> Date: Wed, 4 Mar 2026 14:38:41 -0800 Subject: [PATCH 15/16] add EOF space --- .github/workflows/pr-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index fd8df2f335..ff599964da 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -1290,4 +1290,4 @@ jobs: - name: Execute shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} \ No newline at end of file + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} From e8854fcb248da16e23601e188bfe43f726fbebbb Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Wed, 4 Mar 2026 23:00:50 +0000 Subject: [PATCH 16/16] upgrade fla==0.4.1 to fix qwen next bugs on b200 --- docker/Dockerfile.dev | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index 4a62cceb36..4eae94528a 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -54,7 +54,7 @@ RUN pip install /tmp/wheels/flash_attn_3-*.whl && \ RUN pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps -RUN pip install flash-linear-attention==0.4.0 +RUN pip install flash-linear-attention==0.4.1 RUN pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu128/ RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \