From f5ff88b5e5f80adf1b116b3055e53849083a574c Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Tue, 24 Feb 2026 12:08:30 -0800 Subject: [PATCH] war fix: disable registration of the other groups except for the dp_cp group Signed-off-by: Youngeun Kwon --- .../megatron_fsdp/param_and_grad_buffer.py | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py index 3ec117ebd9e..663d9096a14 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py @@ -1613,24 +1613,28 @@ def __init__( ) # Select the communicator groups to register FSDP buffers. self.ubr_groups = [self.dist_index.get_fsdp_group(is_expert_parallel=False)] - if self.dist_index.get_fsdp_group(is_expert_parallel=True) is not None: - # Expert-DP group when using EP - self.ubr_groups.append(self.dist_index.get_fsdp_group(is_expert_parallel=True)) - if self.dist_index.get_outer_fsdp_group() is not None: - # Outer/Inter-FSDP group when using hybrid FSDP - self.ubr_groups.append(self.dist_index.get_outer_fsdp_group()) - if ( - self.dist_index.get_fsdp_group( - is_expert_parallel=False, independent_all_gather=True - ) - is not None - ): - # All-gather group used when overlapping all-gather and gradient reduction. - self.ubr_groups.append( + # Currernlty we are not supporting symmetric registration for other groups. + # For now, we only allow to register buffer to other groups when symmetric + # registration is disabled. We will support it in the future. + if self.ddp_config.disable_symmetric_registration: + if self.dist_index.get_fsdp_group(is_expert_parallel=True) is not None: + # Expert-DP group when using EP + self.ubr_groups.append(self.dist_index.get_fsdp_group(is_expert_parallel=True)) + if self.dist_index.get_outer_fsdp_group() is not None: + # Outer/Inter-FSDP group when using hybrid FSDP + self.ubr_groups.append(self.dist_index.get_outer_fsdp_group()) + if ( self.dist_index.get_fsdp_group( is_expert_parallel=False, independent_all_gather=True ) - ) + is not None + ): + # All-gather group used when overlapping all-gather and gradient reduction. + self.ubr_groups.append( + self.dist_index.get_fsdp_group( + is_expert_parallel=False, independent_all_gather=True + ) + ) log_single_rank( logger,