diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index 4a62cceb36..4eae94528a 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -54,7 +54,7 @@ RUN pip install /tmp/wheels/flash_attn_3-*.whl && \ RUN pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps -RUN pip install flash-linear-attention==0.4.0 +RUN pip install flash-linear-attention==0.4.1 RUN pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu128/ RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \ diff --git a/miles/backends/megatron_utils/arguments.py b/miles/backends/megatron_utils/arguments.py index 24496011b1..e1c7ba8f5d 100644 --- a/miles/backends/megatron_utils/arguments.py +++ b/miles/backends/megatron_utils/arguments.py @@ -1,4 +1,5 @@ import logging +import os from megatron.training.arguments import parse_args, validate_args from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding @@ -17,8 +18,9 @@ def set_default_megatron_args(args): if args.seq_length is None: args.seq_length = 4096 args.max_position_embeddings = args.seq_length - # TODO: revisit this when megatron(dev) have solved the optimizer-cpu-offload ckpt saving bug - args.dist_ckpt_save_pre_mcore_014 = True + # Notice(Jiajun): new megatron has removed this argument and use dp_reshardable instead of fully_shard + if os.getenv("DEPRECATED_MEGATRON_COMPATIBLE", "0") == "1": + args.dist_ckpt_save_pre_mcore_014 = True # compatible for megatron if hasattr(args, "rope_type") and args.rope_type is None: args.rope_type = "yarn" if args.multi_latent_attention else "rope" diff --git a/miles/backends/megatron_utils/model_provider.py b/miles/backends/megatron_utils/model_provider.py index 5f54503979..0e70fa0fd0 100644 --- a/miles/backends/megatron_utils/model_provider.py +++ b/miles/backends/megatron_utils/model_provider.py @@ -63,8 +63,13 @@ def get_model_provider_func( if getattr(args, "custom_model_provider_path", None): def wrapped_model_provider( - pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None + pre_process: bool = True, + post_process: bool = True, + vp_stage: int | None = None, + config: TransformerConfig | None = None, + pg_collection=None, ) -> GPTModel: + assert config is None, "miles builds the config from args, so it expects config to be None" custom_model_provider = load_function(args.custom_model_provider_path) # Check if the custom provider supports vp_stage parameter has_vp_stage = "vp_stage" in inspect.signature(custom_model_provider).parameters @@ -93,9 +98,26 @@ def wrapped_model_provider( provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size provider.sequence_parallel = args.sequence_parallel provider.finalize() - return provider.provide - def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage: int | None = None) -> GPTModel: + def wrapped_bridge_provider( + pre_process: bool = True, + post_process: bool = True, + vp_stage: int | None = None, + config: TransformerConfig | None = None, + pg_collection=None, + ) -> GPTModel: + assert config is None, "miles builds the config from args, so it expects config to be None" + return provider.provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) + + return wrapped_bridge_provider + + def model_provider( + pre_process: bool = True, + post_process: bool = True, + vp_stage: int | None = None, + config: TransformerConfig | None = None, + pg_collection=None, + ) -> GPTModel: """Builds the model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. @@ -111,7 +133,8 @@ def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage use_te = args.transformer_impl == "transformer_engine" # Experimental loading arguments from yaml - config: TransformerConfig = core_transformer_config_from_args(args) + assert config is None, "miles builds the config from args, so it expects config to be None" + config = core_transformer_config_from_args(args) if args.spec is not None: transformer_layer_spec = import_module(args.spec) diff --git a/miles/backends/megatron_utils/update_weight/common.py b/miles/backends/megatron_utils/update_weight/common.py index 85fe76a1b8..c729a36e68 100644 --- a/miles/backends/megatron_utils/update_weight/common.py +++ b/miles/backends/megatron_utils/update_weight/common.py @@ -1,4 +1,5 @@ import inspect +import logging import re from argparse import Namespace from collections.abc import Iterator, Sequence @@ -11,11 +12,43 @@ from miles.backends.megatron_utils.misc_utils import strip_param_name_prefix from miles.utils.types import ParamInfo +logger = logging.getLogger(__name__) -def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: + +def _gather_with_stride( + param_partitions: list[torch.Tensor], partition_dim: int, partition_stride: int +) -> torch.Tensor: + """Gather partitions respecting partition_stride (strided/interleaved TP sharding).""" + if partition_stride == 1: + return torch.cat(param_partitions, dim=partition_dim) + # Interleaved (strided) partitioning, e.g. linear_fc1.weight under GLU/SwiGLU + chunks_per_rank = [p.chunk(partition_stride, dim=partition_dim) for p in param_partitions] + interleaved = [chunks_per_rank[r][s] for s in range(partition_stride) for r in range(len(param_partitions))] + return torch.cat(interleaved, dim=partition_dim) + + +def _check_and_fix_partition(args: Namespace, name: str, partition_stride: int, partition_dim: int) -> tuple[int, int]: + """Validate partition_stride values for known parameter patterns. + + After Megatron-LM PR #2708, linear_fc1 correctly reports partition_stride=2 + (GLU/SwiGLU interleaved [gate, up]), so assert partition_stride==2 is removed. + But TEGroupedLinear still does not set partition_stride/partition_dim correctly for grouped moe gemm + """ + if "linear_fc1.weight" in name and args.swiglu: + partition_stride = 2 + elif "linear_fc2.weight" in name: + assert partition_stride == 1, f"Expected partition_stride=1 for {name}, got {partition_stride}" + if partition_dim == 0: + partition_dim = 1 + else: + assert partition_stride == 1, f"Expected partition_stride=1 for {name}, got {partition_stride}" + return partition_stride, partition_dim + + +def all_gather_param(args: Namespace, name: str, param: torch.nn.Parameter) -> torch.Tensor: """ All-gather TP-sharded param to full tensor. expert_bias→param, non-TP/duplicated→param.data. - Uses expert-TP for ".experts.", else regular-TP. linear_fc1 rechunked (GLU), linear_fc2 dim fix. + Uses expert-TP for ".experts.", else regular-TP. Handles strided partitioning via partition_stride. """ if "expert_bias" in name: return param @@ -34,21 +67,15 @@ def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: param_partitions = [torch.empty_like(param.data) for _ in range(tp_size)] dist.all_gather(param_partitions, param.data, group=tp_group) partition_dim = param.partition_dim - assert param.partition_stride == 1, "partition_stride != 1 is not supported" - # TODO: here we did an extra copy during concat, maybe merge this with convert_to_hf is better? - # TODO: check only GLU is used. - if "linear_fc1.weight" in name: - param_partitions = [p.chunk(2, dim=0) for p in param_partitions] - param_partitions = [p[0] for p in param_partitions] + [p[1] for p in param_partitions] - # this is bug in megatron's grouped moe. - if "linear_fc2.weight" in name: - if partition_dim == 0: - partition_dim = 1 - param = torch.cat(param_partitions, dim=partition_dim) + partition_stride = param.partition_stride + + partition_stride, partition_dim = _check_and_fix_partition(args, name, partition_stride, partition_dim) + param = _gather_with_stride(param_partitions, partition_dim, partition_stride) return param def all_gather_params_async( + args: Namespace, param_infos_and_params: list[tuple[ParamInfo, torch.Tensor]], ) -> list[torch.Tensor]: """ @@ -63,10 +90,10 @@ def all_gather_params_async( for info, param in param_infos_and_params: # Prepare async all_gather if "expert_bias" in info.name: - gather_tasks.append((info, param, None, None, None)) + gather_tasks.append((info, param, None, None, None, None)) handles.append(None) elif not param.tensor_model_parallel or getattr(param, "parallel_mode", None) == "duplicated": - gather_tasks.append((info, param.data, None, None, None)) + gather_tasks.append((info, param.data, None, None, None, None)) handles.append(None) else: # Start async all_gather @@ -79,7 +106,7 @@ def all_gather_params_async( param_partitions = [torch.empty_like(param.data) for _ in range(tp_size)] handle = dist.all_gather(param_partitions, param.data, group=tp_group, async_op=True) - gather_tasks.append((info, None, handle, param_partitions, param.partition_dim)) + gather_tasks.append((info, None, handle, param_partitions, param.partition_dim, param.partition_stride)) handles.append(handle) # Phase 2: Wait for ALL async operations to complete at once @@ -90,23 +117,15 @@ def all_gather_params_async( # Phase 3: Process all results after all communications are done gathered_params = [] - for info, direct_param, handle, param_partitions, partition_dim in gather_tasks: + for info, direct_param, handle, param_partitions, partition_dim, partition_stride in gather_tasks: if handle is None: # No all_gather needed param = direct_param else: - # Process the gathered partitions (same logic as original all_gather_param) - assert partition_dim is not None, "partition_stride != 1 is not supported" - # TODO: here we did an extra copy during concat, maybe merge this with convert_to_hf is better? - # TODO: check only GLU is used. - if "linear_fc1.weight" in info.name: - param_partitions = [p.chunk(2, dim=0) for p in param_partitions] - param_partitions = [p[0] for p in param_partitions] + [p[1] for p in param_partitions] - # this is bug in megatron's grouped moe. - if "linear_fc2.weight" in info.name: - if partition_dim == 0: - partition_dim = 1 - param = torch.cat(param_partitions, dim=partition_dim) + partition_stride, partition_dim = _check_and_fix_partition( + args, info.name, partition_stride, partition_dim + ) + param = _gather_with_stride(param_partitions, partition_dim, partition_stride) gathered_params.append(param) diff --git a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py index af2250dc1b..ecdba3c8c7 100644 --- a/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py +++ b/miles/backends/megatron_utils/update_weight/hf_weight_iterator_direct.py @@ -27,7 +27,9 @@ def get_hf_weight_chunks(self, megatron_local_weights): for megatron_local_param_infos in tqdm( self.megatron_local_param_info_buckets, disable=rank != 0, desc="Update weights" ): - megatron_full_params = _get_megatron_full_params(megatron_local_param_infos, megatron_local_weights) + megatron_full_params = _get_megatron_full_params( + self.args, megatron_local_param_infos, megatron_local_weights + ) hf_named_tensors = self._convert_to_hf_named_tensors(megatron_full_params, megatron_local_param_infos) yield hf_named_tensors del megatron_full_params @@ -42,6 +44,7 @@ def _convert_to_hf_named_tensors(self, megatron_full_params: Sequence[torch.Tens def _get_megatron_full_params( + args: Namespace, megatron_local_param_infos: Sequence[ParamInfo], megatron_local_weights, ) -> Sequence[torch.Tensor]: @@ -100,7 +103,7 @@ def _get_megatron_full_params( setattr(param, key, value) # Batch async all_gather for all parameters - gathered_params = all_gather_params_async(list(zip(megatron_local_param_infos, params, strict=False))) + gathered_params = all_gather_params_async(args, list(zip(megatron_local_param_infos, params, strict=False))) return gathered_params diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py index caf6ae54f1..c166b195f4 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -147,7 +147,7 @@ def _update_weight_from_distributed( Non-expert: gather TP → rm pad → HF → buffer (flush if full). All gather, PP source buffers. Returns updated bytes on source, None on non-source. """ - param = all_gather_param(name, param) + param = all_gather_param(self.args, name, param) if not self._is_pp_src_rank: return @@ -170,7 +170,7 @@ def _update_expert_weight_from_distributed( """ Expert: gather TP → rm pad → buffer. EP gather + HF deferred. Threshold × EP size. """ - param = all_gather_param(name, param) + param = all_gather_param(self.args, name, param) param_size = param.numel() * param.element_size() if ( diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 39314715c8..202b5a3ca1 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1858,7 +1858,11 @@ def equal(x, y): ("num_hidden_layers", "num_layers", equal), ("intermediate_size", "ffn_hidden_size", equal), ("tie_word_embeddings", "untie_embeddings_and_output_weights", lambda x, y: not x == y), - ("rms_norm_eps", "norm_epsilon", equal), + ( + "rms_norm_eps", + "norm_epsilon" if os.getenv("DEPRECATED_MEGATRON_COMPATIBLE", "0") == "1" else "layernorm_epsilon", + equal, + ), ("rope_theta", "rotary_base", equal), ]: if hasattr(hf_config, hf_config_name): diff --git a/tests/fast/test_megatron_cli_flags.py b/tests/fast/test_megatron_cli_flags.py new file mode 100644 index 0000000000..3b58eafcb4 --- /dev/null +++ b/tests/fast/test_megatron_cli_flags.py @@ -0,0 +1,50 @@ +import sys + +import pytest + + +def test_post_layernorm_flags_propagate_to_megatron(monkeypatch): + pytest.importorskip("megatron.training.arguments") + + import torch + from megatron.training.arguments import core_transformer_config_from_args + + import miles.backends.megatron_utils.arguments as megatron_arguments + import miles.utils.arguments as miles_arguments + + monkeypatch.setattr(miles_arguments, "miles_validate_args", lambda args: None) + monkeypatch.setattr(megatron_arguments, "validate_args", lambda args: None) + + argv = [ + "pytest", + "--train-backend", + "megatron", + "--rollout-batch-size", + "1", + "--num-layers", + "1", + "--hidden-size", + "8", + "--num-attention-heads", + "1", + "--post-self-attn-layernorm", + "--post-mlp-layernorm", + ] + monkeypatch.setattr(sys, "argv", argv) + + args = miles_arguments.parse_args() + + assert args.post_self_attn_layernorm is True + assert args.post_mlp_layernorm is True + + if args.bf16: + args.params_dtype = torch.bfloat16 + elif args.fp16: + args.params_dtype = torch.float16 + else: + args.params_dtype = torch.float32 + + config = core_transformer_config_from_args(args) + + assert config.post_self_attn_layernorm is True + assert config.post_mlp_layernorm is True diff --git a/tools/convert_to_hf.py b/tools/convert_to_hf.py index b84d68e665..a383473046 100644 --- a/tools/convert_to_hf.py +++ b/tools/convert_to_hf.py @@ -76,7 +76,7 @@ def main(args): for key, value in info.attrs.items(): setattr(param, key, value) - param = update_weight_utils.all_gather_param(info.name, param) + param = update_weight_utils.all_gather_param(args, info.name, param) param = update_weight_utils.remove_padding(info.name, param, vocab_size) # use torch.distributed if is_save_rank: