From 3f593b76425eef46cc844f5cdffa5faa5edfa845 Mon Sep 17 00:00:00 2001 From: Ahmad Kiswani Date: Mon, 5 Jan 2026 18:37:22 +0200 Subject: [PATCH 1/2] Fix stale parameter index map in DistributedOptimizer During mixed-precision training, model_param_group_index_map becomes stale after parameters are reordered by dtype. The map is initialized based on original parameter order, but _build_model_and_main_param_groups reorders parameters (FP32 first, then FP16/BF16) for checkpoint consistency. The map was never updated to reflect this reordering. This caused get_parameter_state_dp_zero to access wrong parameters when saving checkpoints, resulting in size mismatch errors. Solution: Rebuild model_param_group_index_map after parameter reordering to keep it synchronized with optimizer.param_groups. Fixes https://github.com/NVIDIA/Megatron-LM/issues/2777 --- megatron/core/optimizer/distrib_optimizer.py | 17 ++++ tests/unit_tests/test_optimizer.py | 99 ++++++++++++++++++++ 2 files changed, 116 insertions(+) diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index 6e093f96f7e..c0cf5e992d6 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -604,6 +604,23 @@ def __init__( self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges] self.optimizer.load_state_dict(self.optimizer.state_dict()) + # Rebuild model_param_group_index_map to reflect parameter reordering. + # The _build_model_and_main_param_groups method reorders parameters by dtype + # (FP32 first, then FP16/BF16), so we need to update the mapping to match + # the new positions in optimizer.param_groups. + for group_index, group_range in enumerate(self.opt_group_ranges): + param_order = 0 + # First, add FP32 params (in the same order as they appear in group_range["params"]) + for model_param in group_range["params"]: + if model_param.type() == 'torch.cuda.FloatTensor': + self.model_param_group_index_map[model_param] = (group_index, param_order) + param_order += 1 + # Then, add FP16/BF16 params (in the same order as they appear in group_range["params"]) + for model_param in group_range["params"]: + if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: + self.model_param_group_index_map[model_param] = (group_index, param_order) + param_order += 1 + def _get_model_param_range_map(self, param: torch.nn.Parameter): """ Given a model param, get the index sub-range of the param that this diff --git a/tests/unit_tests/test_optimizer.py b/tests/unit_tests/test_optimizer.py index 1c64ad86c52..13ca216ff20 100644 --- a/tests/unit_tests/test_optimizer.py +++ b/tests/unit_tests/test_optimizer.py @@ -750,3 +750,102 @@ def test_get_megatron_optimizer_custom_process_groups_validation(): use_gloo_process_groups=True, # Should be False when using custom groups pg_collection=pg_collection_complete, ) + + +@pytest.mark.parametrize("use_distributed_optimizer", [True]) +def test_mixed_precision_param_index_map(use_distributed_optimizer: bool): + """ + Test that model_param_group_index_map stays synchronized after parameter reordering. + + This test addresses issue #2777 where the index map becomes stale after + _build_model_and_main_param_groups reorders parameters by dtype (FP32 first, + then FP16/BF16). The test creates a model with mixed precision parameters + and verifies that checkpoint operations work correctly. + """ + world = int(os.getenv('WORLD_SIZE', '1')) + rank = int(os.getenv('RANK', '0')) + + # Setup distributed environment + _init_distributed(world, rank) + Utils.initialize_model_parallel() + + # Create a model with mixed precision parameters + # We'll manually set some parameters to FP32 and others to BF16 + class MixedPrecisionModel(nn.Module): + def __init__(self): + super().__init__() + # First layer in BF16 + self.fc1 = nn.Linear(100, 50, bias=False, dtype=torch.bfloat16, device='cuda') + # Second layer in FP32 (simulating manual precision promotion) + self.fc2 = nn.Linear(50, 30, bias=False, dtype=torch.float32, device='cuda') + # Third layer in BF16 + self.fc3 = nn.Linear(30, 10, bias=False, dtype=torch.bfloat16, device='cuda') + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x.float()) + x = self.fc3(x.bfloat16()) + return x + + model = MixedPrecisionModel() + model.requires_grad_(True) + + # Wrap with DDP + ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=use_distributed_optimizer) + model = DistributedDataParallel( + TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model + ) + + # Create optimizer with distributed optimizer enabled + optimizer_config = OptimizerConfig( + optimizer='adam', bf16=True, use_distributed_optimizer=use_distributed_optimizer + ) + optim = get_megatron_optimizer(optimizer_config, [model]) + + # Access the underlying distributed optimizer + if use_distributed_optimizer: + dist_optim = optim.optimizer + + # Verify that model_param_group_index_map is correctly synchronized + # After the fix, the map should reflect the reordered parameters + for model_param in dist_optim.model_param_group_index_map.keys(): + group_index, group_order = dist_optim.model_param_group_index_map[model_param] + + # Verify the index points to a valid parameter + assert group_index < len( + dist_optim.optimizer.param_groups + ), f"group_index {group_index} out of range" + assert group_order < len( + dist_optim.optimizer.param_groups[group_index]["params"] + ), f"group_order {group_order} out of range for group {group_index}" + + # Get the corresponding optimizer parameter + opt_param = dist_optim.optimizer.param_groups[group_index]["params"][group_order] + + # Verify the sizes match (this would fail before the fix) + model_param_range = dist_optim._get_model_param_range_map(model_param) + param_range = model_param_range["param"] + assert param_range.size == opt_param.numel(), ( + f"Size mismatch: model param range size {param_range.size} " + f"!= optimizer param size {opt_param.numel()}" + ) + + # Run a forward/backward pass to populate optimizer state + input_data = torch.randn(8, 100, dtype=torch.bfloat16, device='cuda') + output = model(input_data) + loss = output.sum() + loss.backward() + optim.step() + + # Test get_parameter_state_dp_zero (the function that was failing in issue #2777) + # This should work without size mismatch errors + try: + state_dict = dist_optim.get_parameter_state_dp_zero() + # Verify state_dict was created successfully + if rank == 0 or state_dict is not None: + assert state_dict is not None, "Failed to get parameter state" + assert 'buckets_coalesced' in state_dict, "Missing expected keys in state dict" + except RuntimeError as e: + pytest.fail(f"get_parameter_state_dp_zero failed with error: {e}") + + _deinit_distributed() From 537fa1ae174073e4825cbf7df9dba3098609b695 Mon Sep 17 00:00:00 2001 From: Ahmad Kiswani Date: Mon, 12 Jan 2026 17:03:30 +0200 Subject: [PATCH 2/2] Fix shard_model_param in FP32 Parameter shards should never participate in autograd, inconsistent with FP16/BF16 --- megatron/core/optimizer/distrib_optimizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index c0cf5e992d6..903ae4a19d8 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -415,7 +415,9 @@ def _build_model_and_main_param_groups( # fp32 params. elif model_param.type() == 'torch.cuda.FloatTensor': - shard_model_param = model_param.view(-1)[param_range.start : param_range.end] + shard_model_param = model_param.detach().view(-1)[ + param_range.start : param_range.end + ] model_fp32_params_this_group.append(model_param) shard_fp32_params_this_group.append(shard_model_param) tensor_parallel.copy_tensor_model_parallel_attributes(