Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ RUN pip install /tmp/wheels/flash_attn_3-*.whl && \

RUN pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps

RUN pip install flash-linear-attention==0.4.0
RUN pip install flash-linear-attention==0.4.1
RUN pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu128/

RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \
Expand Down
6 changes: 4 additions & 2 deletions miles/backends/megatron_utils/arguments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os

from megatron.training.arguments import parse_args, validate_args
from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding
Expand All @@ -17,8 +18,9 @@ def set_default_megatron_args(args):
if args.seq_length is None:
args.seq_length = 4096
args.max_position_embeddings = args.seq_length
# TODO: revisit this when megatron(dev) have solved the optimizer-cpu-offload ckpt saving bug
args.dist_ckpt_save_pre_mcore_014 = True
# Notice(Jiajun): new megatron has removed this argument and use dp_reshardable instead of fully_shard
if os.getenv("DEPRECATED_MEGATRON_COMPATIBLE", "0") == "1":
args.dist_ckpt_save_pre_mcore_014 = True
# compatible for megatron
if hasattr(args, "rope_type") and args.rope_type is None:
args.rope_type = "yarn" if args.multi_latent_attention else "rope"
Expand Down
31 changes: 27 additions & 4 deletions miles/backends/megatron_utils/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,13 @@ def get_model_provider_func(
if getattr(args, "custom_model_provider_path", None):

def wrapped_model_provider(
pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None
pre_process: bool = True,
post_process: bool = True,
vp_stage: int | None = None,
config: TransformerConfig | None = None,
pg_collection=None,
) -> GPTModel:
assert config is None, "miles builds the config from args, so it expects config to be None"
custom_model_provider = load_function(args.custom_model_provider_path)
# Check if the custom provider supports vp_stage parameter
has_vp_stage = "vp_stage" in inspect.signature(custom_model_provider).parameters
Expand Down Expand Up @@ -93,9 +98,26 @@ def wrapped_model_provider(
provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size
provider.sequence_parallel = args.sequence_parallel
provider.finalize()
return provider.provide

def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None) -> GPTModel:
def wrapped_bridge_provider(
pre_process: bool = True,
post_process: bool = True,
vp_stage: int | None = None,
config: TransformerConfig | None = None,
pg_collection=None,
) -> GPTModel:
assert config is None, "miles builds the config from args, so it expects config to be None"
return provider.provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage)

return wrapped_bridge_provider

def model_provider(
pre_process: bool = True,
post_process: bool = True,
vp_stage: int | None = None,
config: TransformerConfig | None = None,
pg_collection=None,
) -> GPTModel:
"""Builds the model.

If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
Expand All @@ -111,7 +133,8 @@ def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage
use_te = args.transformer_impl == "transformer_engine"

# Experimental loading arguments from yaml
config: TransformerConfig = core_transformer_config_from_args(args)
assert config is None, "miles builds the config from args, so it expects config to be None"
config = core_transformer_config_from_args(args)

if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
Expand Down
77 changes: 48 additions & 29 deletions miles/backends/megatron_utils/update_weight/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import logging
import re
from argparse import Namespace
from collections.abc import Iterator, Sequence
Expand All @@ -11,11 +12,43 @@
from miles.backends.megatron_utils.misc_utils import strip_param_name_prefix
from miles.utils.types import ParamInfo

logger = logging.getLogger(__name__)

def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor:

def _gather_with_stride(
param_partitions: list[torch.Tensor], partition_dim: int, partition_stride: int
) -> torch.Tensor:
"""Gather partitions respecting partition_stride (strided/interleaved TP sharding)."""
if partition_stride == 1:
return torch.cat(param_partitions, dim=partition_dim)
# Interleaved (strided) partitioning, e.g. linear_fc1.weight under GLU/SwiGLU
chunks_per_rank = [p.chunk(partition_stride, dim=partition_dim) for p in param_partitions]
interleaved = [chunks_per_rank[r][s] for s in range(partition_stride) for r in range(len(param_partitions))]
return torch.cat(interleaved, dim=partition_dim)


def _check_and_fix_partition(args: Namespace, name: str, partition_stride: int, partition_dim: int) -> tuple[int, int]:
"""Validate partition_stride values for known parameter patterns.

