From 960fd938b601203f4e8e8aac70d04e30216c6238 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Wed, 13 May 2026 02:10:53 +0000 Subject: [PATCH 1/4] Thread custom process groups through MoE grad finalization --- .../core/distributed/finalize_model_grads.py | 23 +++- megatron/core/transformer/moe/moe_layer.py | 6 +- megatron/core/transformer/moe/moe_utils.py | 19 ++- .../core/transformer/moe/shared_experts.py | 10 +- .../distributed/test_finalize_model_grads.py | 120 ++++++++++++++++++ .../transformer/moe/test_routers.py | 28 ++++ 6 files changed, 190 insertions(+), 16 deletions(-) diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py index 9c31b280875..53db86b4a5c 100644 --- a/megatron/core/distributed/finalize_model_grads.py +++ b/megatron/core/distributed/finalize_model_grads.py @@ -328,7 +328,11 @@ def reset_model_temporary_tensors(config: TransformerConfig, model: List[torch.n module.reset_global_aux_loss_tracker() -def _update_router_expert_bias(model: List[torch.nn.Module], config: TransformerConfig): +def _update_router_expert_bias( + model: List[torch.nn.Module], + config: TransformerConfig, + tp_dp_cp_group: Optional[torch.distributed.ProcessGroup] = None, +): """ Update the expert bias of the router for a global batch. This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks @@ -350,7 +354,10 @@ def _update_router_expert_bias(model: List[torch.nn.Module], config: Transformer stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0) stacked_expert_bias = torch.stack(expert_bias_list, dim=0) stacked_updated_expert_bias = get_updated_expert_bias( - stacked_tokens_per_expert, stacked_expert_bias, config.moe_router_bias_update_rate + stacked_tokens_per_expert, + stacked_expert_bias, + config.moe_router_bias_update_rate, + tp_dp_cp_group=tp_dp_cp_group, ) for expert_bias, updated_expert_bias in zip(expert_bias_list, stacked_updated_expert_bias): @@ -448,6 +455,7 @@ def finalize_model_grads( """ config = get_model_config(model[0]) + tp_dp_cp_group = None if pg_collection is not None: assert hasattr(pg_collection, 'tp') assert hasattr(pg_collection, 'pp') @@ -466,11 +474,16 @@ def finalize_model_grads( "If you don't need pos_embd_group, you need to explicitly set it to None." ) assert hasattr(pg_collection, 'dp_cp') + if config.moe_router_enable_expert_bias: + assert hasattr(pg_collection, 'tp_dp_cp') and pg_collection.tp_dp_cp is not None, ( + "pg_collection must have tp_dp_cp when " "moe_router_enable_expert_bias is enabled." + ) tp_group = pg_collection.tp pp_group = pg_collection.pp embd_group = pg_collection.embd pos_emb_group = pg_collection.pos_embd dp_cp_group = pg_collection.dp_cp + tp_dp_cp_group = pg_collection.tp_dp_cp if config.moe_router_enable_expert_bias else None else: tp_group = parallel_state.get_tensor_model_parallel_group() pp_group = parallel_state.get_pipeline_model_parallel_group() @@ -519,7 +532,11 @@ def finalize_model_grads( config.timers('embedding-grads-all-reduce').stop() if config.moe_router_enable_expert_bias: - _update_router_expert_bias(model, config) + if pg_collection is None: + tp_dp_cp_group = parallel_state.get_tensor_and_data_parallel_group( + with_context_parallel=True + ) + _update_router_expert_bias(model, config, tp_dp_cp_group=tp_dp_cp_group) reset_model_temporary_tensors(config, model) diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 2ddc17a567a..0baf2c65cc5 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -8,7 +8,7 @@ import torch -from megatron.core import parallel_state, tensor_parallel, utils +from megatron.core import tensor_parallel, utils from megatron.core.extensions.transformer_engine import HAVE_TE from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.module import MegatronModule @@ -509,7 +509,7 @@ def shared_experts_compute(self, hidden_states: torch.Tensor): apply_module(self.shared_experts), False, tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), + self.tp_group, hidden_states, ) else: @@ -672,7 +672,7 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None): custom_forward, False, tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), + self.tp_group, hidden_states, intermediate_tensors, padding_mask, diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index f258f3474ae..e1b581898ab 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -1161,7 +1161,10 @@ def track_moe_metrics( def get_updated_expert_bias( - tokens_per_expert: torch.Tensor, expert_bias: torch.Tensor, expert_bias_update_rate: float + tokens_per_expert: torch.Tensor, + expert_bias: torch.Tensor, + expert_bias_update_rate: float, + tp_dp_cp_group: Optional[torch.distributed.ProcessGroup] = None, ) -> torch.Tensor: """Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1# @@ -1169,17 +1172,21 @@ def get_updated_expert_bias( tokens_per_expert (torch.Tensor): The number of tokens assigned to each expert. expert_bias (torch.Tensor): The bias for each expert. expert_bias_udpate_rate (float): The update rate for the expert bias. + tp_dp_cp_group (torch.distributed.ProcessGroup, optional): The group spanning the tensor, + data, and context parallel ranks that share the router expert-bias update. Returns: torch.Tensor: The updated expert bias. """ with torch.no_grad(): - # All Reduce Across TPxCPxDP group - torch.distributed.all_reduce( - tokens_per_expert, + if tp_dp_cp_group is None: # TODO(Hepteract): delete the usage of the global parallel_state. - group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True), - ) + tp_dp_cp_group = parallel_state.get_tensor_and_data_parallel_group( + with_context_parallel=True + ) + + # All Reduce Across TPxCPxDP group + torch.distributed.all_reduce(tokens_per_expert, group=tp_dp_cp_group) average_tokens = tokens_per_expert.sum(dim=-1, keepdim=True) / tokens_per_expert.shape[-1] offset = average_tokens - tokens_per_expert updated_expert_bias = expert_bias + torch.sign(offset) * expert_bias_update_rate diff --git a/megatron/core/transformer/moe/shared_experts.py b/megatron/core/transformer/moe/shared_experts.py index 61ea47955b8..a565e2ec718 100644 --- a/megatron/core/transformer/moe/shared_experts.py +++ b/megatron/core/transformer/moe/shared_experts.py @@ -229,10 +229,12 @@ def pre_forward_comm(self, input, wait_current_stream=True): self.gate_score = torch.nn.functional.sigmoid(logits) if self.config.sequence_parallel: self.cached_fc1_input = gather_from_sequence_parallel_region( - input, tensor_parallel_output_grad=True + input, tensor_parallel_output_grad=True, group=self.tp_group ) else: - self.cached_fc1_input = copy_to_tensor_model_parallel_region(input) + self.cached_fc1_input = copy_to_tensor_model_parallel_region( + input, group=self.tp_group + ) set_tensor_grad_fn_sequence_sr(self.cached_fc1_input, torch.iinfo(torch.int).max) @overlap_state_check( @@ -321,11 +323,11 @@ def post_forward_comm(self): with torch.cuda.stream(self.stream): if self.config.sequence_parallel: self.cached_output = reduce_scatter_to_sequence_parallel_region( - self.cached_fc2_output + self.cached_fc2_output, group=self.tp_group ) else: self.cached_output = reduce_from_tensor_model_parallel_region( - self.cached_fc2_output + self.cached_fc2_output, group=self.tp_group ) self.cached_fc2_output = None set_tensor_grad_fn_sequence_sr(self.cached_output, torch.iinfo(torch.int).max) diff --git a/tests/unit_tests/distributed/test_finalize_model_grads.py b/tests/unit_tests/distributed/test_finalize_model_grads.py index e1e2e760693..7b1b24b03eb 100644 --- a/tests/unit_tests/distributed/test_finalize_model_grads.py +++ b/tests/unit_tests/distributed/test_finalize_model_grads.py @@ -1,7 +1,9 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +import importlib import inspect import os +from types import SimpleNamespace import pytest import torch @@ -11,14 +13,132 @@ from megatron.core.distributed.finalize_model_grads import ( _allreduce_non_tensor_model_parallel_grads, _allreduce_word_embedding_grads, + _update_router_expert_bias, + finalize_model_grads, ) from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils +_MISSING = object() +_FINALIZE_MODEL_GRADS_MODULE = importlib.import_module( + "megatron.core.distributed.finalize_model_grads" +) + + +class _FinalizeModelGradsModel(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.finish_grad_sync_calls = 0 + + def finish_grad_sync(self, force_all_reduce=False): + del force_all_reduce + self.finish_grad_sync_calls += 1 + + +def _finalize_model_grads_config(): + return SimpleNamespace( + timers=None, + flextron=False, + moe_router_enable_expert_bias=True, + moe_router_load_balancing_type="none", + ) + + +def _patch_finalize_model_grads_collectives(monkeypatch): + def no_op(*args, **kwargs): + del args, kwargs + + for name in ( + "_allreduce_conditional_embedding_grads", + "_allreduce_non_tensor_model_parallel_grads", + "_allreduce_word_embedding_grads", + "_allreduce_position_embedding_grads", + "reset_model_temporary_tensors", + ): + monkeypatch.setattr(_FINALIZE_MODEL_GRADS_MODULE, name, no_op) + + +def _pg_collection(tp_dp_cp=_MISSING): + pg_collection = ProcessGroupCollection() + pg_collection.tp = object() + pg_collection.pp = object() + pg_collection.embd = None + pg_collection.pos_embd = None + pg_collection.dp_cp = object() + if tp_dp_cp is not _MISSING: + pg_collection.tp_dp_cp = tp_dp_cp + return pg_collection + + +def test_update_router_expert_bias_uses_explicit_group(monkeypatch): + class RouterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.local_tokens_per_expert = torch.tensor([1.0, 3.0]) + self.expert_bias = torch.zeros(2) + + group = object() + router = RouterModule() + config = type("Config", (), {"moe_router_bias_update_rate": 0.25})() + calls = [] + + def fake_get_updated_expert_bias( + tokens_per_expert, expert_bias, expert_bias_update_rate, tp_dp_cp_group=None + ): + calls.append((expert_bias_update_rate, tp_dp_cp_group)) + return expert_bias + 1.0 + + monkeypatch.setattr( + _FINALIZE_MODEL_GRADS_MODULE, "get_updated_expert_bias", fake_get_updated_expert_bias + ) + + _update_router_expert_bias([torch.nn.Sequential(router)], config, tp_dp_cp_group=group) + + assert calls == [(0.25, group)] + torch.testing.assert_close(router.expert_bias, torch.ones(2)) + + +def test_finalize_model_grads_uses_pg_collection_tp_dp_cp(monkeypatch): + _patch_finalize_model_grads_collectives(monkeypatch) + + group = object() + pg_collection = _pg_collection(tp_dp_cp=group) + + calls = [] + + def fake_update_router_expert_bias(model, config, tp_dp_cp_group=None): + calls.append((model, config, tp_dp_cp_group)) + + monkeypatch.setattr( + _FINALIZE_MODEL_GRADS_MODULE, "_update_router_expert_bias", fake_update_router_expert_bias + ) + + config = _finalize_model_grads_config() + model = _FinalizeModelGradsModel(config) + finalize_model_grads([model], pg_collection=pg_collection) + + assert model.finish_grad_sync_calls == 1 + assert calls == [([model], config, group)] + + +def test_finalize_model_grads_requires_tp_dp_cp_for_explicit_groups(monkeypatch): + _patch_finalize_model_grads_collectives(monkeypatch) + + config = _finalize_model_grads_config() + model = _FinalizeModelGradsModel(config) + + for pg_collection in (_pg_collection(), _pg_collection(tp_dp_cp=None)): + with pytest.raises(AssertionError, match="tp_dp_cp"): + finalize_model_grads([model], pg_collection=pg_collection) + assert model.finish_grad_sync_calls == 0 + + class TestAllReduceLNGrads: def init_model(self, share_embeddings_and_output_weights: bool = False): diff --git a/tests/unit_tests/transformer/moe/test_routers.py b/tests/unit_tests/transformer/moe/test_routers.py index a03766d668f..e2d21720c93 100644 --- a/tests/unit_tests/transformer/moe/test_routers.py +++ b/tests/unit_tests/transformer/moe/test_routers.py @@ -26,6 +26,34 @@ HAVE_ROUTER_FUSION = False +def test_get_updated_expert_bias_uses_explicit_group(monkeypatch): + group = object() + all_reduce_groups = [] + + def fake_all_reduce(tensor, group=None): + all_reduce_groups.append(group) + + def unexpected_default_group(**kwargs): + raise AssertionError("expected explicit tp_dp_cp_group") + + monkeypatch.setattr(torch.distributed, "all_reduce", fake_all_reduce) + monkeypatch.setattr( + "megatron.core.transformer.moe.moe_utils.parallel_state." + "get_tensor_and_data_parallel_group", + unexpected_default_group, + ) + + tokens_per_expert = torch.tensor([[1.0, 3.0]]) + expert_bias = torch.zeros_like(tokens_per_expert) + + updated_bias = get_updated_expert_bias( + tokens_per_expert, expert_bias, expert_bias_update_rate=0.1, tp_dp_cp_group=group + ) + + assert all_reduce_groups == [group] + torch.testing.assert_close(updated_bias, torch.tensor([[0.1, -0.1]])) + + class TestTop2Router: def setup_method(self, method): Utils.initialize_model_parallel(1, 1) From 53f2fa5e9de1d7359a18408b8fcc66ef00376253 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Wed, 13 May 2026 02:29:33 +0000 Subject: [PATCH 2/4] Use real distributed tests for MoE grad finalization --- .../distributed/test_finalize_model_grads.py | 154 +++++++----------- .../transformer/moe/test_routers.py | 28 ---- 2 files changed, 57 insertions(+), 125 deletions(-) diff --git a/tests/unit_tests/distributed/test_finalize_model_grads.py b/tests/unit_tests/distributed/test_finalize_model_grads.py index 7b1b24b03eb..9b2d7aa504a 100644 --- a/tests/unit_tests/distributed/test_finalize_model_grads.py +++ b/tests/unit_tests/distributed/test_finalize_model_grads.py @@ -1,9 +1,6 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import importlib import inspect import os -from types import SimpleNamespace import pytest import torch @@ -13,7 +10,6 @@ from megatron.core.distributed.finalize_model_grads import ( _allreduce_non_tensor_model_parallel_grads, _allreduce_word_embedding_grads, - _update_router_expert_bias, finalize_model_grads, ) from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec @@ -24,16 +20,14 @@ from tests.unit_tests.test_utilities import Utils -_MISSING = object() -_FINALIZE_MODEL_GRADS_MODULE = importlib.import_module( - "megatron.core.distributed.finalize_model_grads" -) - - -class _FinalizeModelGradsModel(torch.nn.Module): - def __init__(self, config): +class _RouterExpertBiasModel(torch.nn.Module): + def __init__(self, config, local_tokens_per_expert): super().__init__() self.config = config + self.ddp_config = DistributedDataParallelConfig() + self.router = torch.nn.Module() + self.router.register_buffer("local_tokens_per_expert", local_tokens_per_expert) + self.router.register_buffer("expert_bias", torch.zeros_like(local_tokens_per_expert)) self.finish_grad_sync_calls = 0 def finish_grad_sync(self, force_all_reduce=False): @@ -41,102 +35,68 @@ def finish_grad_sync(self, force_all_reduce=False): self.finish_grad_sync_calls += 1 -def _finalize_model_grads_config(): - return SimpleNamespace( - timers=None, - flextron=False, +def _router_expert_bias_config(): + return TransformerConfig( + num_layers=1, + hidden_size=8, + num_attention_heads=1, + use_cpu_initialization=True, moe_router_enable_expert_bias=True, + moe_router_score_function="sigmoid", + moe_router_bias_update_rate=0.25, moe_router_load_balancing_type="none", ) -def _patch_finalize_model_grads_collectives(monkeypatch): - def no_op(*args, **kwargs): - del args, kwargs - - for name in ( - "_allreduce_conditional_embedding_grads", - "_allreduce_non_tensor_model_parallel_grads", - "_allreduce_word_embedding_grads", - "_allreduce_position_embedding_grads", - "reset_model_temporary_tensors", - ): - monkeypatch.setattr(_FINALIZE_MODEL_GRADS_MODULE, name, no_op) - - -def _pg_collection(tp_dp_cp=_MISSING): - pg_collection = ProcessGroupCollection() - pg_collection.tp = object() - pg_collection.pp = object() - pg_collection.embd = None - pg_collection.pos_embd = None - pg_collection.dp_cp = object() - if tp_dp_cp is not _MISSING: - pg_collection.tp_dp_cp = tp_dp_cp - return pg_collection - - -def test_update_router_expert_bias_uses_explicit_group(monkeypatch): - class RouterModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.local_tokens_per_expert = torch.tensor([1.0, 3.0]) - self.expert_bias = torch.zeros(2) - - group = object() - router = RouterModule() - config = type("Config", (), {"moe_router_bias_update_rate": 0.25})() - calls = [] - - def fake_get_updated_expert_bias( - tokens_per_expert, expert_bias, expert_bias_update_rate, tp_dp_cp_group=None - ): - calls.append((expert_bias_update_rate, tp_dp_cp_group)) - return expert_bias + 1.0 - - monkeypatch.setattr( - _FINALIZE_MODEL_GRADS_MODULE, "get_updated_expert_bias", fake_get_updated_expert_bias - ) - - _update_router_expert_bias([torch.nn.Sequential(router)], config, tp_dp_cp_group=group) - - assert calls == [(0.25, group)] - torch.testing.assert_close(router.expert_bias, torch.ones(2)) - - -def test_finalize_model_grads_uses_pg_collection_tp_dp_cp(monkeypatch): - _patch_finalize_model_grads_collectives(monkeypatch) - - group = object() - pg_collection = _pg_collection(tp_dp_cp=group) - - calls = [] +def _router_bias_pg_collection(include_tp_dp_cp=True): + required_pgs = ['tp', 'pp', 'embd', 'pos_embd', 'dp_cp'] + if include_tp_dp_cp: + required_pgs.append('tp_dp_cp') + return ProcessGroupCollection.use_mpu_process_groups(required_pgs) - def fake_update_router_expert_bias(model, config, tp_dp_cp_group=None): - calls.append((model, config, tp_dp_cp_group)) - monkeypatch.setattr( - _FINALIZE_MODEL_GRADS_MODULE, "_update_router_expert_bias", fake_update_router_expert_bias - ) - - config = _finalize_model_grads_config() - model = _FinalizeModelGradsModel(config) - finalize_model_grads([model], pg_collection=pg_collection) - - assert model.finish_grad_sync_calls == 1 - assert calls == [([model], config, group)] +class TestFinalizeModelGradsMoEExpertBias: + def setup_method(self, method): + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + Utils.destroy_model_parallel() + Utils.initialize_model_parallel(tensor_model_parallel_size=min(2, Utils.world_size)) + def teardown_method(self, method): + Utils.destroy_model_parallel() -def test_finalize_model_grads_requires_tp_dp_cp_for_explicit_groups(monkeypatch): - _patch_finalize_model_grads_collectives(monkeypatch) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_finalize_model_grads_updates_router_expert_bias_with_custom_group(self): + config = _router_expert_bias_config() + device = torch.device("cuda", torch.cuda.current_device()) + local_tokens = torch.tensor([float(torch.distributed.get_rank() + 1), 0.0], device=device) + model = _RouterExpertBiasModel(config, local_tokens) - config = _finalize_model_grads_config() - model = _FinalizeModelGradsModel(config) + finalize_model_grads([model], pg_collection=_router_bias_pg_collection()) - for pg_collection in (_pg_collection(), _pg_collection(tp_dp_cp=None)): - with pytest.raises(AssertionError, match="tp_dp_cp"): - finalize_model_grads([model], pg_collection=pg_collection) - assert model.finish_grad_sync_calls == 0 + expected_bias = torch.tensor([-0.25, 0.25], device=device) + torch.testing.assert_close(model.router.expert_bias, expected_bias) + torch.testing.assert_close( + model.router.local_tokens_per_expert, torch.zeros_like(local_tokens) + ) + assert model.finish_grad_sync_calls == 1 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_finalize_model_grads_requires_custom_group_before_grad_sync(self): + config = _router_expert_bias_config() + device = torch.device("cuda", torch.cuda.current_device()) + pg_collections = [ + _router_bias_pg_collection(include_tp_dp_cp=False), + _router_bias_pg_collection(), + ] + pg_collections[1].tp_dp_cp = None + + for pg_collection in pg_collections: + model = _RouterExpertBiasModel(config, torch.tensor([1.0, 0.0], device=device)) + with pytest.raises(AssertionError, match="tp_dp_cp"): + finalize_model_grads([model], pg_collection=pg_collection) + assert model.finish_grad_sync_calls == 0 class TestAllReduceLNGrads: diff --git a/tests/unit_tests/transformer/moe/test_routers.py b/tests/unit_tests/transformer/moe/test_routers.py index e2d21720c93..a03766d668f 100644 --- a/tests/unit_tests/transformer/moe/test_routers.py +++ b/tests/unit_tests/transformer/moe/test_routers.py @@ -26,34 +26,6 @@ HAVE_ROUTER_FUSION = False -def test_get_updated_expert_bias_uses_explicit_group(monkeypatch): - group = object() - all_reduce_groups = [] - - def fake_all_reduce(tensor, group=None): - all_reduce_groups.append(group) - - def unexpected_default_group(**kwargs): - raise AssertionError("expected explicit tp_dp_cp_group") - - monkeypatch.setattr(torch.distributed, "all_reduce", fake_all_reduce) - monkeypatch.setattr( - "megatron.core.transformer.moe.moe_utils.parallel_state." - "get_tensor_and_data_parallel_group", - unexpected_default_group, - ) - - tokens_per_expert = torch.tensor([[1.0, 3.0]]) - expert_bias = torch.zeros_like(tokens_per_expert) - - updated_bias = get_updated_expert_bias( - tokens_per_expert, expert_bias, expert_bias_update_rate=0.1, tp_dp_cp_group=group - ) - - assert all_reduce_groups == [group] - torch.testing.assert_close(updated_bias, torch.tensor([[0.1, -0.1]])) - - class TestTop2Router: def setup_method(self, method): Utils.initialize_model_parallel(1, 1) From 95aeafe76f491e93d569488c51695daf2ea88da4 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Wed, 13 May 2026 03:34:43 +0000 Subject: [PATCH 3/4] Avoid global parallel state in MoE grad tests --- .../distributed/test_finalize_model_grads.py | 38 ++++++++++++++----- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/tests/unit_tests/distributed/test_finalize_model_grads.py b/tests/unit_tests/distributed/test_finalize_model_grads.py index 9b2d7aa504a..ee535c29baf 100644 --- a/tests/unit_tests/distributed/test_finalize_model_grads.py +++ b/tests/unit_tests/distributed/test_finalize_model_grads.py @@ -4,6 +4,7 @@ import pytest import torch +import torch.distributed as dist from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallelConfig @@ -48,11 +49,20 @@ def _router_expert_bias_config(): ) -def _router_bias_pg_collection(include_tp_dp_cp=True): - required_pgs = ['tp', 'pp', 'embd', 'pos_embd', 'dp_cp'] - if include_tp_dp_cp: - required_pgs.append('tp_dp_cp') - return ProcessGroupCollection.use_mpu_process_groups(required_pgs) +_NO_TP_DP_CP = object() + + +def _router_bias_pg_collection(tp_dp_cp=_NO_TP_DP_CP): + kwargs = { + 'tp': dist.group.WORLD, + 'pp': dist.group.WORLD, + 'embd': None, + 'pos_embd': None, + 'dp_cp': dist.group.WORLD, + } + if tp_dp_cp is not _NO_TP_DP_CP: + kwargs['tp_dp_cp'] = tp_dp_cp + return ProcessGroupCollection(**kwargs) class TestFinalizeModelGradsMoEExpertBias: @@ -61,21 +71,28 @@ def setup_method(self, method): os.environ.pop('NVTE_FLASH_ATTN', None) os.environ.pop('NVTE_UNFUSED_ATTN', None) Utils.destroy_model_parallel() - Utils.initialize_model_parallel(tensor_model_parallel_size=min(2, Utils.world_size)) + Utils.initialize_distributed() + parallel_state.destroy_model_parallel() def teardown_method(self, method): Utils.destroy_model_parallel() @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_finalize_model_grads_updates_router_expert_bias_with_custom_group(self): + assert not parallel_state.model_parallel_is_initialized() + config = _router_expert_bias_config() device = torch.device("cuda", torch.cuda.current_device()) - local_tokens = torch.tensor([float(torch.distributed.get_rank() + 1), 0.0], device=device) + local_tokens = torch.tensor( + [0.0, 2.0] if dist.get_rank() == 0 else [0.0, 0.0], device=device + ) model = _RouterExpertBiasModel(config, local_tokens) - finalize_model_grads([model], pg_collection=_router_bias_pg_collection()) + finalize_model_grads( + [model], pg_collection=_router_bias_pg_collection(tp_dp_cp=dist.group.WORLD) + ) - expected_bias = torch.tensor([-0.25, 0.25], device=device) + expected_bias = torch.tensor([0.25, -0.25], device=device) torch.testing.assert_close(model.router.expert_bias, expected_bias) torch.testing.assert_close( model.router.local_tokens_per_expert, torch.zeros_like(local_tokens) @@ -84,11 +101,12 @@ def test_finalize_model_grads_updates_router_expert_bias_with_custom_group(self) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_finalize_model_grads_requires_custom_group_before_grad_sync(self): + assert not parallel_state.model_parallel_is_initialized() config = _router_expert_bias_config() device = torch.device("cuda", torch.cuda.current_device()) pg_collections = [ - _router_bias_pg_collection(include_tp_dp_cp=False), _router_bias_pg_collection(), + _router_bias_pg_collection(tp_dp_cp=dist.group.WORLD), ] pg_collections[1].tp_dp_cp = None From 5406627a2633bd9c4c5bb7f7af8d898e4688151e Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Wed, 13 May 2026 16:18:50 +0000 Subject: [PATCH 4/4] Simplify router expert-bias group selection --- megatron/core/distributed/finalize_model_grads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py index 53db86b4a5c..7d9179d1c50 100644 --- a/megatron/core/distributed/finalize_model_grads.py +++ b/megatron/core/distributed/finalize_model_grads.py @@ -478,12 +478,12 @@ def finalize_model_grads( assert hasattr(pg_collection, 'tp_dp_cp') and pg_collection.tp_dp_cp is not None, ( "pg_collection must have tp_dp_cp when " "moe_router_enable_expert_bias is enabled." ) + tp_dp_cp_group = pg_collection.tp_dp_cp tp_group = pg_collection.tp pp_group = pg_collection.pp embd_group = pg_collection.embd pos_emb_group = pg_collection.pos_embd dp_cp_group = pg_collection.dp_cp - tp_dp_cp_group = pg_collection.tp_dp_cp if config.moe_router_enable_expert_bias else None else: tp_group = parallel_state.get_tensor_model_parallel_group() pp_group = parallel_state.get_pipeline_model_parallel_group()