diff --git a/.github/oncall_schedule.json b/.github/oncall_schedule.json index 2f6e01c786c..55b734d3711 100644 --- a/.github/oncall_schedule.json +++ b/.github/oncall_schedule.json @@ -1,8 +1,4 @@ [ - { - "user": "janEbert", - "date": "2026-05-06" - }, { "user": "dimapihtar", "date": "2026-05-13" @@ -46,5 +42,9 @@ { "user": "Phlip79", "date": "2026-07-22" + }, + { + "user": "YangFei1990", + "date": "2026-07-29" } ] diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index a711f1405d1..2cb454a2027 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -165,6 +165,8 @@ def __init__( param_indices == layout.param_indices ), f"param_indices for {buffer_key} do not match between grouping and layout" + self.full_param_layout = full_param_layout + # Compute gradient scaling factors. if config.calculate_per_token_loss: assert ( diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 38d84ffcce5..eaa8b722237 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -208,6 +208,16 @@ def __init__( not self.ddp_config.reduce_scatter_with_fp32_accumulation ), "RS w/ FP32 accumulation not supported with num_distributed_optimizer_instances > 1" + reduction_collective = ( + "reduce-scatter" if self.ddp_config.use_distributed_optimizer else "all-reduce" + ) + log_single_rank( + logger, + logging.INFO, + f"Using {reduction_collective} for gradient reductions because " + f"{self.ddp_config.use_distributed_optimizer=}", + ) + global dist_reduce_scatter_func if self.ddp_config.reduce_scatter_with_fp32_accumulation: dist_reduce_scatter_func = reduce_scatter_with_fp32_accumulation @@ -322,8 +332,9 @@ def start_param_sync(self, force_sync: bool = False): async_op = self.ddp_config.overlap_param_gather and not force_sync if not self.ddp_config.use_distributed_optimizer: - # Layer-wise optimizer path: use all_gather for variable-size - # param gather. + # Legacy layer-wise optimizer path: use all_gather for variable-size + # param gather. Once all layerwise call sites set + # ddp_config.use_distributed_optimizer=True, this branch can be removed. # # Each rank may own a different number of params per bucket, so # layerwise_param_flat_sizes can vary across ranks. PyTorch's NCCL diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index b9f62e59547..d8793f01d67 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -271,6 +271,8 @@ def _calculate_cuda_graph_token_counts( ) # Make sure divisible by TP size cuda_graph_step_size = round_up_to_nearest_multiple(cuda_graph_step_size, tp_size) + # Ensure non-zero step size (can happen when max_tokens < num_cuda_graphs). + cuda_graph_step_size = max(cuda_graph_step_size, tp_size) # round down cuda graph max tokens to be multiple of TP size cuda_graph_max_tokens = (cuda_graph_max_tokens // tp_size) * tp_size @@ -367,11 +369,9 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int ): cuda_graph_max_tokens = max_tokens - assert cuda_graph_max_tokens == max_requests * (num_speculative_tokens + 1), ( - f"cuda_graph_max_tokens ({cuda_graph_max_tokens}) must equal max_requests *" - f"(num_speculative_tokens + 1) ({max_requests * (num_speculative_tokens + 1)}). " - "This is required for correctly syncing EP ranks: " - f"prefill and decode graph pools must have the same token count granularity." + assert cuda_graph_max_tokens >= max_requests * (num_speculative_tokens + 1), ( + f"cuda_graph_max_tokens ({cuda_graph_max_tokens}) must be >= max_requests * " + f"(num_speculative_tokens + 1) ({max_requests * (num_speculative_tokens + 1)})." ) if num_cuda_graphs != -1: diff --git a/megatron/core/inference/config.py b/megatron/core/inference/config.py index a39ec038051..e8769f3d6e7 100644 --- a/megatron/core/inference/config.py +++ b/megatron/core/inference/config.py @@ -188,8 +188,12 @@ class InferenceConfig: # ================================= num_cuda_graphs: Optional[int] = None """ - Maximum number of cuda graphs to capture, where the cuda graph batch sizes range from 1 to - `max_requests`. Due to rounding, the actual number of cuda graphs may not equal this argument. + Maximum number of cuda graphs to capture. + Graph token counts are spaced from 1 up to a per-graph-type budget: + - Decode-only graphs are always bounded by `max_requests * (num_speculative_tokens + 1)`. + - Prefill/mixed graphs share that same bound by default, + or extend up to `max_tokens` when `cuda_graph_all_prefills` is set. + Due to rounding, the actual number of cuda graphs may not equal this argument. """ cuda_graph_mixed_prefill_count: Optional[int] = 16 @@ -202,6 +206,14 @@ class InferenceConfig: Whether to use CUDA graphs for non-decode steps. """ + cuda_graph_all_prefills: bool = False + """ + Whether prefill/mixed CUDA graphs should span up to `max_tokens`. + When False (default), prefill/mixed graphs are bounded by the same token limit as decode graphs: + `max_requests * (num_speculative_tokens + 1)`. + When True, prefill/mixed graph capture is extended to cover the full `max_tokens` budget. + """ + static_kv_memory_pointers: bool = False """ Whether the KV cache (and Mamba states) will reside at the same memory addresses diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index f9695f6c9a5..6c2dcb47340 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -623,12 +623,21 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC and not (force_disable_non_decode_cuda_graphs) ) + # CUDA graph token budget for prefill/mixed graphs. Decode graphs are always + # capped at max_requests * (num_speculative_tokens + 1) inside the helper; this + # only widens the prefill/mixed range when `cuda_graph_all_prefills` is set. + cuda_graph_max_tokens = ( + self.max_tokens + if inference_config.cuda_graph_all_prefills + else self.max_requests * (self.num_speculative_tokens + 1) + ) + # CUDA graph config list. self.cuda_graph_batch_dimensions_list, self.cuda_graph_token_counts = ( CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( tp_size=tp_size, num_cuda_graphs=inference_config.num_cuda_graphs, - cuda_graph_max_tokens=self.max_requests * (self.num_speculative_tokens + 1), + cuda_graph_max_tokens=cuda_graph_max_tokens, cuda_graph_mixed_prefill_request_count=inference_config.cuda_graph_mixed_prefill_count, max_requests=self.max_requests, max_tokens=self.max_tokens, diff --git a/megatron/core/models/mimo/colocated_schedule.py b/megatron/core/models/mimo/colocated_schedule.py new file mode 100644 index 00000000000..e55a588ebcb --- /dev/null +++ b/megatron/core/models/mimo/colocated_schedule.py @@ -0,0 +1,430 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Three-phase schedule for colocated MIMO training with LLM PP>1. + +Phase 1: Encoder forward + communicate for the full batch (all ranks synchronized). +Phase 2: LLM 1F1B pipeline with detached encoder embeddings sliced per microbatch. +Phase 3: Encoder backward for the full batch (all ranks synchronized). + +Encoder runs on all ranks (PP=1) and its TP/DP collectives require all ranks +to participate simultaneously. The 1F1B pipeline staggers ranks across PP stages, +so encoder collectives cannot run inside the pipeline. The three-phase design +separates encoder (synchronized) from LLM (pipelined) by detaching the autograd +graph at the encoder-LLM boundary. + +Shape contract: encoder input tensors are 3D ``[seq, batch, hidden]`` with +the batch dim at ``dim=1``. Encoder output embeddings are either 3D +``[seq, batch, hidden]`` (batch dim = 1) or 2D ``[seq*batch, hidden]`` +(batch dim = 0); the bridge may collapse the leading two dims. Other +layouts (e.g. ``[B, C, H, W]`` images) are not supported. + +DP-direction contract: fan-in (enc_dp > llm_dp), fan-out (enc_dp < llm_dp), +and equal-DP are all supported. The ColocatedBridgeCommunicator handles +the encoder-side reshape on both forward (fan-in: all-gather, fan-out: +narrow) and backward (fan-in: scatter, fan-out: all-gather). The schedule's +job is to hand each side its correctly-sized slice of the global batch: + + * Fan-in: data iterator yields LLM-DP-sized per-rank batches; the + schedule narrows encoder inputs to the encoder rank's smaller slot + in ``_slice_for_encoder_dp`` before encode_and_communicate. + * Fan-out: data iterator yields encoder-DP-sized per-rank batches; the + bridge narrows encoder embeddings to the LLM-DP rank's slot inside + encode_and_communicate, and ``_build_lm_microbatches`` narrows the + LLM-side passthrough fields (input_ids, labels, loss_mask, + position_ids) to the same slot so they line up with the bridge + output for the LLM forward. +""" + +from contextlib import contextmanager +from functools import partial +from typing import Optional + +import torch +import torch.distributed as dist + +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.pipeline_parallel import schedules + + +def colocated_forward_backward_with_pp( + mimo_model, + data_iterator, + num_microbatches: int, + encoder_grid: Optional[HyperCommGrid] = None, + llm_grid: Optional[HyperCommGrid] = None, + encoder_name: str = "images", + forward_only: bool = False, + **schedule_kwargs, +): + """Three-phase colocated training: encoder batch -> LLM pipeline -> encoder backward. + + Args: + mimo_model: MimoModel with colocated communicators and lm_has_pp=True. + data_iterator: Yields dicts with input_ids, labels, etc. + num_microbatches: Number of microbatches for the LLM pipeline. + encoder_grid: Encoder HyperCommGrid (for DP fan-in slicing). + llm_grid: LLM HyperCommGrid (for PP group). + encoder_name: Modality name for the encoder (e.g., "images"). + forward_only: Skip backward passes if True. + **schedule_kwargs: Passed to forward_backward_pipelining_without_interleaving. + Must include p2p_communicator, pg_collection, seq_length, micro_batch_size. + """ + pp_group = llm_grid.get_pg("pp") if llm_grid and 'pp' in llm_grid.dim_names else None + is_pp_first = pp_group is None or pp_group.rank() == 0 + + # ── Phase 1: Encoder forward on full batch (one pass) ──────────────── + # All ranks participate (encoder is PP=1, communicate is collective). + all_batches = [next(data_iterator) for _ in range(num_microbatches)] + full_encoder_input = _concat_encoder_inputs(all_batches, encoder_name) + _slice_for_encoder_dp(full_encoder_input, encoder_grid, llm_grid) + + enc_out = mimo_model.encode_and_communicate({encoder_name: full_encoder_input}) + + # Detach so Phase 2 runs no encoder collectives; microbatch views accumulate + # .grad into detached_full.grad automatically. + detached_full = {k: v.detach().requires_grad_(True) for k, v in enc_out.items()} + lm_data = _build_lm_microbatches( + detached_full, all_batches, num_microbatches, encoder_grid, llm_grid + ) + + # ── Phase 2: LLM 1F1B pipeline ────────────────────────────────────── + # Only LLM P2P communication (within PP group). No encoder collectives. + cache_iter = iter(lm_data) + + def _lm_forward_step(data_iterator_unused, model, *args): + cached = next(cache_iter) + forward_kwargs = dict( + input_ids=cached['input_ids'], + labels=cached['labels'], + loss_mask=cached['loss_mask'], + position_ids=cached['position_ids'], + encoder_embeddings=cached['encoder_embeddings'], + ) + if cached.get('attention_mask') is not None: + forward_kwargs['attention_mask'] = cached['attention_mask'] + if cached.get('packing_kwargs') is not None: + forward_kwargs['packing_kwargs'] = cached['packing_kwargs'] + output_tensor, loss_mask = model(**forward_kwargs) + return output_tensor, partial(_loss_func, cached['loss_mask']) + + # Swap in a capturing finalize so the inner PP schedule does not run DDP + # grad sync before Phase 3 has produced encoder grads. The capture also + # records ``num_tokens`` and ``force_all_reduce`` that the inner schedule + # would have passed — we forward them to the original finalize after + # Phase 3 so per-token-loss configs see the correct global divisor and + # any caller-requested all-reduce semantics are preserved. + with _deferred_finalize(mimo_model.config) as (original_finalize, capture): + losses = schedules.forward_backward_pipelining_without_interleaving( + forward_step_func=_lm_forward_step, + data_iterator=cache_iter, + model=[mimo_model], + num_microbatches=num_microbatches, + forward_only=forward_only, + **schedule_kwargs, + ) + + # ── Phase 3: Encoder backward (one pass, all ranks sync) ──────────── + # detached_full.grad was populated by Phase 2's per-microbatch LLM backward + # (accumulated across microbatch view slices on PP stage 0). + # Broadcast to PP stage 1+ then run one encoder backward for the full batch. + if not forward_only and enc_out: + _broadcast_encoder_grad(detached_full, enc_out, pp_group, is_pp_first) + for key in enc_out: + grad = detached_full[key].grad + if grad is not None: + torch.autograd.backward(enc_out[key], grad_tensors=grad) + + # Single post-Phase-3 finalize: reduces LLM grads (from Phase 2) and + # encoder grads (from Phase 3) together. Without this call, encoder + # grads remain local to each rank and Adam steps on un-reduced grads, + # causing silent divergence from the equal-DP reference. Forward the + # captured force_all_reduce so callers requesting that semantics + # (e.g. final-microbatch sync with overlap_grad_reduce) get it. + if not forward_only and original_finalize is not None: + original_finalize( + [mimo_model], + capture.num_tokens, + pg_collection=schedule_kwargs.get('pg_collection'), + force_all_reduce=capture.force_all_reduce, + ) + + return losses + + +# ── Helpers ────────────────────────────────────────────────────────────── + + +def _fan_out_slot(encoder_grid, llm_grid): + """Return ``(scale, slot)`` for fan-out LLM-side narrowing. + + For fan-out (``llm_dp > enc_dp``) the data iterator yields encoder-DP- + sized per-rank batches. The bridge narrows encoder embeddings to this + LLM-DP rank's slot inside ``encode_and_communicate``; LLM-side fields + (input_ids, labels, ...) must be narrowed to the SAME slot so they + line up with the bridge output for the LLM forward. Returns + ``(scale, slot)`` where ``slot`` is this rank's index inside the + fan-out sibling group; ``(1, 0)`` for equal-DP and fan-in (where the + LLM-side fields are already correctly sized for the LLM-DP rank). + """ + if encoder_grid is None or llm_grid is None: + return 1, 0 + enc_dp = encoder_grid.get_pg("dp").size() + llm_dp = llm_grid.get_pg("dp").size() + if llm_dp <= enc_dp: + return 1, 0 + scale = llm_dp // enc_dp + slot = llm_grid.get_pg("dp").rank() % scale + return scale, slot + + +def _modality_present(batch, encoder_name): + """Return True iff this batch carries inputs for ``encoder_name``.""" + mod_in = batch.get('modality_inputs') + return bool(mod_in) and encoder_name in mod_in and mod_in[encoder_name] is not None + + +def _concat_encoder_inputs(all_batches, encoder_name): + """Concatenate encoder inputs from all microbatches along batch dim (dim=1). + + All encoder input tensors must be 3D ``[seq, batch, hidden]``. All + microbatches must uniformly have or lack ``modality_inputs[encoder_name]``; + mixed batches are rejected because Phase 2 reuses one detached encoder + output across every LLM microbatch. + """ + first = all_batches[0] + has_first = _modality_present(first, encoder_name) + for idx, b in enumerate(all_batches): + if _modality_present(b, encoder_name) != has_first: + raise ValueError( + f"colocated_forward_backward_with_pp requires uniform " + f"modality_inputs across microbatches for '{encoder_name}'; " + f"microbatch 0 has it = {has_first} but microbatch {idx} differs." + ) + if not has_first: + return {} + result = {} + for enc_name in first['modality_inputs'][encoder_name]: + result[enc_name] = {} + for key in first['modality_inputs'][encoder_name][enc_name]: + vals = [b['modality_inputs'][encoder_name][enc_name][key] for b in all_batches] + tensors = [v for v in vals if isinstance(v, torch.Tensor)] + if tensors: + for v in tensors: + if v.ndim != 3: + raise ValueError( + f"encoder input '{enc_name}.{key}' must be 3D " + f"[seq, batch, hidden], got shape={tuple(v.shape)}" + ) + result[enc_name][key] = torch.cat(tensors, dim=1) + else: + result[enc_name][key] = vals[0] + return result + + +def _slice_for_encoder_dp(full_encoder_input, encoder_grid, llm_grid): + """Slice concatenated encoder input for fan-in (enc_dp > llm_dp). + + Encoder input tensors must be 3D ``[seq, batch, hidden]``. For fan-in + the data iterator yields LLM-DP-sized per-rank batches; this helper + narrows them to the encoder rank's smaller slot before forward. + Equal-DP and fan-out (where the per-rank batch is already encoder-DP- + sized — the bridge narrows on the LLM side) are no-ops. + """ + if encoder_grid is None or llm_grid is None: + return + enc_dp = encoder_grid.get_pg("dp").size() + llm_dp = llm_grid.get_pg("dp").size() + if enc_dp <= llm_dp: + return + scale = enc_dp // llm_dp + slot = encoder_grid.get_pg("dp").rank() % scale + for enc_name in full_encoder_input: + for key, tensor in full_encoder_input[enc_name].items(): + if not isinstance(tensor, torch.Tensor): + continue + if tensor.ndim != 3: + raise ValueError( + f"encoder input '{enc_name}.{key}' must be 3D " + f"[seq, batch, hidden], got shape={tuple(tensor.shape)}" + ) + bs = tensor.shape[1] + if bs % scale != 0: + raise ValueError( + f"Encoder fan-in: tensor batch={bs} not divisible by scale={scale}." + ) + ss = bs // scale + if ss == 0: + raise ValueError( + f"Encoder fan-in produces zero-sized batch: " + f"total_batch={bs}, scale={scale}. Increase micro_batch_size." + ) + full_encoder_input[enc_name][key] = tensor[ + :, slot * ss : (slot + 1) * ss, : + ].contiguous() + + +def _build_lm_microbatches( + detached_full, all_batches, num_microbatches, encoder_grid=None, llm_grid=None +): + """Slice detached encoder output into per-microbatch views for the LLM pipeline. + + Encoder embeddings are either 3D ``[seq, batch, hidden]`` (batch dim = 1) + or 2D ``[seq*batch, hidden]`` (batch dim = 0); the bridge may collapse + the leading two dims. Other layouts are rejected. Pass-through fields + (input_ids, labels, loss_mask, position_ids, attention_mask, packing_kwargs) + are copied per microbatch from the corresponding ``all_batches`` entry. + + For fan-out (``llm_dp > enc_dp``) the per-microbatch passthrough fields + arrive at the encoder-DP-sized batch; this helper narrows them to the + LLM-DP rank's slot via :func:`_fan_out_slot` so they line up with the + bridge-narrowed encoder embeddings. Fan-in and equal-DP leave the + fields unchanged (``scale=1, slot=0``). + """ + fan_out_scale, fan_out_slot = _fan_out_slot(encoder_grid, llm_grid) + + def _maybe_narrow(tensor): + """Narrow a batch-dim-0 tensor to this LLM-DP rank's fan-out slot.""" + if fan_out_scale == 1 or tensor is None or not isinstance(tensor, torch.Tensor): + return tensor + bs = tensor.shape[0] + if bs % fan_out_scale != 0: + raise ValueError( + f"Fan-out narrowing: tensor batch={bs} not divisible by " f"scale={fan_out_scale}." + ) + ss = bs // fan_out_scale + return tensor[fan_out_slot * ss : (fan_out_slot + 1) * ss].contiguous() + + def _maybe_narrow_attn(tensor, ref_batch): + """Narrow ``attention_mask`` only when its dim-0 matches the input batch. + + Some callers pass attention_mask as ``[b, 1, s, s]`` (batch-first, + narrow the way ``input_ids`` is narrowed); others pass shapes that + broadcast across batch (e.g. ``[1, 1, s, s]`` causal mask). We only + narrow when dim-0 equals the pre-narrowing batch size, leaving + broadcastable masks alone. + """ + if ( + fan_out_scale == 1 + or tensor is None + or not isinstance(tensor, torch.Tensor) + or ref_batch is None + or not isinstance(ref_batch, torch.Tensor) + or tensor.ndim < 1 + or tensor.shape[0] != ref_batch.shape[0] + ): + return tensor + return _maybe_narrow(tensor) + + def _passthrough(batch_idx): + b = all_batches[batch_idx] + input_ids = b.get('input_ids') + return { + 'input_ids': _maybe_narrow(input_ids), + 'labels': _maybe_narrow(b.get('labels')), + 'loss_mask': _maybe_narrow(b.get('loss_mask')), + 'position_ids': _maybe_narrow(b.get('position_ids')), + 'attention_mask': _maybe_narrow_attn(b.get('attention_mask'), input_ids), + 'packing_kwargs': b.get('packing_kwargs'), + } + + if not detached_full: + # Text-only batch: no encoder embeddings to slice + return [{'encoder_embeddings': {}, **_passthrough(mb)} for mb in range(num_microbatches)] + + sample = next(iter(detached_full.values())) + if sample.ndim not in (2, 3): + raise ValueError( + f"encoder output must be 2D [seq*batch, hidden] or 3D " + f"[seq, batch, hidden], got shape={tuple(sample.shape)}" + ) + batch_dim = 1 if sample.ndim == 3 else 0 + total_batch = sample.shape[batch_dim] + if total_batch % num_microbatches != 0: + raise ValueError( + f"Encoder output batch dim ({total_batch}) must be divisible " + f"by num_microbatches ({num_microbatches})" + ) + mb_size = total_batch // num_microbatches + + lm_data = [] + for mb in range(num_microbatches): + s, e = mb * mb_size, (mb + 1) * mb_size + mb_enc = {k: (v[:, s:e, :] if v.ndim == 3 else v[s:e, :]) for k, v in detached_full.items()} + lm_data.append({'encoder_embeddings': mb_enc, **_passthrough(mb)}) + return lm_data + + +def _broadcast_encoder_grad(detached_full, enc_out, pp_group, is_pp_first): + """Broadcast encoder gradient from PP stage 0 to stage 1+ ranks.""" + if pp_group is None or pp_group.size() <= 1: + return + src = dist.get_global_rank(pp_group, 0) + for key in enc_out: + if is_pp_first: + if detached_full[key].grad is None: + raise RuntimeError( + f"No encoder gradient on PP stage 0 for '{key}'; " + f"Phase 2 LLM backward did not populate detached_full.grad." + ) + dist.broadcast(detached_full[key].grad, src=src, group=pp_group) + else: + grad = torch.empty_like(detached_full[key]) + dist.broadcast(grad, src=src, group=pp_group) + detached_full[key].grad = grad + + +def _loss_func(loss_mask, output_tensor): + """Default loss function for the LLM pipeline. + + Returns the 3-tuple ``(local_sum, local_num_tokens, log_dict)`` contract + expected when ``calculate_per_token_loss=True`` is set on the + TransformerConfig. When it is not set, the schedule divides + ``local_sum`` by ``local_num_tokens`` (clamped to 1), so the 3-tuple + form is also safe for standard per-microbatch-mean configs. + """ + if output_tensor is None: + zero_loss = torch.tensor(0.0, device='cuda', requires_grad=True) + zero_count = torch.tensor(0, device='cuda', dtype=torch.int) + return zero_loss, zero_count, {'loss_reduced': 0.0} + masked = output_tensor.float() * loss_mask.float() + local_sum = masked.sum() + local_num_tokens = loss_mask.float().sum().to(torch.int) + return local_sum, local_num_tokens, {'loss_reduced': local_sum.detach().item()} + + +class _CapturingFinalize: + """Capture finalize args the inner PP schedule would have passed. + + The three-phase schedule defers grad finalization until after Phase 3 + runs encoder backward. Replacing the config's ``finalize_model_grads_func`` + with this object absorbs the inner schedule's invocation and stores + ``num_tokens`` (required for ``calculate_per_token_loss=True`` configs + whose finalize hook divides by the global valid-token count) and + ``force_all_reduce`` (preserves any caller-requested all-reduce + semantics on the final microbatch) so the post-Phase-3 call to the + original finalize can forward both. + """ + + def __init__(self): + self.num_tokens = None + self.force_all_reduce = False + + def __call__(self, model_list, num_tokens, *args, **kwargs): + self.num_tokens = num_tokens + self.force_all_reduce = kwargs.get('force_all_reduce', False) + return None + + +@contextmanager +def _deferred_finalize(config): + """Suppress the PP schedule's end-of-run DDP grad sync; yield the + original finalize and a capture object so callers can invoke the + original (with the captured ``num_tokens``) once after Phase 3. + """ + original = config.finalize_model_grads_func + capture = _CapturingFinalize() + config.finalize_model_grads_func = capture + try: + yield original, capture + finally: + config.finalize_model_grads_func = original diff --git a/megatron/core/models/mimo/comm/colocated_communicator.py b/megatron/core/models/mimo/comm/colocated_communicator.py index 4c43dcdf3cd..93df93381d5 100644 --- a/megatron/core/models/mimo/comm/colocated_communicator.py +++ b/megatron/core/models/mimo/comm/colocated_communicator.py @@ -95,12 +95,7 @@ def __init__( elif self.dest_dp_size > self.src_dp_size: self.direction = BridgeDirection.FAN_OUT self.scale = self.dest_dp_size // self.src_dp_size - self.gather_group_ranks = self._build_gather_groups( - iter_size=self.src_dp_size, - sibling_tp_size=self.dest_tp_size, - scale=self.scale, - rank_to_pos=self.rank_to_dest_pos, - ) + self.gather_group_ranks = self._build_fan_out_gather_groups() self.gather_pg, _ = dist.new_subgroups_by_enumeration( self.gather_group_ranks, backend='nccl' ) @@ -128,8 +123,9 @@ def _validate_grids(self): f"src={self.src_grid.rank_offset}, dest={self.dest_grid.rank_offset}" ) - # Per-grid dim checks: tp/dp required; pp and cp (if present) must be 1. - # CP>1 also corrupts dp_idx when iterating get_rank_enum(['tp']) groups. + # Per-grid dim checks: tp/dp required; cp (if present) must be 1. + # Src PP must be 1; dest PP>1 is allowed. CP>1 corrupts dp_idx when + # iterating get_rank_enum(['tp']) groups. for name, grid in [("src", self.src_grid), ("dest", self.dest_grid)]: for required in ('tp', 'dp'): if required not in grid.dim_names: @@ -137,14 +133,16 @@ def _validate_grids(self): f"{name} grid must have '{required}' dimension, " f"got dim_names={grid.dim_names}" ) - for singleton in ('pp', 'cp'): - if singleton in grid.dim_names: - size = grid.shape[grid.dim_names.index(singleton)] - if size != 1: - raise ValueError( - f"{name} {singleton.upper()} must be 1 for " - f"ColocatedBridgeCommunicator, got {size}" - ) + if 'cp' in grid.dim_names: + cp_size = grid.shape[grid.dim_names.index('cp')] + if cp_size != 1: + raise ValueError( + f"{name} CP must be 1 for ColocatedBridgeCommunicator, got {cp_size}" + ) + if 'pp' in self.src_grid.dim_names: + src_pp = self.src_grid.shape[self.src_grid.dim_names.index('pp')] + if src_pp != 1: + raise ValueError(f"src PP must be 1 for ColocatedBridgeCommunicator, got {src_pp}") src_dp = self.src_grid.shape[self.src_grid.dim_names.index('dp')] dest_dp = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')] @@ -158,20 +156,35 @@ def _extract_parallelism_info(self): self.src_dp_size = self.src_grid.shape[self.src_grid.dim_names.index('dp')] self.dest_tp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('tp')] self.dest_dp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')] + self.dest_pp_size = ( + self.dest_grid.shape[self.dest_grid.dim_names.index('pp')] + if 'pp' in self.dest_grid.dim_names + else 1 + ) def _build_rank_mappings(self): self.rank_to_src_pos: Dict[int, Tuple[int, int]] = {} self.rank_to_dest_pos: Dict[int, Tuple[int, int]] = {} + self.rank_to_dest_pp_pos: Dict[int, Tuple[int, int, int]] = {} + self.dest_pp_pos_to_rank: Dict[Tuple[int, int, int], int] = {} src_tp_groups = self.src_grid.get_rank_enum(['tp']) for dp_idx, tp_group in enumerate(src_tp_groups): for tp_idx, rank in enumerate(tp_group): self.rank_to_src_pos[rank] = (dp_idx, tp_idx) - dest_tp_groups = self.dest_grid.get_rank_enum(['tp']) - for dp_idx, tp_group in enumerate(dest_tp_groups): - for tp_idx, rank in enumerate(tp_group): + # Include destination PP when enumerating destination ranks so DP + # indices stay true DP coordinates instead of flattened (dp, pp) + # positions. Fan-out gather groups then stay within one PP stage. + dest_group_dims = ['tp', 'pp'] if 'pp' in self.dest_grid.dim_names else ['tp'] + dest_tp_pp_groups = self.dest_grid.get_rank_enum(dest_group_dims) + for dp_idx, tp_pp_group in enumerate(dest_tp_pp_groups): + for local_idx, rank in enumerate(tp_pp_group): + pp_idx = local_idx // self.dest_tp_size if self.dest_pp_size > 1 else 0 + tp_idx = local_idx % self.dest_tp_size self.rank_to_dest_pos[rank] = (dp_idx, tp_idx) + self.rank_to_dest_pp_pos[rank] = (dp_idx, pp_idx, tp_idx) + self.dest_pp_pos_to_rank[(dp_idx, pp_idx, tp_idx)] = rank @staticmethod def _build_gather_groups( @@ -198,6 +211,21 @@ def _build_gather_groups( groups.append(group_ranks) return groups + def _build_fan_out_gather_groups(self) -> List[List[int]]: + """Build dest-side fan-out gather groups, preserving destination PP stage.""" + groups: List[List[int]] = [] + for src_dp_idx in range(self.src_dp_size): + sibling_dp_indices = range(src_dp_idx * self.scale, (src_dp_idx + 1) * self.scale) + for dest_pp_idx in range(self.dest_pp_size): + for dest_tp_idx in range(self.dest_tp_size): + group_ranks = [] + for dest_dp_idx in sibling_dp_indices: + group_ranks.append( + self.dest_pp_pos_to_rank[(dest_dp_idx, dest_pp_idx, dest_tp_idx)] + ) + groups.append(group_ranks) + return groups + def is_fan_in(self) -> bool: """True if src DP > dest DP (forward all-gathers).""" return self.direction is BridgeDirection.FAN_IN diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py index 411791f1e5c..389aa65a8d0 100644 --- a/megatron/core/models/mimo/config/role.py +++ b/megatron/core/models/mimo/config/role.py @@ -79,7 +79,7 @@ def build( Grids differ → NON_COLOCATED with PP-stage info per module. """ if module_to_grid_map is None or cls._all_grids_colocated(module_to_grid_map): - return cls._colocated(modality_module_names) + return cls._colocated(modality_module_names, module_to_grid_map) return cls._from_grid_map(module_to_grid_map) @staticmethod @@ -89,16 +89,30 @@ def _all_grids_colocated(module_to_grid_map: Dict[str, 'HyperCommGrid']) -> bool return all(g.rank_offset == first.rank_offset and g.size == first.size for g in grids[1:]) @classmethod - def _colocated(cls, modality_module_names: List[str]) -> 'RankRole': - """Colocated layout: every module on every rank, PP=1.""" + def _colocated( + cls, + modality_module_names: List[str], + module_to_grid_map: Optional[Dict[str, 'HyperCommGrid']] = None, + ) -> 'RankRole': + """Colocated layout: every module on every rank. + + When a grid map is supplied, per-module stage info is derived from + each grid's pp group (LLM PP>1 is allowed). With no grid map, every + module is both first and last stage. + """ all_module_names = list(modality_module_names) + [MIMO_LANGUAGE_MODULE_KEY] - return cls( - modules={ - name: ModuleStageInfo(is_first_stage=True, is_last_stage=True) - for name in all_module_names - }, - mode=ModuleLayout.COLOCATED, - ) + modules = {} + for name in all_module_names: + grid = module_to_grid_map.get(name) if module_to_grid_map else None + if grid is not None and 'pp' in grid.dim_names: + pp_group = grid.get_pg('pp') + pp_rank, pp_size = pp_group.rank(), pp_group.size() + modules[name] = ModuleStageInfo( + is_first_stage=(pp_rank == 0), is_last_stage=(pp_rank == pp_size - 1) + ) + else: + modules[name] = ModuleStageInfo(is_first_stage=True, is_last_stage=True) + return cls(modules=modules, mode=ModuleLayout.COLOCATED) @classmethod def _from_grid_map(cls, module_to_grid_map: Dict[str, HyperCommGrid]) -> 'RankRole': diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index bdfe4289dd0..e7695c8b4ba 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -67,6 +67,11 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - # in TP/DP within those ranks. self._build_colocated_communicators() + lang_info = self.role.modules.get(MIMO_LANGUAGE_MODULE_KEY) + self.lm_has_pp = lang_info is not None and not ( + lang_info.is_first_stage and lang_info.is_last_stage + ) + # Use special token IDs from the config self.special_token_ids = ( mimo_config.special_token_ids.copy() if mimo_config.special_token_ids else {} @@ -318,6 +323,7 @@ def forward( labels: Optional[torch.Tensor] = None, modality_inputs: Optional[Dict[str, Dict[str, Any]]] = None, packing_kwargs: Optional[dict] = None, + encoder_embeddings: Optional[Dict[str, torch.Tensor]] = None, ): """Forward pass through the multimodal model. @@ -362,6 +368,20 @@ def forward( input_tensors = getattr(self, 'input_tensors', None) if self.role.mode == ModuleLayout.COLOCATED: + if self.lm_has_pp and input_tensors is not None: + # PP>1 non-first stage: hidden states from P2P + lm_result = self._forward_language_module( + input_ids, + position_ids, + attention_mask, + labels, + {MIMO_LANGUAGE_MODULE_KEY: input_tensors}, + ) + # Unwrap dict for P2P (schedule uses plain tensors, not dicts) + if isinstance(lm_result, dict): + lm_result = lm_result[MIMO_LANGUAGE_MODULE_KEY] + return lm_result, loss_mask + return self._forward_all_modules( input_ids, position_ids, @@ -370,6 +390,7 @@ def forward( labels, modality_inputs, packing_kwargs, + encoder_embeddings=encoder_embeddings, ) if self.role.mode == ModuleLayout.NON_COLOCATED: @@ -519,7 +540,12 @@ def _build_colocated_communicators(self): ) def destroy(self) -> None: - """Release process groups owned by this MimoModel.""" + """Release process groups owned by this MimoModel. + + NCCL caps concurrent communicators, so long-lived or + repeatedly-rebuilt models leak subgroups without explicit + destroy. Tests should call this before ``destroy_all_grids()``. + """ for comm in self.colocated_comms.values(): comm.destroy() self.colocated_comms.clear() @@ -535,6 +561,22 @@ def _apply_colocated_comms(self, modality_embeddings): ) return modality_embeddings + def encode_and_communicate(self, modality_inputs): + """Run encoder forward + colocated TP/DP transform (collective).""" + modality_embeddings = {} + for modality_name, submodule in self.modality_submodules.items(): + if ( + modality_inputs + and modality_name in modality_inputs + and modality_inputs[modality_name] is not None + ): + embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) + if embeddings is not None: + modality_embeddings[modality_name] = embeddings + if self.colocated_comms: + modality_embeddings = self._apply_colocated_comms(modality_embeddings) + return modality_embeddings + def _forward_all_modules( self, input_ids: torch.Tensor, @@ -544,6 +586,7 @@ def _forward_all_modules( labels: Optional[torch.Tensor], modality_inputs: Optional[Dict[str, Dict[str, Any]]], packing_kwargs: Optional[dict] = None, + encoder_embeddings: Optional[Dict[str, torch.Tensor]] = None, ): """Forward pass when all modules are on all ranks (no multi-module PP). @@ -560,26 +603,12 @@ def _forward_all_modules( packed_seq_params.qkv_format = 'thd' logger.debug(f"Packed sequence parameters: {packed_seq_params}") - # 1. Process each modality to get embeddings - modality_embeddings = {} - - for modality_name, submodule in self.modality_submodules.items(): - if ( - modality_inputs - and modality_name in modality_inputs - and modality_inputs[modality_name] is not None - ): - logger.debug(f"Processing {modality_name} modality") - embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) - if embeddings is not None: - modality_embeddings[modality_name] = embeddings - logger.debug( - f"Generated embeddings for {modality_name} with shape {embeddings.shape}" - ) - - # Apply colocated communication if configured (no-op when colocated_comms is empty) - if self.colocated_comms: - modality_embeddings = self._apply_colocated_comms(modality_embeddings) + if encoder_embeddings is not None: + # PP>1 path: encoder forward + communicate already ran in Phase 1; + # reuse the precomputed embeddings for every LLM microbatch. + modality_embeddings = encoder_embeddings + else: + modality_embeddings = self.encode_and_communicate(modality_inputs) # Get text embeddings text_embeddings = self.get_text_embeddings(input_ids, position_ids, self.special_token_ids) diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index c6d3e41aed5..ebdd42effe2 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -850,7 +850,7 @@ def _get_megatron_emerging_optimizer( config, pg_collection, init_state_fn_list=list(init_fns), - model_chunks=model_chunks if config.overlap_param_gather else None, + model_chunks=model_chunks, ) return ChainedOptimizer(results) diff --git a/megatron/core/optimizer/layer_wise_optimizer.py b/megatron/core/optimizer/layer_wise_optimizer.py index d0f64010bad..f60934cee26 100644 --- a/megatron/core/optimizer/layer_wise_optimizer.py +++ b/megatron/core/optimizer/layer_wise_optimizer.py @@ -1,15 +1,18 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import logging -from typing import Callable, List, Optional +import math +from collections import defaultdict +from typing import Callable, Dict, List, Optional, Tuple import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from megatron.core.dist_checkpointing.dict_utils import nested_values from megatron.core.dist_checkpointing.mapping import LocalNonpersistentObject, ShardedStateDict +from megatron.core.distributed.param_and_grad_buffer import group_params_for_buffers from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.utils import get_pg_rank, get_pg_size +from megatron.core.utils import get_pg_rank, get_pg_size, log_single_rank from .clip_grads import count_zeros_fp32, get_grad_norm_fp32 from .optimizer import ( @@ -19,6 +22,13 @@ MegatronOptimizer, ) from .optimizer_config import OptimizerConfig +from .param_layout import ( + FullParamLayout, + PerBufferParamLayout, + bucket_end_divisor, + pad_param_start, + pad_to_divisor, +) logger = logging.getLogger(__name__) @@ -39,6 +49,305 @@ class LayerWiseDistributedOptimizer(ChainedOptimizer): 6. allgather updated params to every rank """ + @staticmethod + def _shard_divisor(data_parallel_world_size: int, ddp_config) -> int: + """Per-shard alignment divisor. + + Guarantees that ``dp_size * shard_size`` satisfies bucket-end alignment + and that every shard start is 64-element aligned (required by + :func:`pad_param_start`). + """ + dp_size = data_parallel_world_size + bucket_divisor = bucket_end_divisor(dp_size, ddp_config.pad_buckets_for_high_nccl_busbw) + return math.lcm(64, bucket_divisor // dp_size) + + @staticmethod + def _compute_per_buffer_param_layout( + params: List[torch.nn.Parameter], + bucket_size: Optional[int], + data_parallel_world_size: int, + ddp_config, + param_indices: Optional[List[int]] = None, + ) -> 'PerBufferParamLayout': + """Compute parameter layout with shard-aligned buckets via size-matching. + + Assigns parameters to ``dp_size`` equal-sized shards within each bucket + so that no parameter is ever split across a shard boundary. + + **Algorithm** (operates in reverse model / backprop order): + + 1. Separate shared-embedding parameters (isolated buckets, emitted first). + 2. Pool the remaining parameters in backprop order, indexed by numel. + 3. Pop the next unassigned parameter and assign it to shard 0. + 4. For shards 1 … ``dp_size - 1``, assign the next unassigned parameter + of the same numel (also in backprop order). If none is available, + insert padding of that numel. Every shard grows by the same amount, + so all shards stay the same size. + 5. When the bucket total reaches *bucket_size*, finalise the bucket + (pad shard size to :meth:`_shard_divisor`) and start a new one. + 6. Repeat from 3 until all parameters are assigned. + + Because repeated layers produce many parameters of the same shape, + size-matching naturally keeps whole parameters together without any + name-parsing heuristic. Padding overhead is low (depending on number + of layers and number of shards) — zero when every shape group has a + count divisible by ``dp_size``. + + Args: + params: Parameters in model-definition (forward) order. + bucket_size: Approximate elements per bucket (``None`` → single bucket). + data_parallel_world_size: Size of the data-parallel group. + ddp_config: :class:`DistributedDataParallelConfig`. + param_indices: Optional per-param dtype indices (passed through). + + Returns: + :class:`PerBufferParamLayout` with shard-aligned buckets. + """ + dp_size = data_parallel_world_size + shard_divisor = LayerWiseDistributedOptimizer._shard_divisor(dp_size, ddp_config) + + # -- 0. Separate shared-embedding params. ------------------------- + shared_embedding_params: List[torch.nn.Parameter] = [] + regular_params: List[torch.nn.Parameter] = [] + total_param_numel = 0 + for param in params: + total_param_numel += param.data.nelement() + if getattr(param, 'shared_embedding', False): + shared_embedding_params.append(param) + else: + regular_params.append(param) + + # -- 1. Build backprop-order pool & per-size index. --------------- + pool = list(reversed(regular_params)) + assigned_param_ids: set[int] = set() # id(param) of assigned params + + size_groups: Dict[int, List[torch.nn.Parameter]] = defaultdict(list) + for param in pool: + size_groups[param.data.nelement()].append(param) + size_cursors: Dict[int, int] = defaultdict(int) + + overall_cursor = 0 + + def _next_unassigned() -> Optional[torch.nn.Parameter]: + nonlocal overall_cursor + while overall_cursor < len(pool): + if id(pool[overall_cursor]) not in assigned_param_ids: + return pool[overall_cursor] + overall_cursor += 1 + return None + + def _next_with_size(param_numel: int) -> Optional[torch.nn.Parameter]: + """Next unassigned param of size *param_numel* in backprop order.""" + group = size_groups[param_numel] + cursor = size_cursors[param_numel] + while cursor < len(group): + if id(group[cursor]) not in assigned_param_ids: + size_cursors[param_numel] = cursor + return group[cursor] + cursor += 1 + size_cursors[param_numel] = cursor + return None + + # -- 2. Output accumulators and per-bucket shard state. ---------- + param_index_map: Dict[torch.nn.Parameter, Tuple[int, int, int]] = {} + bucket_indices: List[Tuple[int, int]] = [] + per_bucket_numel_unpadded: List[int] = [] + buffer_cursor = 0 # write position in the contiguous buffer + bucket_id = 0 + + # Per-shard state for the bucket currently being built. + # `shard_assignments[i]` holds an ordered list of (param | None, numel) + # entries to be written into shard i; a `None` entry is empty padding + # that keeps every shard the same size. + shard_assignments: List[List[Tuple[Optional[torch.nn.Parameter], int]]] = [ + [] for _ in range(dp_size) + ] + shard_cursor = 0 # position within each shard (identical for all shards) + bucket_numel_unpadded = 0 + size_match_padding_numel = 0 # elements used for empty-shard-slot padding + + def _finalize_bucket() -> None: + nonlocal buffer_cursor, bucket_id, shard_assignments + nonlocal shard_cursor, bucket_numel_unpadded + if shard_cursor == 0: + return + padded_shard_size = pad_to_divisor(shard_cursor, shard_divisor) + bucket_start_index = buffer_cursor + + for shard_id in range(dp_size): + shard_start_index = bucket_start_index + shard_id * padded_shard_size + cursor = shard_start_index + for param, numel in shard_assignments[shard_id]: + cursor = pad_param_start(cursor) + if param is not None: + param_index_map[param] = (cursor, cursor + numel, bucket_id) + cursor += numel + + bucket_end_index = bucket_start_index + dp_size * padded_shard_size + bucket_indices.append((bucket_start_index, bucket_end_index)) + per_bucket_numel_unpadded.append(bucket_numel_unpadded) + buffer_cursor = bucket_end_index + bucket_id += 1 + + shard_assignments = [[] for _ in range(dp_size)] + shard_cursor = 0 + bucket_numel_unpadded = 0 + + # -- 3. Emit one isolated bucket per shared-embedding param. ----- + # Shared (tied) embeddings need their own bucket — typically because + # input and output embeddings are tied across pipeline-parallel + # stages and need a cross-stage all-reduce. Each shared embedding + # occupies shard 0 of its bucket alone; shards 1..dp_size-1 are + # filled with empty (padding) slots of the same numel so the bucket + # is shard-aligned and the embedding fits entirely within shard 0. + # + # NOTE: This is expensive. Padding cost per shared embedding is + # (dp_size - 1) * pad_to_divisor(numel, shard_divisor) elements, + # which for a vocab x hidden embedding (e.g. 128k x 8192) at dp_size + # = 8 is roughly 7 * (vocab * hidden) elements — many GBs of the + # param buffer (and again of the grad buffer) per shared embedding. + # The cost is unavoidable while preserving the "no parameter crosses + # a shard boundary" invariant the layerwise scheme depends on for + # correct reduce-scatter + local optimizer step. + for param in reversed(shared_embedding_params): + param_numel = param.data.nelement() + assigned_param_ids.add(id(param)) + shard_assignments[0].append((param, param_numel)) + bucket_numel_unpadded += param_numel + # No size-matching: each shared embedding must be alone in its + # bucket. Pad shards 1..dp_size-1 with same-size empty slots. + for shard_id in range(1, dp_size): + shard_assignments[shard_id].append((None, param_numel)) + size_match_padding_numel += param_numel + shard_cursor = pad_param_start(shard_cursor) + param_numel + _finalize_bucket() + + # -- 4. Size-matching loop for regular params. -------------------- + while True: + param = _next_unassigned() + if param is None: + break + + param_numel = param.data.nelement() + assigned_param_ids.add(id(param)) + shard_assignments[0].append((param, param_numel)) + bucket_numel_unpadded += param_numel + + for shard_id in range(1, dp_size): + # Prefer an exact-numel peer; this gives the cleanest layout + # (no inner-shard padding). + matched_param = _next_with_size(param_numel) + if matched_param is not None: + assigned_param_ids.add(id(matched_param)) + shard_assignments[shard_id].append((matched_param, param_numel)) + bucket_numel_unpadded += param_numel + continue + + # No exact peer. Greedily pack as many smaller params from the + # queue as fit within this shard slot (sized to ``param_numel``). + # Cuts overhead from unique-large seeds (e.g. an embedding) + # that would otherwise force ``(dp_size - 1) * param_numel`` of + # empty padding. + useful_in_slot = 0 + slot_cursor = 0 + while True: + candidate_param = _next_unassigned() + if candidate_param is None: + break + candidate_numel = candidate_param.data.nelement() + candidate_start = pad_param_start(slot_cursor) + if candidate_start + candidate_numel > param_numel: + break + assigned_param_ids.add(id(candidate_param)) + shard_assignments[shard_id].append((candidate_param, candidate_numel)) + bucket_numel_unpadded += candidate_numel + slot_cursor = candidate_start + candidate_numel + useful_in_slot += candidate_numel + + # Pad the remainder of the slot up to ``param_numel``. + padding_start = pad_param_start(slot_cursor) + padding_size = param_numel - padding_start + if padding_size > 0: + shard_assignments[shard_id].append((None, padding_size)) + size_match_padding_numel += param_numel - useful_in_slot + + shard_cursor = pad_param_start(shard_cursor) + param_numel + + if bucket_size is not None: + bucket_total = dp_size * pad_to_divisor(shard_cursor, shard_divisor) + if bucket_total >= bucket_size: + _finalize_bucket() + + _finalize_bucket() + + # -- 5. Log padding overhead. ------------------------------------ + total_buffer_numel = bucket_indices[-1][1] if bucket_indices else 0 + total_padding = total_buffer_numel - total_param_numel + alignment_and_shard_end_padding = total_padding - size_match_padding_numel + log_single_rank( + logger, + logging.INFO, + f"Layerwise param layout: {len(params)} params, " + f"{len(bucket_indices)} buckets, " + f"dp_size={dp_size}, " + f"total_param_numel={total_param_numel}, " + f"total_buffer_numel={total_buffer_numel}, " + f"total_padding={total_padding} " + f"(size_match={size_match_padding_numel}, " + f"alignment+shard_end={alignment_and_shard_end_padding}), " + f"overhead={total_padding / max(total_param_numel, 1) * 100:.1f}%", + ) + + return PerBufferParamLayout( + param_index_map=param_index_map, + bucket_indices=bucket_indices, + per_bucket_numel_unpadded=per_bucket_numel_unpadded, + param_indices=param_indices if param_indices is not None else [], + ) + + @staticmethod + def compute_full_param_layout( + params: List[torch.nn.Parameter], + bucket_size: Optional[int], + data_parallel_world_size: int, + ddp_config, + expert_data_parallel_world_size: Optional[int] = None, + ) -> 'FullParamLayout': + """Compute parameter layouts for all buffer groups. + + Groups parameters by :class:`BufferKey` via :func:`group_params_for_buffers` + and produces a layerwise shard-aligned size-matching layout per buffer. + Every parameter stays within a single shard so the local optimizer step + (e.g. Newton-Schulz iteration for Muon) can run on whole tensors. + + Args: + params: All parameters to lay out. + bucket_size: Approximate elements per bucket (``None`` → single bucket). + data_parallel_world_size: DP group size for dense parameters. + ddp_config: :class:`DistributedDataParallelConfig`. + expert_data_parallel_world_size: Expert DP group size (defaults to + ``data_parallel_world_size``). + + Returns: + :class:`FullParamLayout` with a :class:`PerBufferParamLayout` per buffer group. + """ + buffer_groups = group_params_for_buffers(params, ddp_config.grad_reduce_in_fp32) + layouts = {} + for buffer_key, (group_params, param_indices) in buffer_groups.items(): + if buffer_key.is_expert_parallel: + dp_world_size = ( + expert_data_parallel_world_size + if expert_data_parallel_world_size is not None + else data_parallel_world_size + ) + else: + dp_world_size = data_parallel_world_size + + layouts[buffer_key] = LayerWiseDistributedOptimizer._compute_per_buffer_param_layout( + group_params, bucket_size, dp_world_size, ddp_config, param_indices + ) + return FullParamLayout(layouts=layouts) + def __init__( self, optimizers: List[MegatronOptimizer], @@ -55,15 +364,34 @@ def __init__( config: OptimizerConfig. pg_collection: ProcessGroupCollection. init_state_fn_list: List of init state functions. - model_chunks: DDP-wrapped model chunks (needed for overlap_param_gather). + model_chunks: DDP-wrapped model chunks. """ self.pg_collection = pg_collection - self.shard_params(optimizers) + + full_param_layouts = None + if model_chunks is not None: + full_param_layouts = [ + chunk.full_param_layout + for chunk in model_chunks + if hasattr(chunk, 'full_param_layout') and chunk.full_param_layout is not None + ] or None + self.shard_params(optimizers, full_param_layouts) + + # When a full_param_layout is available, ddp_config.use_distributed_optimizer + # is True and model params are views into the DDP param buffer. After the + # optimizer step copies updated fp32 main params → bf16 model params, the + # buffer is already up-to-date in-place. We can use DDP's buffer-based + # all-gather (start_param_sync) instead of the flatten/unflatten allgather_params + # path. + self.use_buffer_param_sync = full_param_layouts is not None # Set up overlap param gather using DDP bucket infrastructure. self.overlap_param_gather = config.overlap_param_gather - if self.overlap_param_gather: + if self.overlap_param_gather and not self.use_buffer_param_sync: + # Legacy path: set up per-bucket param lists for variable-size all-gather. + # When use_buffer_param_sync is True, the standard distributed optimizer + # all-gather path is used and this setup is not needed. assert ( model_chunks is not None ), "model_chunks must be provided if overlap_param_gather is True" @@ -91,6 +419,13 @@ def __init__( super().__init__(optimizers) + # Assign self.model_chunks AFTER super().__init__: ChainedOptimizer.__init__ + # resets self.model_chunks to [] and then repopulates only from chained + # children that have a model_chunks attribute (DistOpt does, Float16-wrapped + # raw torch optimizers do not). Set it here so LayerWise.step's + # ``for model_chunk in self.model_chunks`` actually iterates. + self.model_chunks = model_chunks if model_chunks is not None else [] + # TODO(kunlun, deyuf): potential future perf optimization # since allreduce is unchanged and handled by megatron DDP, they're already in # contiguous gbuf. So instead of shard param by layer randomly, we can shard by @@ -98,32 +433,122 @@ def __init__( # This way each rank do some duplicated work but allgather_v is no longer needed # All current distopt optimization can also be potentially applied - def shard_params(self, optimizers): - """Shard all params into lists by rank.""" - # list of parameter are sorted by numel and assigned to ranks in ping-pong style - # example of 4 ranks and 10 parameters p0-p9 after sorting, then dp_cp_params_list will be - # [[p0, p7, p8], [p1, p6, p9], [p2, p5], [p3, p4]] + def shard_params(self, optimizers, full_param_layouts=None): + """Shard params across ranks according to the computed param layout. - # simplify when dp_cp group size is 1 - if get_pg_size(self.pg_collection.dp_cp) == 1: + Each param's shard assignment is derived from the :class:`FullParamLayout` + stored on the DDP model chunks. Within each bucket the buffer is divided + into ``dp_size`` equal shards; a param's shard index is determined by its + position in the buffer. + + Falls back to the legacy ping-pong-by-numel strategy when no layout is + available (e.g. ``dp_size == 1`` or no DDP wrapper). + + Args: + optimizers: Optimizers whose param groups will be narrowed to + the local rank's shard. + full_param_layouts: List of :class:`FullParamLayout` (one per model + chunk). ``None`` triggers the legacy fallback. + """ + # Simplify when dp_cp group size is 1. + dp_cp_size = get_pg_size(self.pg_collection.dp_cp) + if dp_cp_size == 1: self.dp_cp_params_list = None self.expt_dp_params_list = None return - dp_cp_idx, expt_dp_idx = 0, 0 - dp_cp_size = get_pg_size(self.pg_collection.dp_cp) expt_dp_size = get_pg_size(self.pg_collection.expt_dp) - # create ping-pong style loop so memory is more balanced + + if full_param_layouts is not None: + self._shard_params_from_layout(optimizers, full_param_layouts, dp_cp_size, expt_dp_size) + else: + self._shard_params_ping_pong(optimizers, dp_cp_size, expt_dp_size) + + def _shard_params_from_layout(self, optimizers, full_param_layouts, dp_cp_size, expt_dp_size): + """Derive shard assignments from the param layout.""" + dp_cp_rank = get_pg_rank(self.pg_collection.dp_cp) + expt_dp_rank = get_pg_rank(self.pg_collection.expt_dp) + + self.dp_cp_params_list = [[] for _ in range(dp_cp_size)] + self.expt_dp_params_list = [[] for _ in range(expt_dp_size)] + + # Map each param to its shard index. + param_to_shard: Dict[torch.nn.Parameter, int] = {} + for full_layout in full_param_layouts: + for buffer_key, layout in full_layout.layouts.items(): + dp_size = expt_dp_size if buffer_key.is_expert_parallel else dp_cp_size + for param, ( + param_start_index, + param_end_index, + bucket_id, + ) in layout.param_index_map.items(): + bucket_start_index, bucket_end_index = layout.bucket_indices[bucket_id] + shard_size = (bucket_end_index - bucket_start_index) // dp_size + shard_id = (param_start_index - bucket_start_index) // shard_size + shard_end_index = bucket_start_index + (shard_id + 1) * shard_size + assert param_end_index <= shard_end_index, ( + f"Param (shape={tuple(param.shape)}, numel={param.numel()}) at " + f"({param_start_index}, {param_end_index}) crosses shard boundary " + f"in bucket ({bucket_start_index}, {bucket_end_index}) with " + f"shard_size={shard_size}, shard_id={shard_id}, " + f"shard_end_index={shard_end_index}. The layout must keep every " + f"param fully within one shard." + ) + param_to_shard[param] = shard_id + + # Collect all param groups and assign params to per-rank lists. + param_groups = [] + for optimizer in optimizers: + param_groups += optimizer.param_groups + param_groups_this_rank = [[] for _ in param_groups] + + for group_index, group in enumerate(param_groups): + is_expert = group.get("is_expert_parallel", False) + local_rank = expt_dp_rank if is_expert else dp_cp_rank + params_list = self.expt_dp_params_list if is_expert else self.dp_cp_params_list + + for param in group["params"]: + assert param in param_to_shard, ( + f"Optimizer param (shape={tuple(param.shape)}, numel={param.numel()}) " + f"not found in any param layout. Ensure all optimizer params are " + f"included in the full_param_layout passed to DDP." + ) + shard_id = param_to_shard[param] + params_list[shard_id].append(param) + if shard_id == local_rank: + param_groups_this_rank[group_index].append(param) + + # Now we modify the group to only handle local params. + for group, local_params in zip(param_groups, param_groups_this_rank): + group["params"] = local_params + + # Simplify when expt_dp group size is 1 or expert parallel is off. + if expt_dp_size == 1 or len(self.expt_dp_params_list[0]) == 0: + self.expt_dp_params_list = None + + def _shard_params_ping_pong(self, optimizers, dp_cp_size, expt_dp_size): + """Legacy ping-pong-by-numel shard assignment (no layout available). + + Legacy: this method is a fallback for when no ``full_param_layout`` + is provided. Once all call sites supply a layout, this can be removed + in favor of :meth:`_shard_params_from_layout`. + + List of parameters are sorted by numel and assigned to ranks in ping-pong style. + Example of 4 ranks and 10 parameters p0-p9 after sorting, then dp_cp_params_list + will be [[p0, p7, p8], [p1, p6, p9], [p2, p5], [p3, p4]]. + """ + dp_cp_idx, expt_dp_idx = 0, 0 + # Create ping-pong style loop so memory is more balanced. dp_cp_loop = list(range(dp_cp_size)) + list(range(dp_cp_size))[::-1] expt_dp_loop = list(range(expt_dp_size)) + list(range(expt_dp_size))[::-1] self.dp_cp_params_list = [[] for _ in range(dp_cp_size)] self.expt_dp_params_list = [[] for _ in range(expt_dp_size)] - # get all param groups + # Get all param groups. param_groups = [] for optimizer in optimizers: param_groups += optimizer.param_groups - # sort param in all groups by param numel and assign to each rank evenly + # Sort param in all groups by param numel and assign to each rank evenly. param_list = [] for group_index, group in enumerate(param_groups): for p in group["params"]: @@ -131,7 +556,7 @@ def shard_params(self, optimizers): param_list.sort(key=lambda x: x[0].numel()) param_groups_this_rank = [[] for g in param_groups] - # assign params to rank in ping-pong style loop + # Assign params to rank in ping-pong style loop. for p, group_index in param_list: if param_groups[group_index].get("is_expert_parallel", False): if expt_dp_loop[expt_dp_idx] == get_pg_rank(self.pg_collection.expt_dp): @@ -144,17 +569,23 @@ def shard_params(self, optimizers): self.dp_cp_params_list[dp_cp_loop[dp_cp_idx]].append(p) dp_cp_idx = (dp_cp_idx + 1) % len(dp_cp_loop) - # now we modify the group to only handle local params + # Now we modify the group to only handle local params. for groups, params in zip(param_groups, param_groups_this_rank): groups["params"] = params - # simplify when expt_dp group size is 1 or expert parallel is off + # Simplify when expt_dp group size is 1 or expert parallel is off. if expt_dp_size == 1 or len(self.expt_dp_params_list[0]) == 0: self.expt_dp_params_list = None def set_bucket_layerwise_params_list(self, model_chunks): """Map sharded params to DDP buckets for async all-gather. + Legacy: only used by the variable-size all-gather path + (``use_buffer_param_sync=False``). Once all call sites supply a + ``full_param_layout``, this can be removed — the standard distributed + optimizer buffer all-gather handles param sync without per-bucket + param lists. + For each bucket in each model chunk's bucket groups, build per-rank param lists by cross-referencing the layer-wise sharded param lists with the bucket's params. @@ -193,7 +624,13 @@ def set_bucket_layerwise_params_list(self, model_chunks): @torch.no_grad() def allgather_params(self) -> None: - """All-gather updated params from all ranks.""" + """All-gather updated params from all ranks. + + Legacy: only used when ``use_buffer_param_sync=False``. Once all + call sites supply a ``full_param_layout``, this can be removed — the + standard distributed optimizer buffer all-gather (via + ``start_param_sync``) replaces this flatten/unflatten path. + """ # helper function to flatten local params, all-gather, # unflatten and copy to model params @@ -278,13 +715,27 @@ def count_zeros(self): @torch.no_grad() def step(self): # type: ignore[no-untyped-def] - """step function for layer-wise optimizer.""" + """step function for layer-wise optimizer. + + NOTE: bypassed when this optimizer is a child of an outer + ChainedOptimizer; in that case the sibling DistributedOptimizer's + step_with_ready_grads handles the param sync. + """ update_successful, grad_norm, num_zeros_in_grad = super().step() - # All gather updated params. If overlap_param_gather is True, the allgather + # All-gather updated params. If overlap_param_gather is True, the all-gather # is deferred to the forward pre-hooks via DDP bucket infrastructure. if not self.overlap_param_gather: - self.allgather_params() + if self.use_buffer_param_sync: + # Model params are views into the DDP param buffer + # (ddp_config.use_distributed_optimizer=True). The optimizer step + # already copied updated fp32 main params → bf16 model params (= + # buffer views), so the buffer is up-to-date. Trigger the standard + # buffer all-gather (matches DistributedOptimizer's call site). + for model_chunk in self.model_chunks: + model_chunk.start_param_sync() + else: + self.allgather_params() return update_successful, grad_norm, num_zeros_in_grad diff --git a/megatron/core/optimizer/param_layout.py b/megatron/core/optimizer/param_layout.py index 6ebcc348f84..543af88f325 100644 --- a/megatron/core/optimizer/param_layout.py +++ b/megatron/core/optimizer/param_layout.py @@ -26,15 +26,20 @@ def pad_param_start(param_start_index: int) -> int: return pad_to_divisor(param_start_index, 64) +def bucket_end_divisor(data_parallel_world_size: int, pad_for_high_nccl_busbw: bool) -> int: + """Divisor used to pad bucket ends for DP-divisibility (and optional NCCL busbw).""" + if pad_for_high_nccl_busbw: + return math.lcm(data_parallel_world_size, 128, 2**16) + return math.lcm(data_parallel_world_size, 128) + + def pad_bucket_end( bucket_end_index: int, data_parallel_world_size: int, pad_for_high_nccl_busbw: bool ) -> int: """Pad bucket end for DP-divisibility (and optionally high NCCL bus bandwidth).""" - if pad_for_high_nccl_busbw: - divisor = math.lcm(data_parallel_world_size, 128, 2**16) - else: - divisor = math.lcm(data_parallel_world_size, 128) - return pad_to_divisor(bucket_end_index, divisor) + return pad_to_divisor( + bucket_end_index, bucket_end_divisor(data_parallel_world_size, pad_for_high_nccl_busbw) + ) @dataclass(frozen=True) diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index 0e14251c5aa..390a164de9d 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -356,6 +356,7 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): kv_cache_management_mode=KVCacheManagementMode(args.rl_kv_cache_management_mode), cuda_graph_mixed_prefill_count=args.inference_dynamic_batching_cuda_graph_mixed_prefill_count, # pylint: disable=line-too-long use_cuda_graphs_for_non_decode_steps=not args.decode_only_cuda_graphs, + cuda_graph_all_prefills=args.inference_cuda_graph_all_prefills, static_kv_memory_pointers=args.rl_persist_cuda_graphs, max_sequence_length=max_sequence_length, mamba_inference_state_config=mamba_inference_state_config, diff --git a/megatron/legacy/fp16_deprecated/loss_scaler.py b/megatron/legacy/fp16_deprecated/loss_scaler.py deleted file mode 100755 index cb64aa92892..00000000000 --- a/megatron/legacy/fp16_deprecated/loss_scaler.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""For backward compatibility, we need the class definitions to deserialize.""" - -class LossScaler: - def __init__(self, scale=1): - self.cur_scale = scale - -class DynamicLossScaler: - def __init__(self, - init_scale=2**32, - scale_factor=2., - scale_window=1000, - min_scale=1, - delayed_shift=1, - consecutive_hysteresis=False): - self.cur_scale = init_scale - self.cur_iter = 0 - self.last_overflow_iter = -1 - self.scale_factor = scale_factor - self.scale_window = scale_window - self.min_scale = min_scale - self.delayed_shift = delayed_shift - self.cur_hysteresis = delayed_shift - self.consecutive_hysteresis = consecutive_hysteresis - diff --git a/megatron/rl/inference/megatron.py b/megatron/rl/inference/megatron.py index cec693a138f..3e8c8e90d02 100644 --- a/megatron/rl/inference/megatron.py +++ b/megatron/rl/inference/megatron.py @@ -113,7 +113,7 @@ async def launch(cls, model: GPTModel, **kwargs): tokenizer=inference_engine.controller.tokenizer, rank=dist.get_rank(), server_port=kwargs.get('port', 8294), - parsers=[], + parsers=args.rl_inference_parsers, verbose=kwargs.get('verbose', False), ) else: diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index be3894999b4..f217c5c4a2d 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1904,6 +1904,11 @@ def _add_inference_args(parser): group.add_argument('--decode-only-cuda-graphs', action='store_true', default=False, help='Only use cuda graphs for decode-only steps, not prefill and mixed steps.') + group.add_argument('--inference-cuda-graph-all-prefills', + action='store_true', default=False, + help='Extend prefill/mixed CUDA graph capture up to `max_tokens`. ' + 'By default, all graphs are limited by the decode limit of ' + '`max_requests * (num_speculative_tokens + 1)`.') group.add_argument('--inference-dynamic-batching-unified-memory-level', type=int, default=0, choices=[0, 1], help='Set unified memory usage within the dynamic ' @@ -1953,9 +1958,6 @@ def _add_inference_args(parser): help='GPU memory budget (in GB) for the Mamba state cache ' 'used by prefix caching on hybrid models. When set, Mamba ' 'states at block boundaries are cached for reuse.') - group.add_argument('--inference-dynamic-batching-cuda-graph-max-tokens', - type=int, default=16384, - help='Maximum number of tokens to capture in a cuda graph.') group.add_argument('--inference-dynamic-batching-cuda-graph-mixed-prefill-count', type=int, default=16, help='Number of mixed prefill requests to capture in a cuda graph.') diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index d4dae645e76..43a0eb9d12a 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -596,7 +596,12 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 # Collect args, model, RNG. + # For LEGACY checkpoints, every unique (tp_rank, ep_rank) shard must be written by + # exactly one rank. Neither dp_rank==0 nor edp_rank==0 alone covers all shards when + # the dense and expert parallelism layouts disagree (e.g. TP > EP*ETP); the union + # does, with at most one rank per (tp_rank, ep_rank) inside any DP group. if not torch.distributed.is_initialized() \ + or mpu.get_data_parallel_rank() == 0 \ or mpu.get_expert_data_parallel_rank() == 0 \ or ckpt_type != CheckpointType.LEGACY: if ckpt_type != CheckpointType.LEGACY: @@ -1366,21 +1371,6 @@ def _load_base_checkpoint( checkpoint_name = get_checkpoint_name(load_dir, iteration, release, return_base_dir=False) try: state_dict = torch.load(checkpoint_name, map_location='cpu') - except ModuleNotFoundError: - from megatron.legacy.fp16_deprecated import loss_scaler - - # For backward compatibility. - if not rank0: - print_rank_0(' > deserializing using the old code structure ...') - sys.modules['fp16.loss_scaler'] = sys.modules['megatron.legacy.fp16_deprecated.loss_scaler'] - sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ - 'megatron.legacy.fp16_deprecated.loss_scaler' - ] - sys.modules['megatron.model'] = sys.modules['megatron.legacy.model'] - state_dict = torch.load(checkpoint_name, map_location='cpu') - sys.modules.pop('fp16.loss_scaler', None) - sys.modules.pop('megatron.fp16.loss_scaler', None) - sys.modules.pop('megatron.model', None) except Exception as e: print('could not load the checkpoint') print(e) diff --git a/megatron/training/training.py b/megatron/training/training.py index c6cab8df952..5c685d8a5c5 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -52,6 +52,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): import torch.distributed from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer +from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer from megatron.core.optimizer_param_scheduler import get_canonical_lr_for_logging from .log_handler import CustomHandler @@ -1308,6 +1309,114 @@ def update_train_iters(args): print_rank_0(f'setting training iterations to {args.train_iters}') +def wrap_model_chunks_with_ddp( + model_chunks, + config, + ddp_config, + *, + use_layer_wise_distributed_optimizer=False, + use_layer_wise_param_layout=True, + DP=DDP, + pg_collection=None, + bucket_sizes=None, + disable_bucketing_per_chunk=None, +): + """Wrap each model chunk in DDP, pre-computing per-chunk param layouts as needed. + + Centralises the DDP-wrapping wiring shared between :func:`get_model` and + unit tests. + + For ``use_layer_wise_distributed_optimizer=True`` and ``use_layer_wise_param_layout=True``: + forces ``ddp_config.use_distributed_optimizer=True`` (mutated in place; needed + for reduce-scatter), and computes per-chunk shard-aligned layouts via + :meth:`LayerWiseDistributedOptimizer.compute_full_param_layout`. With + ``use_layer_wise_param_layout=False``, no layout is supplied and LayerWise falls back + to its legacy ``allgather_params`` sync path. + + For non-layerwise with ``ddp_config.use_distributed_optimizer=True``: + computes per-chunk byte-level layouts via + :meth:`DistributedOptimizer.compute_full_param_layout`. + + Otherwise: no layouts are computed. + + Layouts are only computed when ``DP is DDP`` (i.e. the standard + ``DistributedDataParallel``); FSDP variants don't accept + ``full_param_layout``. + + Args: + model_chunks: List of model chunks to wrap (un-DDP-wrapped). + config: :class:`TransformerConfig`. + ddp_config: :class:`DistributedDataParallelConfig`. Mutated in place when + ``use_layer_wise_distributed_optimizer=True`` and ``use_layer_wise_param_layout=True``. + use_layer_wise_distributed_optimizer: Whether the layerwise wiring runs. + use_layer_wise_param_layout: When ``use_layer_wise_distributed_optimizer=True``, + controls whether to compute and supply a shard-aligned param layout + to DDP. ``False`` keeps LayerWise on its legacy sync path. + DP: The DDP class to construct (``DistributedDataParallel`` or an FSDP + variant). + pg_collection: Optional :class:`ProcessGroupCollection`. Forwarded to + FSDP-style DPs only. + bucket_sizes: Optional per-chunk bucket size override; defaults to + ``[ddp_config.bucket_size] * len(model_chunks)``. + disable_bucketing_per_chunk: Optional per-chunk disable_bucketing flag; + defaults to ``[False] * len(model_chunks)``. + + Returns: + List of DDP-wrapped chunks. + """ + n = len(model_chunks) + if bucket_sizes is None: + bucket_sizes = [ddp_config.bucket_size] * n + if disable_bucketing_per_chunk is None: + disable_bucketing_per_chunk = [False] * n + + # Compute per-chunk layouts (DDP only). + per_chunk_layouts = [None] * n + if DP is DDP: + if use_layer_wise_distributed_optimizer and use_layer_wise_param_layout: + ddp_config.use_distributed_optimizer = True + compute_layout = LayerWiseDistributedOptimizer.compute_full_param_layout + elif not use_layer_wise_distributed_optimizer and ddp_config.use_distributed_optimizer: + compute_layout = DistributedOptimizer.compute_full_param_layout + else: + compute_layout = None + if compute_layout is not None: + data_parallel_world_size = mpu.get_data_parallel_world_size( + with_context_parallel=True + ) + expert_data_parallel_world_size = mpu.get_expert_data_parallel_world_size() + for i, (chunk, bucket_size) in enumerate(zip(model_chunks, bucket_sizes)): + all_params = [p for p in chunk.parameters() if p.requires_grad] + per_chunk_layouts[i] = compute_layout( + all_params, + bucket_size, + data_parallel_world_size, + ddp_config, + expert_data_parallel_world_size=expert_data_parallel_world_size, + ) + + # Wrap each chunk. + wrapped = [] + for chunk, layout, disable_bucketing in zip( + model_chunks, per_chunk_layouts, disable_bucketing_per_chunk + ): + chunk_kwargs = {} + if pg_collection is not None and DP is not DDP: + chunk_kwargs["pg_collection"] = pg_collection + if layout is not None: + chunk_kwargs["full_param_layout"] = layout + wrapped.append( + DP( + config=config, + ddp_config=ddp_config, + module=chunk, + disable_bucketing=disable_bucketing, + **chunk_kwargs, + ) + ) + return wrapped + + def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True, config=None, pg_collection=None): """Build the model.""" args = get_args() @@ -1478,6 +1587,19 @@ def build_model(): if not ddp_config.overlap_grad_reduce: ddp_config.bucket_size = None + # Compute per-chunk bucket sizes / disable_bucketing flags. Bucketing is + # disabled for non-first chunks, when overlap_param_gather_with_optimizer_step + # is on, or for non-zero pipeline-parallel ranks. + pp_rank = mpu.get_pipeline_model_parallel_rank() + per_chunk_disable_bucketing = [ + (chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step + for chunk_idx in range(len(model)) + ] + per_chunk_bucket_sizes = [ + None if (disable or pp_rank > 0) else ddp_config.bucket_size + for disable in per_chunk_disable_bucketing + ] + # Setup stream for ddp initialization. The side-stream may be necessary for cuda graph # capture support with DDP, but we sync it with the current stream to avoid races. ddp_stream = torch.cuda.Stream() @@ -1485,53 +1607,19 @@ def build_model(): ddp_stream.wait_stream(torch.cuda.current_stream()) # Make ddp_stream start after whatever the default stream already queued with torch.cuda.stream(ddp_stream): - # Megatron-FSDP reads dtypes from ddp_config; pass pg_collection for AG/RS overlap. - dp_init_kwargs = {} - if args.use_megatron_fsdp: - dp_init_kwargs["pg_collection"] = pg_collection - - wrapped_model = [] - for model_chunk_idx, model_chunk in enumerate(model): - chunk_kwargs = dict(dp_init_kwargs) - disable_bucketing = ( - (model_chunk_idx > 0) - or args.overlap_param_gather_with_optimizer_step - ) - - # Pre-compute parameter layouts for the distributed optimizer. - # Only pass to DDP; FSDP variants don't accept full_param_layout. - if args.use_distributed_optimizer and DP is DDP: - all_params = [ - p for p in model_chunk.parameters() if p.requires_grad - ] - pp_rank = mpu.get_pipeline_model_parallel_rank() - effective_bucket_size = ( - None - if disable_bucketing or pp_rank > 0 - else ddp_config.bucket_size - ) - chunk_kwargs["full_param_layout"] = ( - DistributedOptimizer.compute_full_param_layout( - all_params, - effective_bucket_size, - mpu.get_data_parallel_world_size(with_context_parallel=True), - ddp_config, - expert_data_parallel_world_size=( - mpu.get_expert_data_parallel_world_size() - ), - ) - ) - - wrapped_model.append( - DP( - config=config, - ddp_config=ddp_config, - module=model_chunk, - disable_bucketing=disable_bucketing, - **chunk_kwargs, - ) - ) - model = wrapped_model + model = wrap_model_chunks_with_ddp( + model, + config, + ddp_config, + use_layer_wise_distributed_optimizer=getattr( + args, 'use_layer_wise_distributed_optimizer', False + ), + use_layer_wise_param_layout=False, + DP=DP, + pg_collection=pg_collection if args.use_megatron_fsdp else None, + bucket_sizes=per_chunk_bucket_sizes, + disable_bucketing_per_chunk=per_chunk_disable_bucketing, + ) # End of setup_stream # Critical: ensure side-stream work completes before touching params on default stream torch.cuda.current_stream().wait_stream(ddp_stream) @@ -3936,6 +4024,9 @@ def should_disable_forward_pre_hook(args): return ( not args.use_megatron_fsdp and has_optimizer - and (args.use_distributed_optimizer or args.use_layer_wise_distributed_optimizer) + and ( + args.use_distributed_optimizer + or getattr(args, 'use_layer_wise_distributed_optimizer', False) + ) and args.overlap_param_gather ) diff --git a/pyproject.toml b/pyproject.toml index a7bc48ec9a3..e271b7d2686 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -206,7 +206,7 @@ override-dependencies = [ flash_mla = [ { git = "https://github.com/deepseek-ai/FlashMLA", rev = "9edee0c022cd0938148a18e334203b0aab43aa19" }, ] -transformer-engine = { git = "https://github.com/NVIDIA/TransformerEngine.git", rev = "f031cf87bd054c7558b887df7bed93975456667f" } +transformer-engine = { git = "https://github.com/NVIDIA/TransformerEngine.git", rev = "42b840051647eef89761a16dfdff87e82bb253ab" } nemo-run = { git = "https://github.com/NVIDIA-NeMo/Run.git", rev = "17ae86b64d7f75653351664f5d8c9e466faede00" } emerging_optimizers = { git = "https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git", rev = "v0.2.0" } nvidia-resiliency-ext = { git = "https://github.com/NVIDIA/nvidia-resiliency-ext.git", rev = "b2bb3d728a18795807d9f76c535e005a609a1b01" } diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq/model_config.yaml index aa4fde5e512..49c13b648b5 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq/model_config.yaml @@ -42,7 +42,6 @@ MODEL_ARGS: --return-log-probs: true --num-tokens-from-file: true --inference-dynamic-batching-buffer-size-gb: 20 - --inference-dynamic-batching-cuda-graph-max-tokens: 2048 --cuda-graph-impl: local --cuda-graph-scope: full --disable-chunked-prefill: true diff --git a/tests/test_utils/recipes/h100/gpt-static-inference.yaml b/tests/test_utils/recipes/h100/gpt-static-inference.yaml index 87046588b2b..de9b0235203 100644 --- a/tests/test_utils/recipes/h100/gpt-static-inference.yaml +++ b/tests/test_utils/recipes/h100/gpt-static-inference.yaml @@ -68,7 +68,7 @@ products: - test_case: [gpt_static_inference_tp1_pp1_583m_fp8_cudagraphs] products: - environment: [dev] - scope: [mr, mr-github] + scope: [mr-broken, mr-github-broken] platforms: [dgx_h100] - test_case: [gpt_static_inference_tp1_pp1_16b_multiprompt_tokensmatch] products: diff --git a/tests/unit_tests/dist_checkpointing/utils.py b/tests/unit_tests/dist_checkpointing/utils.py index 8a9df54ddc8..f5acc373a2f 100644 --- a/tests/unit_tests/dist_checkpointing/utils.py +++ b/tests/unit_tests/dist_checkpointing/utils.py @@ -177,11 +177,38 @@ def init_checkpointing_mock_args(args, ckpt_dir, fully_parallel=False): def setup_model_and_optimizer( - seed, tp, pp, initialize_fn=initialize_gpt_model, bf16=True, dist_opt=True, optimizer='adam' + seed, + tp, + pp, + initialize_fn=initialize_gpt_model, + bf16=True, + dist_opt=True, + optimizer='adam', + use_param_layout=False, ): + optimizer_type = optimizer + use_layer_wise = False + if optimizer_type == 'dist_muon': + optimizer = 'muon' + use_layer_wise = True + if optimizer_type in ('muon', 'dist_muon') and dist_opt: + use_layer_wise = True + + # When use_layer_wise is True and use_param_layout is False, route DDP + # construction through the legacy path (no precomputed param layout, no + # ``use_distributed_optimizer=True`` flip). LayerWiseDistributedOptimizer + # then syncs via its legacy ``allgather_params()`` codepath rather than + # ``start_param_sync``. + ddp_use_dist_opt = dist_opt and not (use_layer_wise and not use_param_layout) + ddp_use_layer_wise = use_layer_wise and use_param_layout + mock_args = parse_args(ignore_unknown_args=True) with mock.patch('megatron.training.training.get_args', new=lambda: mock_args): init_basic_mock_args(mock_args, tp, pp, bf16=bf16) + mock_args.use_distributed_optimizer = ddp_use_dist_opt + mock_args.use_layer_wise_distributed_optimizer = ddp_use_layer_wise + if ddp_use_layer_wise: + mock_args.optimizer = optimizer model = get_model( partial( initialize_fn, @@ -193,19 +220,10 @@ def setup_model_and_optimizer( ) ) - optimizer_type = optimizer - use_layer_wise = False - if optimizer_type == 'dist_muon': - optimizer = 'muon' - use_layer_wise = True - if optimizer_type in ('muon', 'dist_muon') and dist_opt: - use_layer_wise = True - dist_opt = False - config = OptimizerConfig( bf16=bf16, params_dtype=torch.bfloat16 if bf16 else torch.float, - use_distributed_optimizer=dist_opt, + use_distributed_optimizer=ddp_use_dist_opt, use_layer_wise_distributed_optimizer=use_layer_wise, optimizer=optimizer, ) @@ -272,10 +290,27 @@ def setup_moe_model_and_optimizer( use_grouped_mlp=False, use_glu=False, optimizer='adam', + use_param_layout=False, ): + optimizer_type = optimizer + use_layer_wise = False + if optimizer_type == 'dist_muon': + optimizer = 'muon' + use_layer_wise = True + if optimizer_type in ('muon', 'dist_muon') and dist_opt: + use_layer_wise = True + + # See setup_model_and_optimizer for the use_param_layout semantics. + ddp_use_dist_opt = dist_opt and not (use_layer_wise and not use_param_layout) + ddp_use_layer_wise = use_layer_wise and use_param_layout + mock_args = parse_args(ignore_unknown_args=True) with mock.patch('megatron.training.training.get_args', new=lambda: mock_args): init_basic_mock_args(mock_args, tp, pp, bf16=bf16) + mock_args.use_distributed_optimizer = ddp_use_dist_opt + mock_args.use_layer_wise_distributed_optimizer = ddp_use_layer_wise + if ddp_use_layer_wise: + mock_args.optimizer = optimizer model = get_model( partial( initialize_fn, @@ -292,19 +327,10 @@ def setup_moe_model_and_optimizer( ) ) - optimizer_type = optimizer - use_layer_wise = False - if optimizer_type == 'dist_muon': - optimizer = 'muon' - use_layer_wise = True - if optimizer_type in ('muon', 'dist_muon') and dist_opt: - use_layer_wise = True - dist_opt = False - config = OptimizerConfig( bf16=bf16, params_dtype=torch.bfloat16 if bf16 else torch.float, - use_distributed_optimizer=dist_opt, + use_distributed_optimizer=ddp_use_dist_opt, use_layer_wise_distributed_optimizer=use_layer_wise, optimizer=optimizer, ) diff --git a/tests/unit_tests/distributed/test_layer_wise_param_layout.py b/tests/unit_tests/distributed/test_layer_wise_param_layout.py new file mode 100644 index 00000000000..669f4ab23ec --- /dev/null +++ b/tests/unit_tests/distributed/test_layer_wise_param_layout.py @@ -0,0 +1,393 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Tests for LayerWiseDistributedOptimizer parameter layout computation. + +These tests verify the size-matching shard-aligned layout logic without +requiring GPU or distributed setup. +""" + +import math +from collections import Counter +from unittest import mock + +import pytest +import torch + +from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer +from megatron.core.optimizer.param_layout import BufferKey, pad_param_start, pad_to_divisor + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_LWO = LayerWiseDistributedOptimizer + + +def _make_param(shape, dtype=torch.bfloat16, **attrs): + param = torch.nn.Parameter(torch.randn(shape, dtype=dtype)) + for attr_name, attr_value in attrs.items(): + setattr(param, attr_name, attr_value) + return param + + +def _make_ddp_config(pad_for_high_busbw=False, grad_reduce_in_fp32=True): + cfg = mock.Mock() + cfg.pad_buckets_for_high_nccl_busbw = pad_for_high_busbw + cfg.grad_reduce_in_fp32 = grad_reduce_in_fp32 + return cfg + + +# --------------------------------------------------------------------------- +# Tests for _shard_divisor +# --------------------------------------------------------------------------- + + +class TestShardDivisor: + + def _verify(self, dp_size, high_busbw=False): + cfg = _make_ddp_config(pad_for_high_busbw=high_busbw) + sd = _LWO._shard_divisor(dp_size, cfg) + assert sd % 64 == 0, f"shard_divisor {sd} not 64-aligned" + if high_busbw: + bucket_div = math.lcm(dp_size, 128, 2**16) + else: + bucket_div = math.lcm(dp_size, 128) + assert (dp_size * sd) % bucket_div == 0 + + def test_dp2(self): + self._verify(2) + + def test_dp4(self): + self._verify(4) + + def test_dp8(self): + self._verify(8) + + def test_dp8_high_busbw(self): + self._verify(8, high_busbw=True) + + def test_dp1(self): + self._verify(1) + + +# --------------------------------------------------------------------------- +# Helpers for layout verification +# --------------------------------------------------------------------------- + + +def _get_shard_for_param(layout, param, dp_size): + """Return which shard a param lands in.""" + param_start_index, param_end_index, bucket_id = layout.param_index_map[param] + bucket_start_index, bucket_end_index = layout.bucket_indices[bucket_id] + shard_size = (bucket_end_index - bucket_start_index) // dp_size + shard_id = (param_start_index - bucket_start_index) // shard_size + return shard_id + + +def _assert_param_within_shard(layout, param, dp_size): + """Assert that a param lies entirely within one shard.""" + param_start_index, param_end_index, bucket_id = layout.param_index_map[param] + bucket_start_index, bucket_end_index = layout.bucket_indices[bucket_id] + shard_size = (bucket_end_index - bucket_start_index) // dp_size + shard_id = (param_start_index - bucket_start_index) // shard_size + shard_start_index = bucket_start_index + shard_id * shard_size + shard_end_index = shard_start_index + shard_size + assert ( + shard_start_index <= param_start_index + ), f"param start {param_start_index} before shard start {shard_start_index}" + assert ( + param_end_index <= shard_end_index + ), f"param end {param_end_index} past shard end {shard_end_index}" + + +# --------------------------------------------------------------------------- +# Tests for _compute_per_buffer_param_layout (size-matching) +# --------------------------------------------------------------------------- + + +class TestSizeMatchingLayout: + + # -- uniform params: all same size, dp_size divides count -- + + def test_uniform_params_exact_fit(self): + """8 same-size params with dp_size=4 → each shard gets 2 params.""" + dp_size = 4 + params = [_make_param((256,)) for _ in range(8)] + cfg = _make_ddp_config() + + layout = _LWO._compute_per_buffer_param_layout(params, None, dp_size, cfg) + + for param in params: + _assert_param_within_shard(layout, param, dp_size) + + # With 8 params and dp_size=4, 2 rounds of size-matching. + # Each round fills all 4 shards. No padding needed. + shard_counts = Counter(_get_shard_for_param(layout, param, dp_size) for param in params) + assert set(shard_counts.values()) == {2} + + def test_uniform_params_remainder_gets_padding(self): + """5 same-size params with dp_size=4 → 1 round fills 4, 1 round fills 1 + 3 padding.""" + dp_size = 4 + params = [_make_param((256,)) for _ in range(5)] + cfg = _make_ddp_config() + + layout = _LWO._compute_per_buffer_param_layout(params, None, dp_size, cfg) + + for param in params: + _assert_param_within_shard(layout, param, dp_size) + + # All 5 params assigned; shard 0 gets 2, shards 1-3 get 1 each. + shard_counts = Counter(_get_shard_for_param(layout, param, dp_size) for param in params) + assert shard_counts[0] == 2 + assert sum(shard_counts.values()) == 5 + + # -- mixed sizes -- + + def test_mixed_sizes_no_param_split(self): + """Params with different sizes: each param stays within one shard.""" + dp_size = 2 + params = [ + _make_param((100,)), + _make_param((200,)), + _make_param((100,)), + _make_param((200,)), + _make_param((100,)), + ] + cfg = _make_ddp_config() + + layout = _LWO._compute_per_buffer_param_layout(params, None, dp_size, cfg) + + for param in params: + _assert_param_within_shard(layout, param, dp_size) + + def test_size_matching_prefers_same_size(self): + """When shard 0 gets a 256-elem param, other shards should also get 256-elem params.""" + dp_size = 4 + big_params = [_make_param((256,)) for _ in range(4)] + small_params = [_make_param((64,)) for _ in range(4)] + # Interleave: big_params[0], small_params[0], big_params[1], small_params[1], ... + params = [] + for big_param, small_param in zip(big_params, small_params): + params.extend([big_param, small_param]) + cfg = _make_ddp_config() + + layout = _LWO._compute_per_buffer_param_layout(params, None, dp_size, cfg) + + for param in params: + _assert_param_within_shard(layout, param, dp_size) + + # All 4 big params should be in the same round (one per shard). + big_param_bucket_ids = set() + for param in big_params: + _, _, bucket_id = layout.param_index_map[param] + big_param_bucket_ids.add(bucket_id) + # They should share a bucket (matched in one round). + assert len(big_param_bucket_ids) == 1 + + # -- fallback packing for unique-large seeds -- + + def test_unique_large_seed_packs_smaller_params(self): + """A unique-large seed's empty shard slots should absorb smaller params. + + Pool order is the *reverse* of input/forward order, so the + ``unique_large_param`` is placed last in the input list to make it the + first seed (top of pool). Without the packing fallback, the unique + seed would emit a bucket with ``dp_size - 1`` shards of pure padding + and the trailing smaller params would form their own bucket. With the + fallback, the smaller params land in the large param's bucket, + eliminating the second bucket entirely. + """ + dp_size = 4 + unique_large_param = _make_param((1024,)) + filler_params = [_make_param((128,)) for _ in range(3)] + cfg = _make_ddp_config() + + layout = _LWO._compute_per_buffer_param_layout( + filler_params + [unique_large_param], None, dp_size, cfg + ) + + # Invariant still holds: each param lies entirely within one shard. + for param in [unique_large_param] + filler_params: + _assert_param_within_shard(layout, param, dp_size) + + # All four params share a single bucket (no second bucket for the + # filler params). + assert len(layout.bucket_indices) == 1 + _, _, large_bucket_id = layout.param_index_map[unique_large_param] + for filler_param in filler_params: + _, _, filler_bucket_id = layout.param_index_map[filler_param] + assert filler_bucket_id == large_bucket_id + + # Filler params land in shard slots other than the unique-large's. + large_shard = _get_shard_for_param(layout, unique_large_param, dp_size) + for filler_param in filler_params: + assert _get_shard_for_param(layout, filler_param, dp_size) != large_shard + + # -- shared_embedding isolation -- + + def test_shared_embedding_isolated(self): + """shared_embedding params go in their own bucket and fit within shard 0.""" + dp_size = 2 + regular_params = [_make_param((128,)) for _ in range(4)] + embedding_param = _make_param((128,), shared_embedding=True) + params = [embedding_param] + regular_params + cfg = _make_ddp_config() + + layout = _LWO._compute_per_buffer_param_layout(params, None, dp_size, cfg) + + # The shared embedding must fit entirely within one shard so reduce-scatter + # delivers the full reduced gradient to its owner rank. + _assert_param_within_shard(layout, embedding_param, dp_size) + + # And it must be the sole real param in its bucket. + _, _, embedding_bucket_id = layout.param_index_map[embedding_param] + for param in regular_params: + _, _, bucket_id = layout.param_index_map[param] + assert bucket_id != embedding_bucket_id, "shared_embedding should be in its own bucket" + + # Embedding lives in shard 0; shards 1..dp_size-1 are pure padding. + assert _get_shard_for_param(layout, embedding_param, dp_size) == 0 + + def test_shared_embedding_fits_in_shard_at_high_dp(self): + """With dp_size > 2, a vocab-sized shared embedding still fits in one shard. + + Regression test for the case where the isolated bucket was sized to ~numel + instead of dp_size * numel, causing the embedding to silently span multiple + shards on dp_size > 1. + """ + for dp_size in [2, 4, 8]: + embedding_param = _make_param((1024,), shared_embedding=True) + cfg = _make_ddp_config() + layout = _LWO._compute_per_buffer_param_layout([embedding_param], None, dp_size, cfg) + _assert_param_within_shard(layout, embedding_param, dp_size) + assert _get_shard_for_param(layout, embedding_param, dp_size) == 0 + + # -- bucket size threshold -- + + def test_bucket_size_creates_multiple_buckets(self): + """When bucket_size is small, multiple buckets are created.""" + dp_size = 2 + params = [_make_param((256,)) for _ in range(8)] + cfg = _make_ddp_config() + + # Each round: 256 elements per shard. shard_pos after round = 256. + # Padded shard size = pad_to_divisor(256, shard_div). + # bucket_total = dp_size * padded_shard_size. + # Set bucket_size small enough to force a split after 1 round. + shard_div = _LWO._shard_divisor(dp_size, cfg) + padded = pad_to_divisor(256, shard_div) + small_bucket = dp_size * padded # triggers after 1 round + + layout = _LWO._compute_per_buffer_param_layout(params, small_bucket, dp_size, cfg) + + assert len(layout.bucket_indices) == 4 # 8 params / (dp_size per round) = 4 rounds + + # -- bucket alignment -- + + def test_bucket_dp_divisible(self): + """Every bucket total must be divisible by dp_size.""" + for dp_size in [2, 4, 8]: + params = [_make_param((333,)) for _ in range(dp_size * 2)] + cfg = _make_ddp_config() + layout = _LWO._compute_per_buffer_param_layout(params, None, dp_size, cfg) + for bucket_start_index, bucket_end_index in layout.bucket_indices: + assert (bucket_end_index - bucket_start_index) % dp_size == 0 + + def test_bucket_global_alignment(self): + """Bucket end must be a multiple of lcm(dp_size, 128).""" + dp_size = 4 + params = [_make_param((333,)) for _ in range(8)] + cfg = _make_ddp_config() + layout = _LWO._compute_per_buffer_param_layout(params, None, dp_size, cfg) + divisor = math.lcm(dp_size, 128) + for _, bucket_end_index in layout.bucket_indices: + assert ( + bucket_end_index % divisor == 0 + ), f"bucket end {bucket_end_index} not aligned to {divisor}" + + # -- backprop order -- + + def test_backprop_order_in_shards(self): + """Within each shard, params should appear in backprop (reverse model) order.""" + dp_size = 2 + params = [_make_param((128,)) for _ in range(6)] + cfg = _make_ddp_config() + + layout = _LWO._compute_per_buffer_param_layout(params, None, dp_size, cfg) + + # Group params by shard in backprop (reverse model) order. + shard_params: dict[int, list] = {i: [] for i in range(dp_size)} + for param in reversed(params): + shard_id = _get_shard_for_param(layout, param, dp_size) + param_start_index, _, _ = layout.param_index_map[param] + shard_params[shard_id].append((param_start_index, param)) + + # Within each shard, buffer positions should be increasing in backprop order. + for shard_id, items in shard_params.items(): + param_start_indices = [param_start_index for param_start_index, _ in items] + assert param_start_indices == sorted( + param_start_indices + ), f"shard {shard_id} not in order" + + # -- dp_size=1 -- + + def test_dp_size_1(self): + """With dp_size=1, every param goes to shard 0 (trivially no splitting).""" + params = [_make_param((100,)) for _ in range(5)] + cfg = _make_ddp_config() + layout = _LWO._compute_per_buffer_param_layout(params, None, 1, cfg) + + for param in params: + _assert_param_within_shard(layout, param, 1) + + # -- single param -- + + def test_single_param(self): + """A single param should produce one bucket.""" + dp_size = 4 + params = [_make_param((512,))] + cfg = _make_ddp_config() + layout = _LWO._compute_per_buffer_param_layout(params, None, dp_size, cfg) + + assert len(layout.bucket_indices) == 1 + _assert_param_within_shard(layout, params[0], dp_size) + + # -- all params assigned -- + + def test_all_params_in_layout(self): + """Every input param appears in param_index_map exactly once.""" + dp_size = 4 + params = [_make_param((numel,)) for numel in [64, 128, 256, 64, 128, 256, 64]] + cfg = _make_ddp_config() + layout = _LWO._compute_per_buffer_param_layout(params, None, dp_size, cfg) + + assert set(id(param) for param in layout.param_index_map.keys()) == set( + id(param) for param in params + ) + + +# --------------------------------------------------------------------------- +# Tests for compute_full_param_layout +# --------------------------------------------------------------------------- + + +class TestLayerwiseFullParamLayout: + + def test_basic_full_layout(self): + """End-to-end: params grouped by dtype, then laid out with shard alignment.""" + dp_size = 2 + params = [_make_param((256,)) for _ in range(4)] + cfg = _make_ddp_config() + layout = _LWO.compute_full_param_layout(params, None, dp_size, cfg) + assert len(layout.layouts) == 1 + key = list(layout.layouts.keys())[0] + assert key == BufferKey(torch.bfloat16, torch.float, False) + + def test_expert_parallel_separate_buffer(self): + """Expert-parallel params should be in a separate buffer group.""" + dp_size = 2 + dense = _make_param((256,)) + expert = _make_param((256,), allreduce=False) + cfg = _make_ddp_config() + layout = _LWO.compute_full_param_layout([dense, expert], None, dp_size, cfg) + assert len(layout.layouts) == 2 diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 7bcf21882c1..f21df2db81d 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -121,6 +121,7 @@ class DynamicEngineTestConfig: use_fixed_output_lengths: bool = False num_cuda_graphs: int = None use_cuda_graphs_for_non_decode_steps: bool = True + cuda_graph_all_prefills: bool = False fp8: bool = False model_provider: str = "gpt" return_log_probs: bool = False @@ -256,7 +257,10 @@ def _build_inference_context( inference_config=InferenceConfig( max_sequence_length=test_config.max_sequence_length, num_cuda_graphs=test_config.num_cuda_graphs, - use_cuda_graphs_for_non_decode_steps=True, + use_cuda_graphs_for_non_decode_steps=( + test_config.use_cuda_graphs_for_non_decode_steps + ), + cuda_graph_all_prefills=test_config.cuda_graph_all_prefills, buffer_size_gb=test_config.context_buffer_size_gb, paused_buffer_size_gb=test_config.context_paused_buffer_size_gb, block_size_tokens=test_config.context_block_size_tokens, @@ -604,6 +608,7 @@ def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None num_tokens_to_generate = 16 # Run test. + # Force decode-only CG capture: capturing mixed graphs across the full range will OOM. env = self._run_test( num_tokens_to_generate=num_tokens_to_generate, model_provider=model_provider, @@ -611,6 +616,7 @@ def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None cuda_graph_scope=cuda_graph_scope, force_build_cuda_graphs=True, context_max_requests=128, + use_cuda_graphs_for_non_decode_steps=False, ) # Validate max_requests, max_tokens. @@ -783,11 +789,11 @@ def test_fixed_output_lengths(self, model_provider: str) -> None: @pytest.mark.skipif( not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) - def test_cuda_graph_token_counts(self) -> None: + @pytest.mark.parametrize("use_non_decode", [False, True]) + def test_cuda_graph_token_counts(self, use_non_decode: bool) -> None: """Test initialization of `cuda_graph_token_counts` in dynamic context.""" - # Test num_cuda_graphs. - for num_cuda_graphs, expected_cuda_graph_token_counts in [ + decode_only_cases = [ (0, [80]), (1, [80]), (2, [80, 40]), @@ -795,21 +801,39 @@ def test_cuda_graph_token_counts(self) -> None: (8, [80, 64, 48, 32, 16]), (16, [80, 72, 64, 56, 48, 40, 32, 24, 16, 8]), (32, [80, 72, 64, 56, 48, 40, 32, 24, 16, 8]), - ]: + ] + non_decode_cases = [ + (0, [80]), + (1, [80]), + (2, [80, 40]), + (4, [80, 72, 48, 24]), + (8, [80, 64, 48, 32, 16]), + (16, [1024, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8]), + (32, [1024, 512, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8]), + ] + cases = non_decode_cases if use_non_decode else decode_only_cases + + for num_cuda_graphs, expected_cuda_graph_token_counts in cases: # Build cuda graphs (inside dynamic engine). env = self._build_test_env( DynamicEngineTestConfig( - context_buffer_size_gb=0.01, num_cuda_graphs=num_cuda_graphs + context_buffer_size_gb=0.01, + num_cuda_graphs=num_cuda_graphs, + use_cuda_graphs_for_non_decode_steps=use_non_decode, + cuda_graph_all_prefills=use_non_decode, ) ) actual_cuda_graph_token_counts = env.engine.context.cuda_graph_token_counts - assert ( - actual_cuda_graph_token_counts == expected_cuda_graph_token_counts - ), "num_cuda_graphs %d ... cuda_graph_token_counts: expected %s, found %s." % ( - num_cuda_graphs, - expected_cuda_graph_token_counts, - actual_cuda_graph_token_counts, + assert actual_cuda_graph_token_counts == expected_cuda_graph_token_counts, ( + "num_cuda_graphs %d use_non_decode=%s ... cuda_graph_token_counts: " + "expected %s, found %s." + % ( + num_cuda_graphs, + use_non_decode, + expected_cuda_graph_token_counts, + actual_cuda_graph_token_counts, + ) ) @pytest.mark.internal diff --git a/tests/unit_tests/inference/test_batch_dimension_utils.py b/tests/unit_tests/inference/test_batch_dimension_utils.py index ce9bb579ee6..005281e3481 100644 --- a/tests/unit_tests/inference/test_batch_dimension_utils.py +++ b/tests/unit_tests/inference/test_batch_dimension_utils.py @@ -114,6 +114,77 @@ def test_mixed_token_counts_subset_of_decode(self, num_cuda_graphs): ) +class TestGenerateCUDAGraphEdgeCases: + """Single-process tests for graph generation edge cases.""" + + def test_generate_cuda_graph_edge_cases(self): + """Edge cases in graph generation: + max_tokens > max_requests, small max_tokens, step_size floor, speculative decoding. + """ + + # max_tokens > max_requests: decode graphs capped, prefill graphs span full budget + g_list, _ = CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( + tp_size=1, + num_cuda_graphs=8, + cuda_graph_max_tokens=512, + cuda_graph_mixed_prefill_request_count=MIXED_PREFILL_COUNT, + max_requests=64, + max_tokens=512, + max_sequence_length=4096, + use_cuda_graphs_for_non_decode_steps=True, + ) + decode_graphs = [g for g in g_list if g.prefill_req_count == 0] + prefill_graphs = [g for g in g_list if g.prefill_req_count > 0] + assert all(g.token_count <= 64 for g in decode_graphs) + assert prefill_graphs and max(g.token_count for g in prefill_graphs) > 64 + + # max_tokens < num_cuda_graphs: step_size could round to zero + g_list, _ = CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( + tp_size=1, + num_cuda_graphs=32, + cuda_graph_max_tokens=10, + cuda_graph_mixed_prefill_request_count=0, + max_requests=10, + max_tokens=10, + max_sequence_length=4096, + use_cuda_graphs_for_non_decode_steps=False, + ) + assert len(g_list) > 0 and all(g.token_count > 0 for g in g_list) + + # Step size >= tp_size for various TP sizes + for tp_size in (1, 2, 4, 8): + g_list, _ = CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( + tp_size=tp_size, + num_cuda_graphs=64, + cuda_graph_max_tokens=tp_size, + cuda_graph_mixed_prefill_request_count=0, + max_requests=tp_size, + max_tokens=tp_size, + max_sequence_length=4096, + use_cuda_graphs_for_non_decode_steps=False, + ) + assert len(g_list) > 0 + for g in g_list: + assert g.token_count % tp_size == 0 + + # Speculative decoding with max_tokens >> max_requests * (spec+1) + for num_spec in (1, 3, 7): + g_list, _ = CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( + tp_size=1, + num_cuda_graphs=8, + cuda_graph_max_tokens=1024, + cuda_graph_mixed_prefill_request_count=MIXED_PREFILL_COUNT, + max_requests=32, + max_tokens=1024, + max_sequence_length=4096, + use_cuda_graphs_for_non_decode_steps=True, + num_speculative_tokens=num_spec, + ) + for g in [g for g in g_list if g.prefill_req_count == 0]: + assert g.token_count == g.decode_req_count * (num_spec + 1) + assert g.decode_req_count <= 32 + + class TestMatchGraphConfigWithEP: """Tests for match_graph_config with expert parallelism. diff --git a/tests/unit_tests/models/test_mimo_colocated_communicator.py b/tests/unit_tests/models/test_mimo_colocated_communicator.py index 67cee551a0f..5b253ee1d2e 100644 --- a/tests/unit_tests/models/test_mimo_colocated_communicator.py +++ b/tests/unit_tests/models/test_mimo_colocated_communicator.py @@ -234,8 +234,8 @@ def test_rank_offset_mismatch(self): "side,dim,expected", [ ("src", "pp", "src PP must be 1"), - ("dest", "pp", "dest PP must be 1"), ("src", "cp", "CP must be 1"), + ("dest", "cp", "CP must be 1"), ], ) def test_pp_or_cp_gt_one_rejected(self, side, dim, expected): @@ -250,6 +250,13 @@ def test_pp_or_cp_gt_one_rejected(self, side, dim, expected): with pytest.raises(ValueError, match=expected): make_comm(src_grid, dest_grid) + def test_dest_pp_gt_one_accepted(self): + # Dest PP>1 is valid: the three-phase colocated schedule handles + # the LLM pipeline orchestration. The bridge only needs src PP=1. + src_grid = create_hypercomm_grid(tp=4, dp=2) + dest_grid = create_hypercomm_grid(tp=2, pp=2, dp=2) + make_comm(src_grid, dest_grid) + def test_dp_not_divisible(self): # 6-rank grids with DP sizes (3 vs 2) that neither divides the other. # Fits inside an 8-rank world (HyperCommGrid enforces size <= world - offset). diff --git a/tests/unit_tests/models/test_mimo_colocated_correctness.py b/tests/unit_tests/models/test_mimo_colocated_correctness.py index e2d91bdf83e..1432a9c839f 100644 --- a/tests/unit_tests/models/test_mimo_colocated_correctness.py +++ b/tests/unit_tests/models/test_mimo_colocated_correctness.py @@ -51,6 +51,7 @@ """ import os +import re from functools import partial import pytest @@ -61,8 +62,10 @@ import megatron.core.pipeline_parallel.schedules as schedule from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.models.mimo.colocated_schedule import colocated_forward_backward_with_pp from megatron.core.models.mimo.optimizer import get_mimo_optimizer from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.transformer.enums import ModelType from megatron.core.utils import unwrap_model from tests.unit_tests.models.test_mimo_1f1b_schedule import ( @@ -164,7 +167,7 @@ def _set_deterministic_env(): os.environ.pop('NVTE_UNFUSED_ATTN', None) -def _wire_training_hooks(mimo_model, language_pg, vision_pg): +def _wire_training_hooks(mimo_model, language_pg, vision_pg, llm_grid=None): """Attach no_sync / finalize_grads / grad_scale hooks to a MimoModel. The finalize hook implements the heterogeneous-DP grad-scaling story @@ -185,6 +188,12 @@ def _wire_training_hooks(mimo_model, language_pg, vision_pg): 3. Calls ``scale_gradients(1/N_global)`` on each side — lands the true global per-token mean uniformly on encoder and LLM grads. + ``llm_grid`` is required for LLM PP>1 callers: with PP>1 the inner + schedule only populates ``num_tokens`` on the last LLM PP stage; this + hook broadcasts it from the last PP rank to earlier stages before the + DP all-reduce so every rank arrives at the same ``N_global``. + Pass ``None`` (default) for PP=1, where the broadcast is a no-op. + Note: encoder has no loss_func (so nothing emits a per-encoder-DP ``num_tokens`` to feed ``finalize_model_grads``' internal all-reduce). Doing the all-reduce once ourselves and calling ``scale_gradients`` @@ -193,6 +202,7 @@ def _wire_training_hooks(mimo_model, language_pg, vision_pg): """ no_sync_func = build_no_sync_func(mimo_model) + pp_group = llm_grid.get_pg("pp") if llm_grid is not None else None def finalize_grads_func(model_list, num_tokens, force_all_reduce=False, **kwargs): # Schedule passes the per-rank sum-across-microbatches of what the @@ -203,6 +213,13 @@ def finalize_grads_func(model_list, num_tokens, force_all_reduce=False, **kwargs "TransformerConfig so the schedule forwards total_num_tokens; got None." ) + # PP>1: only the last LLM PP stage emits a non-zero num_tokens + # from the loss_func. Broadcast to earlier stages so every rank + # holds the same value before the DP all-reduce below. + if pp_group is not None and pp_group.size() > 1: + last_rank = dist.get_global_rank(pp_group, pp_group.size() - 1) + dist.broadcast(num_tokens, src=last_rank, group=pp_group) + # Phase 1: lift the all-reduce. After this, every rank (including # encoder-only replicas) has N_global = total non-padded tokens in # the global batch. @@ -828,6 +845,173 @@ def _assert_encoder_weights_match(ref_module, dist_module, rtol=1e-3, atol=1e-3) ) +_LLM_LAYER_RX = re.compile(r'^(.*decoder\.layers\.)(\d+)(\..*)$') + + +def _llm_pp_remap_name(name, pp_rank, layers_per_stage): + """Remap a dist LLM param name (local layer idx) to its ref PP=1 name (global idx). + + Dist's ``decoder.layers.{local_idx}`` on PP stage ``s`` corresponds to + ref's global layer ``s * layers_per_stage + local_idx``. Non-layer + params (embedding, final_layernorm, output_layer) are present only on + stages that own them and their names match exactly between ref and dist. + """ + m = _LLM_LAYER_RX.match(name) + if not m: + return name + prefix, local_idx_s, suffix = m.groups() + return f"{prefix}{pp_rank * layers_per_stage + int(local_idx_s)}{suffix}" + + +def _copy_llm_params_pp_aware(ref_module, dist_module, pp_rank, pp_size, num_layers): + """Copy LLM params ref (PP=1) → dist (PP>=1) with layer-index remapping. + + Assumes ``dist_llm_tp == ref_llm_tp`` so shards line up 1:1; callers + must verify (the consolidated correctness test only enables LLM PP-aware + copy/oracle when this holds). + """ + assert num_layers % pp_size == 0, ( + f"num_layers={num_layers} not divisible by pp_size={pp_size}; " + f"oracle requires even PP split." + ) + layers_per_stage = num_layers // pp_size + ref_params = dict(ref_module.named_parameters()) + + with torch.no_grad(): + for name, dist_param in dist_module.named_parameters(): + ref_name = _llm_pp_remap_name(name, pp_rank, layers_per_stage) + assert ref_name in ref_params, ( + f"LLM param '{name}' on PP stage {pp_rank} maps to ref name " + f"'{ref_name}' which does not exist in ref (ref has llm_pp=1)." + ) + ref_param = ref_params[ref_name] + assert ref_param.shape == dist_param.shape, ( + f"LLM param '{name}': ref.shape={tuple(ref_param.shape)} != " + f"dist.shape={tuple(dist_param.shape)} — oracle requires " + f"dist_llm_tp == ref_llm_tp." + ) + dist_param.data.copy_(ref_param.data.to(dist_param.dtype)) + + +def _copy_ref_llm_with_tp_and_pp_remap( + ref_module, dist_module, ref_tp_group, dist_tp_group, pp_rank, pp_size, num_layers +): + """Copy ref LLM (PP=1, ``ref_llm_tp``) → dist LLM (PP>=1, ``dist_llm_tp``). + + Combines the PP-aware layer-index remap (from + :func:`_llm_pp_remap_name`) with the TP reshard + (all-gather-across-ref-TP + slice-by-dist-TP). Needed when fan-out + PP>1 forces ``dist_llm_tp != enc_tp`` on a fixed rank count (e.g. + fan-out PP=2 on 8 GPUs). + + Two-phase to avoid cross-PP-stage collectives: + + * Phase 1 — gather full ref params across ``ref_tp_group``. Iterates + ``ref_module.named_parameters()``, which is identical on every + rank in the same ``ref_tp_group``, so the all-gather collective is + lockstep regardless of how dist's PP layout splits those ranks. + * Phase 2 — copy from ``full_ref`` into this rank's local dist params + (PP-staged) using the layer-index remap. No collectives. + + The naive "iterate dist_module and all-gather inside the loop" + approach hangs whenever dist's PP split spreads across a ref TP + group: ranks on different dist PP stages iterate different params + and never reach the same all-gather call together. + """ + assert num_layers % pp_size == 0, f"num_layers={num_layers} not divisible by pp_size={pp_size}." + layers_per_stage = num_layers // pp_size + ref_tp_size = dist.get_world_size(ref_tp_group) + dist_tp_rank = dist.get_rank(dist_tp_group) + dist_tp_size = dist.get_world_size(dist_tp_group) + + # Phase 1: gather full ref params across ref_tp_group. Safe because + # every rank in ref_tp_group iterates ref_module.named_parameters() + # in the same order. + full_ref = {} + with torch.no_grad(): + for name, ref_param in ref_module.named_parameters(): + partition_dim = getattr(ref_param, 'partition_dim', -1) + if ref_tp_size <= 1 or partition_dim < 0: + full_ref[name] = ref_param.data.detach().clone() + continue + shards = [torch.empty_like(ref_param.data) for _ in range(ref_tp_size)] + dist.all_gather(shards, ref_param.data.contiguous(), group=ref_tp_group) + full_ref[name] = torch.cat(shards, dim=partition_dim) + + # Phase 2: per-rank local copy into dist's PP-staged params, with + # PP-aware layer-index remap and dist-TP slicing. No collectives. + with torch.no_grad(): + for name, dist_param in dist_module.named_parameters(): + ref_name = _llm_pp_remap_name(name, pp_rank, layers_per_stage) + assert ref_name in full_ref, ( + f"LLM param '{name}' on PP stage {pp_rank} maps to ref name " + f"'{ref_name}' which does not exist in ref (ref has llm_pp=1)." + ) + full_weight = full_ref[ref_name] + partition_dim = getattr(dist_param, 'partition_dim', -1) + + if dist_tp_size <= 1 or partition_dim < 0: + # Replicated on dist (or no TP): full ref weight should + # match dist's local shape. + assert full_weight.shape == dist_param.shape, ( + f"Param '{name}' (ref '{ref_name}'): full_ref.shape=" + f"{tuple(full_weight.shape)} != dist.shape=" + f"{tuple(dist_param.shape)} (dist_tp={dist_tp_size}, " + f"partition_dim={partition_dim})" + ) + dist_param.data.copy_(full_weight.to(dist_param.dtype)) + continue + + dist_slice = torch.tensor_split(full_weight, dist_tp_size, dim=partition_dim)[ + dist_tp_rank + ] + assert dist_slice.shape == dist_param.shape, ( + f"Param '{name}' (ref '{ref_name}'): sliced.shape=" + f"{tuple(dist_slice.shape)} != dist.shape=" + f"{tuple(dist_param.shape)} (ref_tp={ref_tp_size}, " + f"dist_tp={dist_tp_size}, partition_dim={partition_dim})" + ) + dist_param.data.copy_(dist_slice.to(dist_param.dtype)) + + +def _assert_llm_weights_match_pp_aware( + ref_module, dist_module, pp_rank, pp_size, num_layers, rtol=1e-2, atol=1e-2 +): + """Assert dist LLM shards match ref (PP=1) via the PP-aware layer-index remap. + + Counterpart to :func:`_copy_llm_params_pp_aware`. Non-layer params + (embedding, final_layernorm, output_layer) only exist on stages that + own them and their names are unchanged between ref and dist. + """ + layers_per_stage = num_layers // pp_size + ref_params = dict(ref_module.named_parameters()) + + mismatches = [] + for name, dist_param in dist_module.named_parameters(): + ref_name = _llm_pp_remap_name(name, pp_rank, layers_per_stage) + assert ref_name in ref_params, ( + f"LLM param '{name}' maps to ref '{ref_name}' which does not exist " + f"(ref has llm_pp=1)." + ) + ref_param = ref_params[ref_name] + assert ref_param.shape == dist_param.shape, ( + f"LLM param '{name}': ref.shape={tuple(ref_param.shape)} != " + f"dist.shape={tuple(dist_param.shape)}." + ) + try: + torch.testing.assert_close(dist_param.data, ref_param.data, rtol=rtol, atol=atol) + except AssertionError as e: + mismatches.append((name, ref_name, str(e))) + + if mismatches: + rank = dist.get_rank() + details = "\n".join(f" {n} -> {rn}: {msg}" for n, rn, msg in mismatches) + raise AssertionError( + f"Rank {rank}: {len(mismatches)} LLM param(s) diverged between " + f"PP>1 dist model and PP=1 reference:\n{details}" + ) + + class _BatchIterator: """Minimal iterator over a pre-generated list of batches.""" @@ -857,17 +1041,39 @@ def _run_forward_backward( seq_length, num_microbatches, ): - """One forward/backward pass through the mimo schedule.""" - return schedule.forward_backward_no_pipelining( - forward_step_func=partial( - forward_step, encoder_grid=enc_grid, llm_grid=llm_grid, encoder_name=encoder_name - ), + """Dispatch to no-pipelining (LLM PP=1) or three-phase (LLM PP>1) schedule. + + PP=1 path uses :func:`forward_step` so per-rank slicing for fan-in/ + fan-out happens at forward-time inside the no-pipelining schedule. + PP>1 path uses :func:`colocated_forward_backward_with_pp`, which + applies the same fan-in/fan-out narrowing internally (encoder side + in ``_slice_for_encoder_dp``, LLM side in ``_build_lm_microbatches``). + """ + pp_size = llm_grid.get_pg("pp").size() if 'pp' in llm_grid.dim_names else 1 + if pp_size <= 1: + return schedule.forward_backward_no_pipelining( + forward_step_func=partial( + forward_step, encoder_grid=enc_grid, llm_grid=llm_grid, encoder_name=encoder_name + ), + data_iterator=_BatchIterator(batches), + model=[mimo_model], + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + pg_collection=language_pg, + ) + + return colocated_forward_backward_with_pp( + mimo_model=mimo_model, data_iterator=_BatchIterator(batches), - model=[mimo_model], num_microbatches=num_microbatches, + encoder_grid=enc_grid, + llm_grid=llm_grid, + encoder_name=encoder_name, seq_length=seq_length, micro_batch_size=micro_batch_size, - forward_only=False, + p2p_communicator=P2PCommunicator(pp_group=llm_grid.get_pg("pp"), config=mimo_model.config), pg_collection=language_pg, ) @@ -919,43 +1125,93 @@ def teardown_method(self): version.parse(torch.__version__) < version.parse("2.3.0"), reason="Requires PyTorch 2.3+" ) @pytest.mark.parametrize( - "enc_tp,enc_dp,llm_tp,llm_dp", [(2, 4, 4, 2), (4, 2, 2, 4)], ids=["fan_in", "fan_out"] + "enc_tp,enc_dp,llm_tp,llm_pp,llm_dp", + [ + (2, 4, 4, 1, 2), # fan-in, PP=1 + (4, 2, 2, 1, 4), # fan-out, PP=1 + (2, 4, 2, 2, 2), # fan-in, PP=2 (dist_llm_tp == enc_tp → LLM weight oracle on) + (4, 2, 1, 2, 4), # fan-out, PP=2 (dist_llm_tp != enc_tp → LLM weight oracle off) + ], + ids=["fan_in_pp1", "fan_out_pp1", "fan_in_pp2", "fan_out_pp2"], ) @pytest.mark.parametrize( "mask_pattern", ["uniform", "asymmetric"], ids=["uniform", "asymmetric"] ) @pytest.mark.parametrize("num_microbatches", [1, 4], ids=["mbs1", "mbs4"]) def test_dist_matches_dp1_reference_post_step_weights( - self, enc_tp, enc_dp, llm_tp, llm_dp, mask_pattern, num_microbatches + self, enc_tp, enc_dp, llm_tp, llm_pp, llm_dp, mask_pattern, num_microbatches ): - """Heterogeneous-DP dist post-step encoder weights match equal-DP reference. + """Heterogeneous-(TP/DP/PP) dist post-step weights match equal-DP PP=1 reference. Builds two MimoModels on every rank: - * Dist: the heterogeneous TP/DP config under test, with + * Dist: the heterogeneous TP/DP/PP config under test, with ``calculate_per_token_loss=True`` + custom finalize hook that pure-SUMs DDP and externally divides by ``N_global``. * Ref: equal-DP uniform with ``enc_tp=dist_enc_tp``, - ``enc_dp=dist_enc_dp``, ``llm_tp=dist_enc_tp``, - ``llm_dp=dist_enc_dp`` — bridge is - ``BridgeDirection.EQUAL`` (identity passthrough), and the - encoder TP sharding matches dist's exactly so shards line up - 1:1 for comparison. - - Both models run the same finalize wiring; both DDPs pure-SUM - across their own DP group, then divide uniformly by ``N_global``. - LLM TP differs between the two models, which introduces fp32 TP - accumulation-order drift in the gradient flowing back to the - encoder but does not change the per-token-mean invariant that the - post-step encoder oracle checks. + ``enc_dp=dist_enc_dp``, ``llm_tp=dist_enc_tp``, ``llm_pp=1``, + ``llm_dp=dist_enc_dp`` — bridge is ``BridgeDirection.EQUAL`` + (identity passthrough), and the encoder TP sharding matches + dist's exactly so shards line up 1:1 for comparison. + + For ``llm_pp == 1`` the dist side runs the no-pipelining schedule + with the existing ``forward_step`` (which narrows for fan-in/ + fan-out at forward time). For ``llm_pp > 1`` the dist side runs + :func:`colocated_forward_backward_with_pp` (three-phase: encoder + forward → LLM 1F1B → encoder backward), which applies the same + narrowing internally. Ref always runs no-pipelining (``llm_pp=1``). Reference weights are copied into the distributed model so both start from identical state. One Adam step later, the dist shards - should match the ref shards within fp32 precision. + should match the ref shards within fp32 precision. Oracles: + + * Always: encoder weights, first-layer encoder grads. + * ``llm_pp == 1``: LLM input + LLM logits (TP+DP-gathered, robust + to different LLM TP layouts). + * ``llm_pp > 1`` AND ``dist_llm_tp == enc_tp``: LLM weights via + PP-aware layer-index remap (shards align 1:1 only if dist and + ref share the same LLM TP). + + Fan-out PP>1 with ``dist_llm_tp == enc_tp`` is impossible on 8 + GPUs (``enc_tp * 2 * llm_dp = 8`` and ``llm_dp > enc_dp = 8/enc_tp`` + contradict), so the LLM weight oracle is skipped there — encoder + weight match alone is the gold-standard end-to-end signal because + the encoder's grads pass through the LLM forward + backward. """ if self.world_size != 8: pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + if num_microbatches < llm_pp: + pytest.skip( + f"PP={llm_pp} requires num_microbatches >= {llm_pp}; " f"got {num_microbatches}" + ) + # Wrap the entire test body so we can catch and PRINT any + # exception before pytest's distributed traceback formatter gets + # jumbled across 8 ranks. Without this, NCCL teardown across + # ranks that fail asymmetrically (some raise, some don't) tends + # to SIGABRT before pytest emits per-rank tracebacks. + rank = dist.get_rank() + try: + self._run_test_body( + rank, enc_tp, enc_dp, llm_tp, llm_pp, llm_dp, mask_pattern, num_microbatches + ) + except Exception: + import traceback as _tb + + print( + f"\n=== rank {rank} TEST EXCEPTION ===\n" + f"config: enc_tp={enc_tp} enc_dp={enc_dp} llm_tp={llm_tp} " + f"llm_pp={llm_pp} llm_dp={llm_dp} mbs={num_microbatches} " + f"mask={mask_pattern}\n" + f"{_tb.format_exc()}\n" + f"=== end rank {rank} exception ===\n", + flush=True, + ) + raise + + def _run_test_body( + self, rank, enc_tp, enc_dp, llm_tp, llm_pp, llm_dp, mask_pattern, num_microbatches + ): _set_deterministic_env() torch.use_deterministic_algorithms(True) torch.backends.cudnn.deterministic = True @@ -963,17 +1219,21 @@ def test_dist_matches_dp1_reference_post_step_weights( encoder_name = "images" hidden_size, seq_length, vocab_size = 256, 64, 1000 + num_layers = 2 + # PP-aware param copy/oracle requires layers divisible by pp_size. + assert num_layers % llm_pp == 0 micro_batch_size = 2 # Global batch spans the larger DP side; dist pre-slices per rank - # before forward_step (which further slices encoder/LLM side). + # via _slice_global_batch_for_dist (LLM-DP-sized for fan-in, + # encoder-DP-sized for fan-out). global_batch_size = micro_batch_size * max(enc_dp, llm_dp) # Grids: dist is heterogeneous; ref is equal-DP uniform matching # dist's encoder so the bridge is identity and encoder shards # align 1:1 for direct comparison. dist_enc_grid = create_hypercomm_grid(offset=0, tp=enc_tp, cp=1, pp=1, dp=enc_dp) - dist_llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=1, pp=1, dp=llm_dp) + dist_llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=1, pp=llm_pp, dp=llm_dp) ref_enc_grid = create_hypercomm_grid(offset=0, tp=enc_tp, cp=1, pp=1, dp=enc_dp) ref_llm_grid = create_hypercomm_grid(offset=0, tp=enc_tp, cp=1, pp=1, dp=enc_dp) create_all_embedding_groups([dist_enc_grid, dist_llm_grid, ref_enc_grid, ref_llm_grid]) @@ -988,14 +1248,14 @@ def test_dist_matches_dp1_reference_post_step_weights( overlap_grad_reduce=True, bucket_size=10000, use_distributed_optimizer=True ) - # Build dist first (heterogeneous TP/DP). + # Build dist first (heterogeneous TP/DP/PP). torch.manual_seed(12345) dist_mimo, _, _, dist_language_pg, dist_vision_pg = get_mimo_model( encoder_name=encoder_name, encoder_grid=dist_enc_grid, llm_grid=dist_llm_grid, hidden_size=hidden_size, - num_layers=2, + num_layers=num_layers, vocab_size=vocab_size, seq_len=seq_length, ddp_config=ddp_config, @@ -1007,14 +1267,14 @@ def test_dist_matches_dp1_reference_post_step_weights( dist_mimo.model_type = ModelType.encoder_or_decoder self._mimo_models.append(dist_mimo) - # Reference with equal-DP uniform (enc_tp == llm_tp, enc_dp == llm_dp). + # Reference with equal-DP uniform (enc_tp == llm_tp, enc_dp == llm_dp, PP=1). torch.manual_seed(12345) ref_mimo, _, _, ref_language_pg, ref_vision_pg = get_mimo_model( encoder_name=encoder_name, encoder_grid=ref_enc_grid, llm_grid=ref_llm_grid, hidden_size=hidden_size, - num_layers=2, + num_layers=num_layers, vocab_size=vocab_size, seq_len=seq_length, ddp_config=ddp_config, @@ -1026,25 +1286,57 @@ def test_dist_matches_dp1_reference_post_step_weights( ref_mimo.model_type = ModelType.encoder_or_decoder self._mimo_models.append(ref_mimo) - # Force identical initial state: encoder shards already match - # (same TP layout), so the helper copies shard-to-shard. LLM - # shards don't match (ref_llm_tp=enc_tp, dist_llm_tp=llm_tp), so - # the helper all-gathers ref's shards across ref's TP group and - # re-slices for dist's TP group. + # Force identical initial state. Encoder shards already match + # (same TP layout), so the helper copies shard-to-shard. _copy_ref_params_to_dist( ref_mimo.modality_submodules[encoder_name].module, dist_mimo.modality_submodules[encoder_name].module, ref_enc_grid.get_pg("tp"), dist_enc_grid.get_pg("tp"), ) - _copy_ref_params_to_dist( - ref_mimo.language_model.module, - dist_mimo.language_model.module, - ref_llm_grid.get_pg("tp"), - dist_llm_grid.get_pg("tp"), - ) + if llm_pp == 1: + # LLM shards may not match (ref_llm_tp=enc_tp, dist_llm_tp=llm_tp); + # the helper all-gathers ref's shards across ref's TP group and + # re-slices for dist's TP group. + _copy_ref_params_to_dist( + ref_mimo.language_model.module, + dist_mimo.language_model.module, + ref_llm_grid.get_pg("tp"), + dist_llm_grid.get_pg("tp"), + ) + elif llm_tp == enc_tp: + # PP>1 with matching TP: dist's local layers map to a slice of + # ref's global layers and shards align 1:1 post-remap. + _copy_llm_params_pp_aware( + ref_mimo.language_model.module, + dist_mimo.language_model.module, + pp_rank=dist_llm_grid.get_pg("pp").rank(), + pp_size=llm_pp, + num_layers=num_layers, + ) + else: + # PP>1 with mismatched TP (e.g. fan-out PP=2 on 8 GPUs): + # combine TP-reshard (all-gather ref's TP shards, slice for + # dist's TP) with PP-aware layer-index remap. + _copy_ref_llm_with_tp_and_pp_remap( + ref_mimo.language_model.module, + dist_mimo.language_model.module, + ref_llm_grid.get_pg("tp"), + dist_llm_grid.get_pg("tp"), + pp_rank=dist_llm_grid.get_pg("pp").rank(), + pp_size=llm_pp, + num_layers=num_layers, + ) - _wire_training_hooks(dist_mimo, dist_language_pg, dist_vision_pg) + # PP>1 dist needs the broadcast-from-last-PP-stage variant of the + # finalize hook so num_tokens lands consistently on every rank. + # Ref is always PP=1 (no broadcast needed). + _wire_training_hooks( + dist_mimo, + dist_language_pg, + dist_vision_pg, + llm_grid=dist_llm_grid if llm_pp > 1 else None, + ) _wire_training_hooks(ref_mimo, ref_language_pg, ref_vision_pg) # Distributed optimizers snapshot current param.data into fp32 master @@ -1061,7 +1353,7 @@ def test_dist_matches_dp1_reference_post_step_weights( dist_optimizer = get_mimo_optimizer(dist_mimo, opt_config) ref_optimizer = get_mimo_optimizer(ref_mimo, opt_config) - # Data: one deterministic global batch, identical on every rank. + # Data: deterministic global batches, identical on every rank. torch.manual_seed(99999) global_batches = _generate_and_broadcast_global_batches( global_mbs=global_batch_size, @@ -1083,16 +1375,23 @@ def test_dist_matches_dp1_reference_post_step_weights( ] ref_per_rank_batch_size = global_batch_size // enc_dp - # Logits capture: hook fires on every microbatch forward. - # Registered before forward/backward, removed right after so the - # hook doesn't leak across the second model's run. - dist_logits, dist_logits_hook = _register_logits_capture(dist_mimo) - ref_logits, ref_logits_hook = _register_logits_capture(ref_mimo) - dist_llm_input, dist_input_hook = _register_llm_input_capture(dist_mimo) - ref_llm_input, ref_input_hook = _register_llm_input_capture(ref_mimo) + # Capture hooks: only meaningful for PP=1 (output_layer / decoder + # captures fire on every microbatch; for PP>1 they fire only on + # specific PP stages of dist, breaking the per-microbatch + # alignment with ref's PP=1 captures). Skip registration for PP>1. + capture_hooks = [] + if llm_pp == 1: + dist_logits, dist_logits_hook = _register_logits_capture(dist_mimo) + ref_logits, ref_logits_hook = _register_logits_capture(ref_mimo) + dist_llm_input, dist_input_hook = _register_llm_input_capture(dist_mimo) + ref_llm_input, ref_input_hook = _register_llm_input_capture(ref_mimo) + capture_hooks = [dist_logits_hook, ref_logits_hook, dist_input_hook, ref_input_hook] + else: + dist_logits = ref_logits = dist_llm_input = ref_llm_input = None try: - # One optimizer step on dist (heterogeneous forward_step slicing). + # One optimizer step on dist (PP=1: no-pipelining + forward_step; + # PP>1: three-phase schedule with internal narrowing). dist_optimizer.zero_grad() _run_forward_backward( mimo_model=dist_mimo, @@ -1115,7 +1414,7 @@ def test_dist_matches_dp1_reference_post_step_weights( "silently zeroed by wrong scaling" ) - # One optimizer step on ref (enc_dp == llm_dp → forward_step skips slicing). + # One optimizer step on ref (always PP=1, equal-DP). ref_optimizer.zero_grad() _run_forward_backward( mimo_model=ref_mimo, @@ -1133,18 +1432,14 @@ def test_dist_matches_dp1_reference_post_step_weights( assert ref_success, "Ref optimizer step failed" assert ref_grad_norm is not None and ref_grad_norm > 0, f"Ref grad_norm={ref_grad_norm}" finally: - dist_logits_hook.remove() - ref_logits_hook.remove() - dist_input_hook.remove() - ref_input_hook.remove() - - # Run all three oracles regardless of individual failures so the - # diff-stats print covers every layer. Order: encoder weights / - # first-layer grads first (tightest — same encoder TP/DP layout - # → shards align 1:1), then LLM logits last (loosest — different - # LLM TP layout drives fp32 accumulation drift). Each oracle - # printed its own min/mean/p95/p99/max before its assertion ran, - # so the user sees the full drift distribution for every test. + for h in capture_hooks: + h.remove() + + # Run all oracles regardless of individual failures so the diff- + # stats print covers every layer. Order: encoder weights / first- + # layer grads first (tightest — same encoder TP/DP layout → shards + # align 1:1), then LLM oracles (looser — different LLM TP layout + # drives fp32 accumulation drift). failures = [] try: @@ -1164,20 +1459,60 @@ def test_dist_matches_dp1_reference_post_step_weights( except AssertionError as e: failures.append(('first_layer_grads', str(e))) - try: - _assert_llm_input_match( - ref_llm_input, dist_llm_input, ref_llm_grid, dist_llm_grid, rtol=1e-3, atol=1e-3 - ) - except AssertionError as e: - failures.append(('llm_input', str(e))) + if llm_pp == 1: + # LLM input + logits oracles use TP+DP all-gather, so they + # work for any LLM TP layout. They expect one capture per + # microbatch, which only PP=1 satisfies. + try: + _assert_llm_input_match( + ref_llm_input, dist_llm_input, ref_llm_grid, dist_llm_grid, rtol=1e-3, atol=1e-3 + ) + except AssertionError as e: + failures.append(('llm_input', str(e))) - try: - _assert_llm_logits_match( - ref_logits, dist_logits, ref_llm_grid, dist_llm_grid, rtol=1e-2, atol=1e-2 - ) - except AssertionError as e: - failures.append(('llm_logits', str(e))) + try: + _assert_llm_logits_match( + ref_logits, dist_logits, ref_llm_grid, dist_llm_grid, rtol=1e-2, atol=1e-2 + ) + except AssertionError as e: + failures.append(('llm_logits', str(e))) + elif llm_tp == enc_tp: + # PP>1 with matching TP: assert LLM weights match ref via + # PP-aware layer-index remap. (LLM forward differs between + # 1F1B and no-pipelining, plus TP shards may accumulate in + # different order; tolerances absorb that drift even in fp32.) + try: + _assert_llm_weights_match_pp_aware( + ref_mimo.language_model.module, + dist_mimo.language_model.module, + pp_rank=dist_llm_grid.get_pg("pp").rank(), + pp_size=llm_pp, + num_layers=num_layers, + rtol=1e-2, + atol=1e-2, + ) + except AssertionError as e: + failures.append(('llm_weights_pp_aware', str(e))) + # else: PP>1 with mismatched TP (fan-out on 8 GPUs). The init copy + # via _copy_ref_llm_with_tp_and_pp_remap aligns starting state, but + # post-step shape comparison would require the same TP-reshard of + # ref's PP=1 weights. Skipped here — encoder weight oracle alone + # is sufficient end-to-end (it requires a working LLM forward + + # backward + bridge for the encoder grads to land correctly). if failures: summary = "\n\n".join(f"== {oracle} ==\n{msg}" for oracle, msg in failures) + # Print before raising so the message lands in stdout even when + # post-test cleanup blows up (NCCL teardown across asymmetric + # pass/fail ranks can SIGABRT before pytest formats the + # traceback). + rank = dist.get_rank() + print( + f"\n=== rank {rank} test_dist_matches failures ===\n" + f"config: enc_tp={enc_tp} enc_dp={enc_dp} llm_tp={llm_tp} " + f"llm_pp={llm_pp} llm_dp={llm_dp} mbs={num_microbatches} mask={mask_pattern}\n" + f"{summary}\n" + f"=== end rank {rank} failures ===\n", + flush=True, + ) raise AssertionError(f"{len(failures)} oracle(s) failed:\n{summary}") diff --git a/tests/unit_tests/test_emerging_optimizers.py b/tests/unit_tests/test_emerging_optimizers.py index cbd2bed5ec6..2eddb60893d 100644 --- a/tests/unit_tests/test_emerging_optimizers.py +++ b/tests/unit_tests/test_emerging_optimizers.py @@ -153,6 +153,30 @@ def create_ddp_model(self, model): TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model ) + def create_ddp_model_for_layerwise(self, model, use_param_layout=False): + """Wrap model in DDP for layer-wise distributed optimizer tests. + + Args: + model: Model to wrap. + use_param_layout: If True, supply DDP a precomputed shard-aligned + ``full_param_layout`` (turns on ``ddp_config.use_distributed_optimizer=True`` + + ``start_param_sync``). If False (default), build DDP without a layout + so ``LayerWiseDistributedOptimizer`` syncs via the legacy + flatten / ``all_gather_v`` / unflatten ``allgather_params()`` codepath. + """ + if use_param_layout: + from megatron.training.training import wrap_model_chunks_with_ddp + + ddp_config = DistributedDataParallelConfig() + wrapped = wrap_model_chunks_with_ddp( + [model], + TransformerConfig(num_attention_heads=1, num_layers=1), + ddp_config, + use_layer_wise_distributed_optimizer=True, + ) + return wrapped[0] + return self.create_ddp_model(model) + def test_get_megatron_optimizer_smoke(self): """Smoke test for get_megatron_optimizer function.""" model = Net().bfloat16().cuda() @@ -258,7 +282,7 @@ def test_get_megatron_optimizer_layer_wise(self): """Test get_megatron_optimizer with layer-wise distributed optimizer.""" model = Net().bfloat16().cuda() model.requires_grad_(True) - model = self.create_ddp_model(model) + model = self.create_ddp_model_for_layerwise(model) optimizer_config = OptimizerConfig( optimizer='muon', @@ -302,7 +326,7 @@ def test_get_megatron_muon_optimizer_backward_compatible(self): """Test get_megatron_muon_optimizer with backward compatible layer-wise distributed optimizer.""" model = Net().bfloat16().cuda() model.requires_grad_(True) - model = self.create_ddp_model(model) + model = self.create_ddp_model_for_layerwise(model) optimizer_config = OptimizerConfig( optimizer='muon', diff --git a/tests/unit_tests/test_layer_wise_optimizer.py b/tests/unit_tests/test_layer_wise_optimizer.py index 36e5fe11b67..572bdb83c6f 100644 --- a/tests/unit_tests/test_layer_wise_optimizer.py +++ b/tests/unit_tests/test_layer_wise_optimizer.py @@ -81,6 +81,7 @@ def create_model_and_optimizer( model_kwargs=None, use_layer_wise=True, copy_from=None, + use_param_layout=False, ): """Create model, DDP wrapper, and optimizer. @@ -90,6 +91,11 @@ def create_model_and_optimizer( model_kwargs: Optional kwargs for model initialization use_layer_wise: If True, use LayerWiseDistributedOptimizer via dist_muon; if False, use standard muon ChainedOptimizer (for reference) + use_param_layout: If True, supply DDP a precomputed shard-aligned + ``full_param_layout`` (turns on ``ddp_config.use_distributed_optimizer=True`` + + ``start_param_sync``). If False (default), build DDP without a layout + so ``LayerWiseDistributedOptimizer`` syncs via the legacy + flatten / ``all_gather_v`` / unflatten ``allgather_params()`` codepath. Returns: tuple: (model, optimizer, pg_collection) @@ -100,10 +106,21 @@ def create_model_and_optimizer( model = model_class(**model_kwargs).bfloat16().cuda() model.requires_grad_(True) - ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=False) - model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model - ) + if use_param_layout: + from megatron.training.training import wrap_model_chunks_with_ddp + + ddp_config = DistributedDataParallelConfig() + model = wrap_model_chunks_with_ddp( + [model], + TransformerConfig(num_attention_heads=1, num_layers=1), + ddp_config, + use_layer_wise_distributed_optimizer=use_layer_wise, + )[0] + else: + ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=False) + model = DistributedDataParallel( + TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model + ) if copy_from: model.module.load_state_dict(copy_from.module.state_dict()) else: @@ -136,6 +153,7 @@ def create_model_and_optimizer_with_overlap_param_gather( overlap_param_gather=True, grad_reduce_in_fp32=False, bucket_size=None, + use_param_layout=False, ): """Create model, DDP wrapper, and optimizer with overlap-param-gather enabled. @@ -151,6 +169,9 @@ def create_model_and_optimizer_with_overlap_param_gather( overlap_param_gather: If True, defer param all-gather to bucket infrastructure grad_reduce_in_fp32: If True, reduce grads in fp32 (regression test for dtype fix) bucket_size: Maximum number of parameters per bucket (None = single bucket) + use_param_layout: If True, supply DDP a precomputed shard-aligned + ``full_param_layout`` (turns on ``ddp_config.use_distributed_optimizer=True`` + + ``start_param_sync``). If False (default), build DDP without a layout. Returns: tuple: (model, optimizer, pg_collection) @@ -161,16 +182,37 @@ def create_model_and_optimizer_with_overlap_param_gather( model = model_class(**model_kwargs).bfloat16().cuda() model.requires_grad_(True) - ddp_config = DistributedDataParallelConfig( - use_distributed_optimizer=False, - overlap_param_gather=True, - overlap_grad_reduce=True, - grad_reduce_in_fp32=grad_reduce_in_fp32, - bucket_size=bucket_size, - ) - model = DistributedDataParallel( - TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model - ) + # overlap_param_gather=True requires bucketing, which only happens when + # overlap_grad_reduce=True (otherwise DDP forces bucket_size=None). Couple + # the two so a caller only has to flip one. + overlap_grad_reduce = overlap_param_gather + + if use_param_layout: + from megatron.training.training import wrap_model_chunks_with_ddp + + ddp_config = DistributedDataParallelConfig( + overlap_param_gather=overlap_param_gather, + overlap_grad_reduce=overlap_grad_reduce, + grad_reduce_in_fp32=grad_reduce_in_fp32, + bucket_size=bucket_size, + ) + model = wrap_model_chunks_with_ddp( + [model], + TransformerConfig(num_attention_heads=1, num_layers=1), + ddp_config, + use_layer_wise_distributed_optimizer=True, + )[0] + else: + ddp_config = DistributedDataParallelConfig( + use_distributed_optimizer=False, + overlap_param_gather=overlap_param_gather, + overlap_grad_reduce=overlap_grad_reduce, + grad_reduce_in_fp32=grad_reduce_in_fp32, + bucket_size=bucket_size, + ) + model = DistributedDataParallel( + TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model + ) if copy_from: model.module.load_state_dict(copy_from.module.state_dict()) else: @@ -206,9 +248,12 @@ def create_reference_model(self, model): reference_model.load_state_dict(model.module.state_dict()) return reference_model - def test_basic(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_basic(self, use_param_layout): """Test basic LayerWiseDistributedOptimizer initialization and step with bf16.""" - model, optimizer, pg_collection = self.create_model_and_optimizer() + model, optimizer, pg_collection = self.create_model_and_optimizer( + use_param_layout=use_param_layout + ) # Verify basic properties assert optimizer is not None, "Optimizer should not be None" @@ -336,13 +381,16 @@ def test_sharded_state_dict(self): sh_base.replica_id[2] == 0 ), f'Expected DP replica_id to be 0 for layer-wise optimizer, got: {sh_base.replica_id[2]}' - def test_multiple_optimizers(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_multiple_optimizers(self, use_param_layout): """Test LayerWiseDistributedOptimizer with multiple chained optimizers. Uses get_megatron_muon_optimizer which produces multiple chained optimizers (muon for 2D weights + adam for 1D biases). Tests allgather with multiple ranks. """ - model, optimizer, pg_collection = self.create_model_and_optimizer() + model, optimizer, pg_collection = self.create_model_and_optimizer( + use_param_layout=use_param_layout + ) ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=False) model = DistributedDataParallel( @@ -436,17 +484,23 @@ def test_bf16_error(self): ): LayerWiseDistributedOptimizer([wrapped_optimizer], lw_config, pg_collection) - def _run_parameter_update_test(self, model_class=SimpleModel): + def _run_parameter_update_test(self, use_param_layout, model_class=SimpleModel): """Helper method to test parameter updates with a given model class. Args: - model_class: Model class to use for testing + use_param_layout: forwarded to create_model_and_optimizer. + model_class: Model class to use for testing. """ - model, optimizer, pg_collection = self.create_model_and_optimizer(model_class=model_class) + model, optimizer, pg_collection = self.create_model_and_optimizer( + model_class=model_class, use_param_layout=use_param_layout + ) # Create reference model and optimizer using the same function reference_model, reference_optimizer, _ = self.create_model_and_optimizer( - model_class=model_class, use_layer_wise=False, copy_from=model + model_class=model_class, + use_layer_wise=False, + copy_from=model, + use_param_layout=use_param_layout, ) # Set same gradients on both models @@ -474,17 +528,19 @@ def _run_parameter_update_test(self, model_class=SimpleModel): for param, ref_param in zip(model.parameters(), reference_model.parameters()): torch.testing.assert_close(param.data, ref_param.data, rtol=1e-5, atol=1e-5) - def test_parameter_updates(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_parameter_updates(self, use_param_layout): """Test LayerWiseDistributedOptimizer actually updates model parameters.""" - self._run_parameter_update_test() + self._run_parameter_update_test(use_param_layout=use_param_layout) - def test_parameter_updates_insufficient_parameters(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_parameter_updates_insufficient_parameters(self, use_param_layout): """Test LayerWiseDistributedOptimizer when there are insufficient parameters for all ranks. Uses a tiny model with only 1 layer (2 parameters: weight and bias). This will be insufficient when world size > 2. """ - self._run_parameter_update_test(model_class=TinyModel) + self._run_parameter_update_test(use_param_layout=use_param_layout, model_class=TinyModel) def test_broadcast_vs_allgather(self): """Test LayerWiseDistributedOptimizer allgather code agains broadcast code.""" @@ -524,10 +580,11 @@ def test_broadcast_vs_allgather(self): # ---- Overlap-param-gather tests ---- - def test_overlap_param_gather_basic(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_overlap_param_gather_basic(self, use_param_layout): """Test overlap-param-gather path: init, forward/backward/step, bucket-based param sync.""" - model, optimizer, pg_collection = ( - self.create_model_and_optimizer_with_overlap_param_gather() + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + use_param_layout=use_param_layout ) assert optimizer is not None, "Optimizer should not be None" @@ -572,15 +629,16 @@ def test_overlap_param_gather_basic(self): msg=f"Parameter {name} differs between rank 0 and rank {i}", ) - def test_overlap_param_gather_parameter_updates(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_overlap_param_gather_parameter_updates(self, use_param_layout): """Test overlap-param-gather produces same parameter updates as standard optimizer.""" - model, optimizer, pg_collection = ( - self.create_model_and_optimizer_with_overlap_param_gather() + model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( + use_param_layout=use_param_layout ) # Create reference model with standard (non-layer-wise) optimizer reference_model, reference_optimizer, _ = self.create_model_and_optimizer( - use_layer_wise=False, copy_from=model + use_layer_wise=False, copy_from=model, use_param_layout=use_param_layout ) # Set same gradients on both models @@ -602,7 +660,8 @@ def test_overlap_param_gather_parameter_updates(self): for param, ref_param in zip(model.parameters(), reference_model.parameters()): torch.testing.assert_close(param.data, ref_param.data, rtol=1e-5, atol=1e-5) - def test_overlap_param_gather_vs_sync_allgather(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_overlap_param_gather_vs_sync_allgather(self, use_param_layout): """Key correctness test: overlap path and sync allgather produce identical updates. Compares: @@ -611,12 +670,14 @@ def test_overlap_param_gather_vs_sync_allgather(self): """ # Create overlap model overlap_model, overlap_optimizer, pg_collection = ( - self.create_model_and_optimizer_with_overlap_param_gather(overlap_param_gather=True) + self.create_model_and_optimizer_with_overlap_param_gather( + overlap_param_gather=True, use_param_layout=use_param_layout + ) ) # Create sync model with same weights (overlap_param_gather=True but sync allgather) sync_model, sync_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=False, copy_from=overlap_model + overlap_param_gather=False, copy_from=overlap_model, use_param_layout=use_param_layout ) # Verify initial parameters match @@ -722,18 +783,22 @@ def test_overlap_param_gather_vs_standard_ddp(self): msg="Overlap-param-gather and standard paths produced different updates", ) - def test_overlap_param_gather_insufficient_parameters(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_overlap_param_gather_insufficient_parameters(self, use_param_layout): """Test overlap-param-gather with TinyModel (only 2 params). Many ranks will have no assigned params when world_size > 2. """ model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( - model_class=TinyModel + model_class=TinyModel, use_param_layout=use_param_layout ) # Create reference model with standard (non-layer-wise) optimizer reference_model, reference_optimizer, _ = self.create_model_and_optimizer( - model_class=TinyModel, use_layer_wise=False, copy_from=model + model_class=TinyModel, + use_layer_wise=False, + copy_from=model, + use_param_layout=use_param_layout, ) # Set same gradients on both models @@ -793,7 +858,8 @@ def test_overlap_param_gather_broadcast_vs_allgather(self): for param, ref_param in zip(model.parameters(), reference_model.parameters()): torch.testing.assert_close(param.data, ref_param.data, rtol=0, atol=0) - def test_overlap_param_gather_multi_iteration(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_overlap_param_gather_multi_iteration(self, use_param_layout): """Test overlap-param-gather correctness over multiple training iterations. Runs multiple forward/backward/step iterations using the async allgather path. @@ -801,12 +867,12 @@ def test_overlap_param_gather_multi_iteration(self): model using the sync path. """ model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=True + overlap_param_gather=True, use_param_layout=use_param_layout ) # Create reference model with sync allgather for comparison ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=False, copy_from=model + overlap_param_gather=False, copy_from=model, use_param_layout=use_param_layout ) for iteration in range(3): @@ -841,17 +907,18 @@ def test_overlap_param_gather_multi_iteration(self): msg=f"Parameters diverged at iteration {iteration}", ) - def test_overlap_param_gather_async_dispatch_and_finish(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_overlap_param_gather_async_dispatch_and_finish(self, use_param_layout): """Test async dispatch + finish_param_sync cycle (the actual runtime path). start_param_sync() (no force_sync) dispatches async all-gathers, then finish_param_sync() waits on the handle and unflattens gathered params. """ model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=True + overlap_param_gather=True, use_param_layout=use_param_layout ) ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=False, copy_from=model + overlap_param_gather=False, copy_from=model, use_param_layout=use_param_layout ) # Set identical gradients on both models @@ -900,14 +967,15 @@ def test_overlap_param_gather_async_dispatch_and_finish(self): msg=f"Parameter {name} differs between rank 0 and rank {i}", ) - def test_overlap_param_gather_finish_chains_next_bucket(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_overlap_param_gather_finish_chains_next_bucket(self, use_param_layout): """Test that finish_param_sync() dispatches next_param_gather_bucket_group. Uses a small bucket_size to force multiple bucket groups, then dispatches only the last bucket group and verifies that finishing it chains to the next. """ model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=True, bucket_size=2000 + overlap_param_gather=True, bucket_size=2000, use_param_layout=use_param_layout ) bucket_groups = model.bucket_groups @@ -915,7 +983,10 @@ def test_overlap_param_gather_finish_chains_next_bucket(self): pytest.skip("Need multiple bucket groups to test chaining") ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=False, copy_from=model, bucket_size=2000 + overlap_param_gather=False, + copy_from=model, + bucket_size=2000, + use_param_layout=use_param_layout, ) # Set identical gradients on both models @@ -961,17 +1032,18 @@ def test_overlap_param_gather_finish_chains_next_bucket(self): msg="Chained bucket finish produced different params than sync path", ) - def test_overlap_param_gather_forward_pre_hook(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_overlap_param_gather_forward_pre_hook(self, use_param_layout): """Test forward pre-hooks trigger finish_param_sync during model(input). After async dispatch, running model(input) fires forward pre-hooks that call finish_param_sync() on each bucket group, completing the param sync. """ model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=True + overlap_param_gather=True, use_param_layout=use_param_layout ) ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=False, copy_from=model + overlap_param_gather=False, copy_from=model, use_param_layout=use_param_layout ) # Set identical gradients on both models @@ -1002,7 +1074,8 @@ def test_overlap_param_gather_forward_pre_hook(self): msg="Forward pre-hook path produced different params than sync path", ) - def test_overlap_param_gather_grad_reduce_in_fp32(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_overlap_param_gather_grad_reduce_in_fp32(self, use_param_layout): """Regression test: grad_reduce_in_fp32 must not cause dtype mismatch in broadcasts. When grad_reduce_in_fp32=True, the grad buffer dtype is fp32 but broadcast @@ -1010,10 +1083,13 @@ def test_overlap_param_gather_grad_reduce_in_fp32(self): would cause a dtype mismatch error in the per-rank broadcast calls. """ model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=True, grad_reduce_in_fp32=True + overlap_param_gather=True, grad_reduce_in_fp32=True, use_param_layout=use_param_layout ) ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=False, copy_from=model, grad_reduce_in_fp32=True + overlap_param_gather=False, + copy_from=model, + grad_reduce_in_fp32=True, + use_param_layout=use_param_layout, ) # Set identical gradients on both models @@ -1040,17 +1116,18 @@ def test_overlap_param_gather_grad_reduce_in_fp32(self): msg="grad_reduce_in_fp32 path produced different params than reference", ) - def test_overlap_param_gather_hook_enable_disable_cycle(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_overlap_param_gather_hook_enable_disable_cycle(self, use_param_layout): """Test the training loop's hook lifecycle: disable → manual sync → enable → forward. The training loop disables hooks before iteration 1 (for initialization), then enables them for subsequent iterations. This test exercises that cycle. """ model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=True + overlap_param_gather=True, use_param_layout=use_param_layout ) ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=False, copy_from=model + overlap_param_gather=False, copy_from=model, use_param_layout=use_param_layout ) input_tensor = torch.randn(16, 80, dtype=torch.bfloat16, device='cuda') @@ -1102,7 +1179,8 @@ def test_overlap_param_gather_hook_enable_disable_cycle(self): msg="Params diverged after iteration 2 (hooks re-enabled)", ) - def test_overlap_param_gather_multi_iteration_with_hooks(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_overlap_param_gather_multi_iteration_with_hooks(self, use_param_layout): """Test multiple iterations using forward pre-hooks (not manual force_sync). Runs 3 iterations where each iteration uses: set grads → step → async dispatch → @@ -1110,10 +1188,10 @@ def test_overlap_param_gather_multi_iteration_with_hooks(self): allgather after each iteration. """ model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=True + overlap_param_gather=True, use_param_layout=use_param_layout ) ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=False, copy_from=model + overlap_param_gather=False, copy_from=model, use_param_layout=use_param_layout ) input_tensor = torch.randn(16, 80, dtype=torch.bfloat16, device='cuda') @@ -1144,7 +1222,8 @@ def test_overlap_param_gather_multi_iteration_with_hooks(self): msg=f"Parameters diverged at iteration {iteration}", ) - def test_overlap_param_gather_start_sync_with_autograd(self): + @pytest.mark.parametrize('use_param_layout', [False, True]) + def test_overlap_param_gather_start_sync_with_autograd(self, use_param_layout): """Regression test: start_param_sync must work when autograd is active. _flatten_dense_tensors on params with requires_grad=True produces a tensor @@ -1160,10 +1239,10 @@ def test_overlap_param_gather_start_sync_with_autograd(self): next bucket group. """ model, optimizer, pg_collection = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=True + overlap_param_gather=True, use_param_layout=use_param_layout ) ref_model, ref_optimizer, _ = self.create_model_and_optimizer_with_overlap_param_gather( - overlap_param_gather=False, copy_from=model + overlap_param_gather=False, copy_from=model, use_param_layout=use_param_layout ) # Confirm params require grad (the precondition for this bug). diff --git a/uv.lock b/uv.lock index 890fbb29034..482e3b1df79 100644 --- a/uv.lock +++ b/uv.lock @@ -2759,7 +2759,7 @@ requires-dist = [ { name = "torch", specifier = ">=2.6.0" }, { name = "tqdm", marker = "extra == 'dev'" }, { name = "tqdm", marker = "extra == 'lts'" }, - { name = "transformer-engine", extras = ["core-cu13", "pytorch"], marker = "extra == 'te'", git = "https://github.com/NVIDIA/TransformerEngine.git?rev=f031cf87bd054c7558b887df7bed93975456667f" }, + { name = "transformer-engine", extras = ["core-cu13", "pytorch"], marker = "extra == 'te'", git = "https://github.com/NVIDIA/TransformerEngine.git?rev=42b840051647eef89761a16dfdff87e82bb253ab" }, { name = "transformers", marker = "extra == 'mlm'" }, { name = "transformers", marker = "extra == 'training'" }, { name = "wandb", marker = "extra == 'mlm'" }, @@ -6749,8 +6749,8 @@ wheels = [ [[package]] name = "transformer-engine" -version = "2.14.0+f031cf87" -source = { git = "https://github.com/NVIDIA/TransformerEngine.git?rev=f031cf87bd054c7558b887df7bed93975456667f#f031cf87bd054c7558b887df7bed93975456667f" } +version = "2.15.0+42b84005" +source = { git = "https://github.com/NVIDIA/TransformerEngine.git?rev=42b840051647eef89761a16dfdff87e82bb253ab#42b840051647eef89761a16dfdff87e82bb253ab" } dependencies = [ { name = "einops" }, { name = "importlib-metadata" },