After Megatron-LM PR #2708, linear_fc1 correctly reports partition_stride=2
(GLU/SwiGLU interleaved [gate, up]), so assert partition_stride==2 is removed.
But TEGroupedLinear still does not set partition_stride/partition_dim correctly for grouped moe gemm
"""
if "linear_fc1.weight" in name and args.swiglu:
partition_stride = 2
elif "linear_fc2.weight" in name:
assert partition_stride == 1, f"Expected partition_stride=1 for {name}, got {partition_stride}"
if partition_dim == 0:
partition_dim = 1
else:
assert partition_stride == 1, f"Expected partition_stride=1 for {name}, got {partition_stride}"
return partition_stride, partition_dim


def all_gather_param(args: Namespace, name: str, param: torch.nn.Parameter) -> torch.Tensor:
"""
All-gather TP-sharded param to full tensor. expert_bias→param, non-TP/duplicated→param.data.
Uses expert-TP for ".experts.", else regular-TP. linear_fc1 rechunked (GLU), linear_fc2 dim fix.
Uses expert-TP for ".experts.", else regular-TP. Handles strided partitioning via partition_stride.
"""
if "expert_bias" in name:
return param
Expand All @@ -34,21 +67,15 @@ def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor:
param_partitions = [torch.empty_like(param.data) for _ in range(tp_size)]
dist.all_gather(param_partitions, param.data, group=tp_group)
partition_dim = param.partition_dim
assert param.partition_stride == 1, "partition_stride != 1 is not supported"
# TODO: here we did an extra copy during concat, maybe merge this with convert_to_hf is better?
# TODO: check only GLU is used.
if "linear_fc1.weight" in name:
param_partitions = [p.chunk(2, dim=0) for p in param_partitions]
param_partitions = [p[0] for p in param_partitions] + [p[1] for p in param_partitions]
# this is bug in megatron's grouped moe.
if "linear_fc2.weight" in name:
if partition_dim == 0:
partition_dim = 1
param = torch.cat(param_partitions, dim=partition_dim)
partition_stride = param.partition_stride

partition_stride, partition_dim = _check_and_fix_partition(args, name, partition_stride, partition_dim)
param = _gather_with_stride(param_partitions, partition_dim, partition_stride)
return param


def all_gather_params_async(
args: Namespace,
param_infos_and_params: list[tuple[ParamInfo, torch.Tensor]],
) -> list[torch.Tensor]:
"""
Expand All @@ -63,10 +90,10 @@ def all_gather_params_async(
for info, param in param_infos_and_params:
# Prepare async all_gather
if "expert_bias" in info.name:
gather_tasks.append((info, param, None, None, None))
gather_tasks.append((info, param, None, None, None, None))
handles.append(None)
elif not param.tensor_model_parallel or getattr(param, "parallel_mode", None) == "duplicated":
gather_tasks.append((info, param.data, None, None, None))
gather_tasks.append((info, param.data, None, None, None, None))
handles.append(None)
else:
# Start async all_gather
Expand All @@ -79,7 +106,7 @@ def all_gather_params_async(

param_partitions = [torch.empty_like(param.data) for _ in range(tp_size)]
handle = dist.all_gather(param_partitions, param.data, group=tp_group, async_op=True)
gather_tasks.append((info, None, handle, param_partitions, param.partition_dim))
gather_tasks.append((info, None, handle, param_partitions, param.partition_dim, param.partition_stride))
handles.append(handle)

# Phase 2: Wait for ALL async operations to complete at once
Expand All @@ -90,23 +117,15 @@ def all_gather_params_async(

# Phase 3: Process all results after all communications are done
gathered_params = []
for info, direct_param, handle, param_partitions, partition_dim in gather_tasks:
for info, direct_param, handle, param_partitions, partition_dim, partition_stride in gather_tasks:
if handle is None:
# No all_gather needed
param = direct_param
else:
# Process the gathered partitions (same logic as original all_gather_param)
assert partition_dim is not None, "partition_stride != 1 is not supported"
# TODO: here we did an extra copy during concat, maybe merge this with convert_to_hf is better?
# TODO: check only GLU is used.
if "linear_fc1.weight" in info.name:
param_partitions = [p.chunk(2, dim=0) for p in param_partitions]
param_partitions = [p[0] for p in param_partitions] + [p[1] for p in param_partitions]
# this is bug in megatron's grouped moe.
if "linear_fc2.weight" in info.name:
if partition_dim == 0:
partition_dim = 1
param = torch.cat(param_partitions, dim=partition_dim)
partition_stride, partition_dim = _check_and_fix_partition(
args, info.name, partition_stride, partition_dim
)
param = _gather_with_stride(param_partitions, partition_dim, partition_stride)

gathered_params.append(param)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def get_hf_weight_chunks(self, megatron_local_weights):
for megatron_local_param_infos in tqdm(
self.megatron_local_param_info_buckets, disable=rank != 0, desc="Update weights"
):
megatron_full_params = _get_megatron_full_params(megatron_local_param_infos, megatron_local_weights)
megatron_full_params = _get_megatron_full_params(
self.args, megatron_local_param_infos, megatron_local_weights
)
hf_named_tensors = self._convert_to_hf_named_tensors(megatron_full_params, megatron_local_param_infos)
yield hf_named_tensors
del megatron_full_params
Expand All @@ -42,6 +44,7 @@ def _convert_to_hf_named_tensors(self, megatron_full_params: Sequence[torch.Tens


def _get_megatron_full_params(
args: Namespace,
megatron_local_param_infos: Sequence[ParamInfo],
megatron_local_weights,
) -> Sequence[torch.Tensor]:
Expand Down Expand Up @@ -100,7 +103,7 @@ def _get_megatron_full_params(
setattr(param, key, value)

# Batch async all_gather for all parameters
gathered_params = all_gather_params_async(list(zip(megatron_local_param_infos, params, strict=False)))
gathered_params = all_gather_params_async(args, list(zip(megatron_local_param_infos, params, strict=False)))

return gathered_params

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _update_weight_from_distributed(
Non-expert: gather TP → rm pad → HF → buffer (flush if full). All gather, PP source buffers.
Returns updated bytes on source, None on non-source.
"""
param = all_gather_param(name, param)
param = all_gather_param(self.args, name, param)
if not self._is_pp_src_rank:
return

Expand All @@ -170,7 +170,7 @@ def _update_expert_weight_from_distributed(
"""
Expert: gather TP → rm pad → buffer. EP gather + HF deferred. Threshold × EP size.
"""
param = all_gather_param(name, param)
param = all_gather_param(self.args, name, param)

param_size = param.numel() * param.element_size()
if (
Expand Down
6 changes: 5 additions & 1 deletion miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,7 +1858,11 @@ def equal(x, y):
("num_hidden_layers", "num_layers", equal),
("intermediate_size", "ffn_hidden_size", equal),
("tie_word_embeddings", "untie_embeddings_and_output_weights", lambda x, y: not x == y),
("rms_norm_eps", "norm_epsilon", equal),
(
"rms_norm_eps",
"norm_epsilon" if os.getenv("DEPRECATED_MEGATRON_COMPATIBLE", "0") == "1" else "layernorm_epsilon",
equal,
),
("rope_theta", "rotary_base", equal),
]:
if hasattr(hf_config, hf_config_name):
Expand Down
50 changes: 50 additions & 0 deletions tests/fast/test_megatron_cli_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import sys

import pytest


def test_post_layernorm_flags_propagate_to_megatron(monkeypatch):
pytest.importorskip("megatron.training.arguments")

import torch
from megatron.training.arguments import core_transformer_config_from_args

import miles.backends.megatron_utils.arguments as megatron_arguments
import miles.utils.arguments as miles_arguments

monkeypatch.setattr(miles_arguments, "miles_validate_args", lambda args: None)
monkeypatch.setattr(megatron_arguments, "validate_args", lambda args: None)

argv = [
"pytest",
"--train-backend",
"megatron",
"--rollout-batch-size",
"1",
"--num-layers",
"1",
"--hidden-size",
"8",
"--num-attention-heads",
"1",
"--post-self-attn-layernorm",
"--post-mlp-layernorm",
]
monkeypatch.setattr(sys, "argv", argv)

args = miles_arguments.parse_args()

assert args.post_self_attn_layernorm is True
assert args.post_mlp_layernorm is True

if args.bf16:
args.params_dtype = torch.bfloat16
elif args.fp16:
args.params_dtype = torch.float16
else:
args.params_dtype = torch.float32

config = core_transformer_config_from_args(args)

assert config.post_self_attn_layernorm is True
assert config.post_mlp_layernorm is True
2 changes: 1 addition & 1 deletion tools/convert_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main(args):
for key, value in info.attrs.items():
setattr(param, key, value)

param = update_weight_utils.all_gather_param(info.name, param)
param = update_weight_utils.all_gather_param(args, info.name, param)
param = update_weight_utils.remove_padding(info.name, param, vocab_size)
# use torch.distributed
if is_save_rank:
Expand Down
Loading