From 2b267a7335179fc07b786e794f84350f01582c1f Mon Sep 17 00:00:00 2001 From: root Date: Tue, 18 Nov 2025 10:20:26 -0800 Subject: [PATCH 01/44] initial test --- gpt_builders.py | 3 +- mamba_builders.py | 3 +- .../data_parallel_inference_coordinator.py | 2 +- megatron/core/model_refitting.py | 166 +++++++++ megatron/core/transformer/cuda_graphs.py | 2 +- megatron/core/transformer/utils.py | 4 +- megatron/rl/inference/megatron.py | 17 +- megatron/rl/rl_utils.py | 38 ++- megatron/training/arguments.py | 7 + megatron/training/training.py | 73 +++- model_provider.py | 4 +- .../inference/test_nccl_model_swap.py | 320 ++++++++++++++++++ 12 files changed, 613 insertions(+), 26 deletions(-) create mode 100644 megatron/core/model_refitting.py create mode 100644 tests/unit_tests/inference/test_nccl_model_swap.py diff --git a/gpt_builders.py b/gpt_builders.py index 89b228815ff..bd676c2ad9b 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -20,7 +20,7 @@ # NOTE: Loading `megatron.legacy.model` earlier fails due to circular import -def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None): +def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None): print_rank_0('building GPT model ...') if config is None: if args.yaml_cfg is not None: @@ -89,6 +89,7 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None): rope_scaling=args.use_rope_scaling, mtp_block_spec=mtp_block_spec, vp_stage=vp_stage, + pg_collection=pg_collection, ) return model diff --git a/mamba_builders.py b/mamba_builders.py index 0ccfc29b86c..53d675bc3cc 100644 --- a/mamba_builders.py +++ b/mamba_builders.py @@ -8,7 +8,7 @@ from megatron.training.arguments import core_transformer_config_from_args -def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None): +def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None): print_rank_0('building MAMBA model ...') if config is None: config = core_transformer_config_from_args(args, TransformerConfig) @@ -35,6 +35,7 @@ def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None): position_embedding_type=args.position_embedding_type, rotary_percent=args.rotary_percent, rotary_base=args.rotary_base, + pg_collection=pg_collection, ) for l in range(model.decoder.num_layers_per_pipeline_rank): diff --git a/megatron/core/inference/data_parallel_inference_coordinator.py b/megatron/core/inference/data_parallel_inference_coordinator.py index ea0560183d8..8644acaf3cc 100644 --- a/megatron/core/inference/data_parallel_inference_coordinator.py +++ b/megatron/core/inference/data_parallel_inference_coordinator.py @@ -23,7 +23,7 @@ except: HAVE_MSGPACK = False - +#TODO We need to see where the process group collection is used. class DataParallelInferenceCoordinator: """ Coordinates inference requests between clients and distributed model engines. diff --git a/megatron/core/model_refitting.py b/megatron/core/model_refitting.py new file mode 100644 index 00000000000..33c50d79a47 --- /dev/null +++ b/megatron/core/model_refitting.py @@ -0,0 +1,166 @@ +from megatron.core.models.common.language_module.language_module import LanguageModule +import torch +import torch.distributed as dist +from typing import Any +from megatron.core import parallel_state +from mcore_reshard import reshard_with_general_planner +from typing import Any, Optional + + + +def _unwrap_module(module: LanguageModule) -> Any: + return module.module.module if hasattr(module, 'module') and hasattr(module.module, 'module') else module.module if hasattr(module, 'module') else module + + +def _move_module(module: Any, device: torch.device | str) -> None: + for p in module.parameters(recurse=True): + if p is not None and p.data is not None: + p.data = p.data.to(device, non_blocking=True) + if p is not None and p._grad is not None: + p._grad = p._grad.to(device, non_blocking=True) + for buf_name, buf in module._buffers.items(): # type: ignore[attr-defined] + if buf is not None: + module._buffers[buf_name] = buf.to(device, non_blocking=True) # type: ignore[index] + + + +def naive_model_swap(src_model: LanguageModule, target_model: LanguageModule): + print(f"Swapping train to inference model") + # Handle list-wrapped modules used throughout training utils + src_lm = src_model[0] if isinstance(src_model, (list, tuple)) else src_model + target_lm = target_model[0] if isinstance(target_model, (list, tuple)) else target_model + + # Unwrap possible precision/wrapper modules to reach the module that owns parameters + src_model = _unwrap_module(src_lm) + target_model = _unwrap_module(target_lm) + + + src_tp_size = dist.get_world_size(src_model.pg_collection.tp) + target_tp_size = dist.get_world_size(target_model.pg_collection.tp) + src_tp_group = src_model.pg_collection.tp + target_tp_group = target_model.pg_collection.tp + + # Build name->param map for inference module + infer_params = {name: p for name, p in target_model.named_parameters(recurse=True)} + + + # Utility: reconstruct global tensor from TP sharded local tensors + def _gather_master_from_training(local_shard: torch.Tensor, dim: int, stride: int, group) -> torch.Tensor: + world_size = dist.get_world_size(group=group) + if world_size == 1: + return local_shard.detach().clone() + # Gather shards from all TP ranks + gather_list = [torch.empty_like(local_shard) for _ in range(world_size)] + dist.all_gather(gather_list, local_shard.contiguous(), group=group) + if stride == 1: + return torch.cat(gather_list, dim=dim).contiguous() + # Strided partition: split each shard into stride chunks and interleave + per_part = local_shard.size(dim) + assert per_part % stride == 0, "Local shard size must be divisible by stride" + per_stride = per_part // stride + blocks: list[torch.Tensor] = [None] * (world_size * stride) # type: ignore[assignment] + for r in range(world_size): + chunks = torch.split(gather_list[r], per_stride, dim=dim) + assert len(chunks) == stride, "Unexpected number of stride chunks" + for i in range(stride): + blocks[r + i * world_size] = chunks[i] + return torch.cat(blocks, dim=dim).contiguous() + + # Utility: shard master tensor to the target inference layout + def _shard_master_to_infer(master: torch.Tensor, dim: int, stride: int, group) -> torch.Tensor: + world_size = dist.get_world_size(group=group) + if world_size == 1: + return master + rank = dist.get_rank(group=group) + full = master + assert full.size(dim) % world_size == 0, "Master size not divisible by TP world size" + per_part = full.size(dim) // world_size + assert per_part % stride == 0, "Per-part size must be divisible by stride" + per_stride = per_part // stride + weight_list = torch.split(full, per_stride, dim=dim) + # Pick this rank's stride segments and concatenate along partition dim + my_chunks = weight_list[rank::world_size][:stride] + return torch.cat(my_chunks, dim=dim).contiguous() + + # Perform transfer + with torch.no_grad(): + # Simple same-TP copy path for models with the same TP size + if src_tp_size == target_tp_size: + for name, src in src_model.named_parameters(recurse=True): + if name not in infer_params: + raise ValueError(f"Parameter {name} in training model not found in inference model") + dst = infer_params[name] + if src.shape == dst.shape: + dst.copy_(src) + else: + raise ValueError(f"Parameter {name} in training model has different shape than in inference model") + return + + # General reshard path + for name, src in src_model.named_parameters(recurse=True): + dst = infer_params.get(name, None) + if dst is None: + raise ValueError(f"Parameter {name} in training model not found in inference model") + + # Non-TP params: direct copy + is_tp = bool(getattr(src, 'tensor_model_parallel', False)) + if not is_tp: + if src.shape != dst.shape: + # If shapes differ unexpectedly for non-TP, try to broadcast master (rank 0) value + #dst.copy_(src.detach().clone().to(dst.dtype)) + raise ValueError(f"Parameter {name} in training model has different shape than in inference model") + else: + dst.copy_(src) + continue + + # Resolve sharding attributes and groups + dim = int(getattr(src, 'partition_dim', 0)) + stride = int(getattr(src, 'partition_stride', 1)) + # Gather training shards -> master + master = _gather_master_from_training(src, dim=dim, stride=stride, group=src_tp_group) + + # Use dst's own TP attributes for target stride/dim if present + #TODO when do these just match the the original expecailly dim? + target_dim = int(getattr(dst, 'partition_dim', dim)) + target_stride = int(getattr(dst, 'partition_stride', stride)) + local_target = _shard_master_to_infer(master, dim=target_dim, stride=target_stride, group=target_tp_group) + # Cast and copy + if local_target.dtype != dst.dtype: + raise ValueError(f"Parameter {name} in training model has different dtype than in inference model") + # local_target = local_target.to(dst.dtype) + dst.copy_(local_target) + print(f"finished Swapped train to inference model") + + +def swap_model_weights(src_model: LanguageModule, target_model: LanguageModule, refit_method: str): + if refit_method == "naive": + naive_model_swap(src_model, target_model) + elif refit_method == "nccl": + nccl_model_swap(src_model, target_model) + else: + raise ValueError(f"Invalid refit method: {refit_method}") + +def nccl_model_swap(src_model: LanguageModule, target_model: LanguageModule): + # Handle list-wrapped modules used throughout training utils + src_lm = src_model[0] if isinstance(src_model, (list, tuple)) else src_model + tgt_lm = target_model[0] if isinstance(target_model, (list, tuple)) else target_model + + # Unwrap to get owning modules (with parameters and pg_collection) + src_core = _unwrap_module(src_lm) + tgt_core = _unwrap_module(tgt_lm) + + # Ensure pg_collection exists + if not hasattr(src_core, "pg_collection") or src_core.pg_collection is None: + raise RuntimeError("Source model missing pg_collection required for NCCL reshard") + if not hasattr(tgt_core, "pg_collection") or tgt_core.pg_collection is None: + raise RuntimeError("Target model missing pg_collection required for NCCL reshard") + + #TODO(Peter): We should figure out why this happens. + # Fill missing DP group on the source using Megatron's parallel state if not provided + if getattr(src_core.pg_collection, "dp", None) is None: + src_core.pg_collection.dp = parallel_state.get_data_parallel_group() + # caching plan for reuse + cached_plan: Optional[Any] = getattr(tgt_core, "_cached_reshard_plan", None) + plan = reshard_with_general_planner(src_core, tgt_core, cached_plan=cached_plan) + if cached_plan is None: + setattr(tgt_core, "_cached_reshard_plan", plan) \ No newline at end of file diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index f75eff7399a..63c23c2836e 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -58,7 +58,7 @@ logger = logging.getLogger(__name__) - +#TODO(Peter) We have changes needed in this for refit to work properly. def is_graph_capturing(): """Query if currently capturing.""" global _IS_GRAPH_CAPTURING diff --git a/megatron/core/transformer/utils.py b/megatron/core/transformer/utils.py index ac00e6557cf..4b9a0ec9e22 100644 --- a/megatron/core/transformer/utils.py +++ b/megatron/core/transformer/utils.py @@ -29,13 +29,13 @@ def get_linear_layer(rows, columns, init_method, perform_initialization=True): return layer -@lru_cache(maxsize=32) +#@lru_cache(maxsize=32) def get_default_causal_mask(sq: int) -> torch.Tensor: """Return the causal upper triangular mask for softmax input.""" return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() -@lru_cache(maxsize=32) +#@lru_cache(maxsize=32) def get_sliding_window_causal_mask(sq, skv, window_size): """Create the equivalent attention mask for SWA in [sq, skv] shape""" m = torch.ones(sq, skv, dtype=torch.bool, device="cuda") diff --git a/megatron/rl/inference/megatron.py b/megatron/rl/inference/megatron.py index 4c739b709c1..02c5847b5e2 100644 --- a/megatron/rl/inference/megatron.py +++ b/megatron/rl/inference/megatron.py @@ -100,6 +100,20 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule) -> Abst module = model.module.module if hasattr(model.module, "module") else model.module + + # DynamicInferenceContext must use the inference model's TP size, not the + # training TP size from global args. The inference model may have a custom + # ProcessGroupCollection with a different TP size. + pg_collection = get_attr_wrapped_model(model, "pg_collection") + tp_group = getattr(pg_collection, 'tp', None) if pg_collection is not None else None + if tp_group is not None: + inference_tp_size = dist.get_world_size(group=tp_group) + elif getattr(args, 'rl_inference_tensor_model_parallel_size', None) is not None: + inference_tp_size = args.rl_inference_tensor_model_parallel_size + else: + inference_tp_size = args.tensor_model_parallel_size + + # Inference context. inference_context = DynamicInferenceContext( params_dtype=args.params_dtype, @@ -116,7 +130,7 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule) -> Abst buffer_overflow_factor=args.inference_dynamic_batching_buffer_overflow_factor, max_requests_override=args.inference_dynamic_batching_max_requests_override, max_tokens_override=args.inference_dynamic_batching_max_tokens_override, - tensor_model_parallel_size=args.tensor_model_parallel_size, + tensor_model_parallel_size=inference_tp_size, materialize_only_last_token_logits=True, unified_memory_kvcache=args.inference_dynamic_batching_unified_memory_kvcache, is_hybrid_model=args.is_hybrid_model, @@ -197,6 +211,7 @@ async def launch(cls, model: GPTModel, **kwargs): ) inference_engine: DynamicInferenceEngine = get_dynamic_inference_engine(args, model) + # TODO(Peter) We need to pass the pg_collection to the coordinator, but like where is the coordinator even defined coordinator = DynamicEngineCoordinator( inference_engine, inference_max_requests=inference_engine.context.max_requests, diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index c0992778d57..7e9a2c39418 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -587,12 +587,13 @@ def get_rollout_generator(args, inference_interface, n_prompts, samples_per_grou def get_environment_rollouts( - model: LanguageModule, optimizer: MegatronOptimizer, n_prompts: int, samples_per_group: int + model: LanguageModule, inference_model: LanguageModule, optimizer: MegatronOptimizer, n_prompts: int, samples_per_group: int ): """Sample environment rollouts from an LLM. Args: model: Model to sample from. + inference_model: Inference model to use for inference. n_prompts: Number of prompts to sample for across *all* data parallel workers. samples_per_group: Amount of trajectories per prompt. @@ -602,6 +603,13 @@ def get_environment_rollouts( args = get_args() nvtx_range = get_nvtx_range() + # If we have seperate training and inference models we to refit weights from the training model to the inference model. + if inference_model is not None: + swap_train_to_inference_model(model, inference_model, args.refit_method) + else: + inference_model = model + + #TODO(peter): We need to get the models process group collection and use that for these checks assert ( n_prompts % mpu.get_expert_data_parallel_world_size() == 0 ), "n_prompts must be divisible by data_parallel_world_size" @@ -609,7 +617,7 @@ def get_environment_rollouts( with nvtx_range("rollout-collection"): loop = get_event_loop() with megatron_rl_inference_mode( - model, + inference_model, optimizer, args.enable_cuda_graph, args.rl_reset_cuda_graphs, @@ -649,6 +657,7 @@ def get_environment_rollouts( torch.distributed.broadcast_object_list(rollouts, src=0) print(f"Got rollouts on rank {rank}") + #TODO(Peter): We need to use the proper models MPU here. if lang_rl_log_dir and rank == get_tensor_model_parallel_src_rank(): with open( lang_rl_log_dir @@ -1930,6 +1939,7 @@ def prepare_data_for_update( def get_rollout_data_iterator( model: LanguageModule, + inference_model: LanguageModule | None, optimizer: MegatronOptimizer, iteration: int, ref_state_dict: Dict[str, torch.Tensor], @@ -1939,7 +1949,7 @@ def get_rollout_data_iterator( tokenizer = get_tokenizer() buffered_rollouts = get_environment_rollouts( - model, optimizer, args.grpo_prompts_per_step, args.grpo_group_size + model, inference_model, optimizer, args.grpo_prompts_per_step, args.grpo_group_size ) buffered_rollouts = prepare_data_for_update(model, ref_state_dict, buffered_rollouts, tokenizer) @@ -1948,6 +1958,7 @@ def get_rollout_data_iterator( def setup_grpo_data_iterator( model: LanguageModule, + inference_model: LanguageModule | None, optimizer: MegatronOptimizer, iteration: int, ref_state_dict: Dict[str, torch.Tensor], @@ -1968,13 +1979,18 @@ def setup_grpo_data_iterator( """ args = get_args() + if inference_model is not None: + inference_mpu = inference_model.pg_collection + else: + inference_mpu = mpu + # We collect new rollouts when we've gone over the collected data 'grpo_iterations' times. if ( iteration % (args.grpo_iterations * ((args.grpo_samples_per_iteration) // args.global_batch_size)) == 0 ): - buffered_rollouts = get_rollout_data_iterator(model, optimizer, iteration, ref_state_dict) + buffered_rollouts = get_rollout_data_iterator(model, inference_model, optimizer, iteration, ref_state_dict) # Reset packing step counter when new rollouts are collected runtime_state = get_rl_runtime_state() @@ -2006,7 +2022,7 @@ def setup_grpo_data_iterator( if bin_idx.item() < len(my_bin_seq_indices) ) # Estimate global sequences for this step - est_global_sequences = step_sequences * mpu.get_data_parallel_world_size() + est_global_sequences = step_sequences * inference_mpu.get_data_parallel_world_size() print_rank_0( f"[Sequence Packing] Optimizer step {plan['current_step']}/{plan['total_steps']}: " f"processing {len(step_bin_indices)} bins (~{est_global_sequences} sequences globally)" @@ -2366,9 +2382,21 @@ def get_sequence_packing_tensorboard_metrics(args): """Get tensorboard metrics for sequence packing mode.""" metrics = {} if args.consumed_train_bins > 0: + # TODO(Peter) We need to use the proper models MPU for refitting. If you forget you probably need to change this all over this + # file bin_batch_size = ( mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() ) metrics['bin-batch-size'] = bin_batch_size metrics['consumed-bins'] = args.consumed_train_bins return metrics + +def swap_train_to_inference_model(train_model: LanguageModule, inference_model: LanguageModule, refit_method: str): + """Swap the train model to the inference model. + + Args: + train_model: The train model to swap to the inference model. + inference_model: The inference model to swap to the train model. + """ + from megatron.core.model_refitting import swap_model_weights + swap_model_weights(train_model, inference_model, refit_method) \ No newline at end of file diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5d1bc2e40a3..ee9427bae68 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1972,6 +1972,13 @@ def _add_rl_args(parser): help='Algorithm for distributing packed bins across ranks. ' 'fifo: first-in-first-out sequential distribution, ' 'round-robin: distribute bins cyclically across ranks for better load balancing') + group.add_argument('--rl-inference-tensor-model-parallel-size', type=int, default=None, + help='Degree of tensor model parallelism for inference for RL.') + group.add_argument('--refit-method', type=str, default='naive', + choices=['naive', 'nccl'], + help=('Method to refit the model weights between training and inference models during RL. ' + 'naive: naive method to refit the model weights between training and inference models during RL. ' + 'nccl: use NCCLCopyService to refit the model weights between training and inference models during RL.')) return parser def _add_training_args(parser): diff --git a/megatron/training/training.py b/megatron/training/training.py index eb2c89fec9b..816ca435c33 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -53,6 +53,9 @@ get_model_config, StragglerDetector, ) +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.core.fp8_utils import correct_amax_history_if_needed from megatron.training.checkpointing import load_checkpoint from megatron.training.checkpointing import save_checkpoint @@ -671,6 +674,46 @@ def pretrain( print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') config = get_model_config(model[0]) + # Build a separate inference model for RL if requested. + inference_model = None + if args.perform_rl_step: + pg_collection = None + if args.rl_inference_tensor_model_parallel_size is not None: + print_rank_0(f"Setting tensor model parallel size to {args.rl_inference_tensor_model_parallel_size} for inference model") + # Build custom process groups for inference with a different TP size, keeping CP and PP the same as training + tp_size = args.rl_inference_tensor_model_parallel_size + cp_size = mpu.get_context_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + dp_size = args.world_size // (tp_size * cp_size * pp_size) + assert dp_size >= 1 and (tp_size * cp_size * pp_size * dp_size) == args.world_size, \ + "World size must be divisible by tp*cp*pp for inference PG layout" + + grid = HyperCommGrid([tp_size, cp_size, 1, pp_size, dp_size], ["tp", "cp", "ep", "pp", "dp"]) + tp_group = grid.create_pg("tp") + cp_group = grid.create_pg("cp") + pp_group = grid.create_pg("pp") + ep_group = grid.create_pg("ep") + dp_group = grid.create_pg("dp") + embd_group_ranks = mpu.default_embedding_ranks( + torch.distributed.get_process_group_ranks(pp_group) + ) + embd_group = torch.distributed.new_group(ranks=embd_group_ranks) + inference_pg_collection = ProcessGroupCollection(tp=tp_group, cp=cp_group, pp=pp_group, ep=ep_group, embd=embd_group, dp=dp_group) + + # Build an isolated inference config so training config remains unchanged + inference_config = copy.deepcopy(config) + inference_config.tensor_model_parallel_size = args.rl_inference_tensor_model_parallel_size + + inference_model = get_model( + model_provider, + model_type, + wrap_with_ddp=False, + pg_collection=inference_pg_collection, + config=inference_config, + ) + inference_model[0].eval() + + # Data stuff. app_metrics['app_build_dataiters_start_time'] = one_logger_utils.get_timestamp_in_ms() timers('train/valid/test-data-iterators-setup', log_level=0).start(barrier=True) @@ -745,6 +788,7 @@ def pretrain( config, checkpointing_context, non_loss_data_func, + inference_model, ) print_datetime('after training is done') @@ -850,31 +894,34 @@ def update_train_iters(args): print_rank_0(f'setting training iterations to {args.train_iters}') -def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True): +def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True, pg_collection=None, config=None): """Build the model.""" args = get_args() args.model_type = model_type + if pg_collection is None: + pg_collection = mpu + # Build model. def build_model(): if ( - mpu.get_pipeline_model_parallel_world_size() > 1 + pg_collection.get_pipeline_model_parallel_world_size() > 1 and args.virtual_pipeline_model_parallel_size is not None ): model = [] for i in range(args.virtual_pipeline_model_parallel_size): # Set pre_process and post_process only after virtual rank is set. - pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) - post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) + pre_process = pg_collection.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) + post_process = pg_collection.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) this_model = model_provider_func( - pre_process=pre_process, post_process=post_process, vp_stage=i) + pre_process=pre_process, post_process=post_process, vp_stage=i, pg_collection=pg_collection, config=config) this_model.model_type = model_type this_model.vp_stage = i model.append(this_model) else: - pre_process = mpu.is_pipeline_first_stage() - post_process = mpu.is_pipeline_last_stage() - model = model_provider_func(pre_process=pre_process, post_process=post_process) + pre_process = pg_collection.is_pipeline_first_stage() + post_process = pg_collection.is_pipeline_last_stage() + model = model_provider_func(pre_process=pre_process, post_process=post_process, pg_collection=pg_collection, config=config) model.model_type = model_type return model @@ -893,18 +940,19 @@ def build_model(): # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): + #TODO(Peter) We need to use the proper models MPU here. tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. num_parameters = sum( [sum([p.nelement() for p in model_module.parameters()]) for model_module in model] ) - if mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0: + if pg_collection.get_data_parallel_rank() == 0 and pg_collection.get_context_parallel_rank() == 0: print( ' > number of parameters on (tensor, pipeline) ' 'model parallel rank ({}, {}): {}'.format( - mpu.get_tensor_model_parallel_rank(), - mpu.get_pipeline_model_parallel_rank(), + pg_collection.get_tensor_model_parallel_rank(), + pg_collection.get_pipeline_model_parallel_rank(), num_parameters, ), flush=True, @@ -1952,6 +2000,7 @@ def train( config, checkpointing_context, non_loss_data_func, + inference_model=None, ): """Training function: run train_step desired number of times, run validation, checkpoint.""" args = get_args() @@ -2288,7 +2337,7 @@ def get_e2e_base_metrics(): if getattr(args, 'perform_rl_step', False): with torch.no_grad(): train_data_iterator = rl_utils.setup_grpo_data_iterator( - model, optimizer, iteration, ref_state_dict, buffered_rollouts + model, inference_model, optimizer, iteration, ref_state_dict, buffered_rollouts ) buffered_rollouts = train_data_iterator diff --git a/model_provider.py b/model_provider.py index 4d8b0daac71..ae91c972c82 100644 --- a/model_provider.py +++ b/model_provider.py @@ -23,7 +23,7 @@ def model_provider( - model_builder: Callable, pre_process=True, post_process=True, vp_stage: Optional[int] = None + model_builder: Callable, pre_process=True, post_process=True, vp_stage: Optional[int] = None, pg_collection=None, config=None ) -> Union[GPTModel, megatron.legacy.model.GPTModel, MambaModel]: """Builds the model. @@ -65,7 +65,7 @@ def oom_observer(device, alloc, device_alloc, device_free): torch._C._cuda_attach_out_of_memory_observer(oom_observer) - return model_builder(args, pre_process, post_process, vp_stage) + return model_builder(args, pre_process, post_process, vp_stage, pg_collection, config) def count_parameters_in_layer(model, layer_name): diff --git a/tests/unit_tests/inference/test_nccl_model_swap.py b/tests/unit_tests/inference/test_nccl_model_swap.py new file mode 100644 index 00000000000..e99c042cc6b --- /dev/null +++ b/tests/unit_tests/inference/test_nccl_model_swap.py @@ -0,0 +1,320 @@ +import os +import copy +import types +import pytest +import torch +import torch.distributed as dist + +from tests.unit_tests.test_utilities import Utils +from megatron.core.model_refitting import swap_model_weights +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec +from megatron.core import parallel_state as mpu +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.model_parallel_config import ModelParallelConfig +from mcore_reshard import reshard_with_general_planner +from typing import Tuple + + +def _build_pg_collection(tp_size: int, pp_size: int = None) -> ProcessGroupCollection: + cp_size = mpu.get_context_parallel_world_size() + if pp_size is None: + pp_size = mpu.get_pipeline_model_parallel_world_size() + world_size = dist.get_world_size() + dp_size = world_size // (tp_size * cp_size * pp_size) + assert dp_size >= 1 and (tp_size * cp_size * pp_size * dp_size) == world_size + + grid = HyperCommGrid([tp_size, cp_size, 1, pp_size, dp_size], ["tp", "cp", "ep", "pp", "dp"]) + tp_group = grid.create_pg("tp") + cp_group = grid.create_pg("cp") + pp_group = grid.create_pg("pp") + ep_group = grid.create_pg("ep") + dp_group = grid.create_pg("dp") + embd_group_ranks = mpu.default_embedding_ranks(dist.get_process_group_ranks(pp_group)) + embd_group = dist.new_group(ranks=embd_group_ranks) + return ProcessGroupCollection(tp=tp_group, cp=cp_group, pp=pp_group, ep=ep_group, embd=embd_group, dp=dp_group) + + +def _build_gpt(config: TransformerConfig, vocab_size: int, seq_len: int, pg_collection, parallel_output: bool = True) -> GPTModel: + model = GPTModel( + config=config, + transformer_layer_spec=get_gpt_layer_local_spec(), + vocab_size=vocab_size, + max_sequence_length=seq_len, + pre_process=True, + post_process=True, + fp16_lm_cross_entropy=False, + parallel_output=parallel_output, + share_embeddings_and_output_weights=True, + position_embedding_type="rope", + rotary_percent=1.0, + pg_collection=pg_collection, + ) + return model + + +def _mp_config() -> ModelParallelConfig: + return ModelParallelConfig( + params_dtype=torch.float32, + use_cpu_initialization=True, + sequence_parallel=False, + gradient_accumulation_fusion=False, + ) + + +def _set_pg_collection(module, tp_group, dp_group): + module.pg_collection = types.SimpleNamespace(tp=tp_group, dp=dp_group, ep=None, pp=None) + return module + +@pytest.mark.parametrize( + "src_tp,src_pp,dst_tp,dst_pp", + [ + (2, 1, 1, 1), # TP2 -> TP1 + (1, 1, 2, 1), # TP1 -> TP2 + (1, 2, 1, 1), # PP2 -> PP1 + (1, 1, 1, 2), # PP1 -> PP2 + (2, 2, 1, 1), # TP2,PP2 -> TP1,PP1 + (1, 1, 2, 2), # TP1,PP1 -> TP2,PP2 + (2, 1, 1, 2), # TP2,PP1 -> TP1,PP2 + (1, 2, 2, 1), # TP1,PP2 -> TP2,PP1 + ], +) +def test_nccl_swap_gpt_parametrized(src_tp: int, src_pp: int, dst_tp: int, dst_pp: int): + # Initialize environment with source MP sizing + Utils.initialize_model_parallel(tensor_model_parallel_size=src_tp, pipeline_model_parallel_size=src_pp) + # Validate divisibility post-init using the default PG safely + world = dist.get_world_size() + if (world % (src_tp * src_pp) != 0) or (world % (dst_tp * dst_pp) != 0): + Utils.destroy_model_parallel() + pytest.skip("WORLD_SIZE must be divisible by both src_tp*src_pp and dst_tp*dst_pp") + model_parallel_cuda_manual_seed(1234) + + torch.manual_seed(1234) + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + # Small GPT config + seq_len = 8 + vocab_size = 128 + cfg = TransformerConfig( + num_layers=4 if (src_pp > 1 or dst_pp > 1) else 2, + hidden_size=32, + num_attention_heads=4, + use_cpu_initialization=True, + pipeline_dtype=torch.float32, + hidden_dropout=0.0, + attention_dropout=0.0, + ) + + # Build PGs and models + src_pgs = ProcessGroupCollection.use_mpu_process_groups() + dst_pgs = _build_pg_collection(tp_size=dst_tp, pp_size=dst_pp) + # Use parallel_output=False to gather vocab-parallel outputs inside model and emit only on last PP stage + src_model = _build_gpt(copy.deepcopy(cfg), vocab_size, seq_len, src_pgs, parallel_output=False).to(device).eval() + dst_model = _build_gpt(copy.deepcopy(cfg), vocab_size, seq_len, dst_pgs, parallel_output=False).to(device).eval() + + # Inputs + batch = 2 + tokens = torch.randint(low=0, high=vocab_size, size=(batch, seq_len), device=device, dtype=torch.long) + position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch, -1) + attention_mask = torch.ones((batch, 1, seq_len, seq_len), device=device, dtype=torch.bool) + + # Collect source reference logits (parallel_output=False ensures full vocab on last PP stage) + ref_logits = torch.empty(batch, seq_len, vocab_size, device=device, dtype=torch.float32) + src_pp_ranks = dist.get_process_group_ranks(src_pgs.pp) + src_last_pp_rank = src_pp_ranks[-1] + with torch.no_grad(): + src_out = src_model(tokens, position_ids, attention_mask) + if dist.get_rank() == src_last_pp_rank: + ref = src_out # [b, s, vocab] + ref_logits.copy_(ref) + dist.broadcast(ref_logits, src=src_last_pp_rank, group=src_pgs.pp) + + # Swap weights + swap_model_weights([src_model], [dst_model], refit_method="nccl") + + # Collect destination logits (parallel_output=False ensures full vocab on last PP stage) + dst_logits = torch.empty(batch, seq_len, vocab_size, device=device, dtype=torch.float32) + dst_pp_ranks = dist.get_process_group_ranks(dst_pgs.pp) + dst_last_pp_rank = dst_pp_ranks[-1] + with torch.no_grad(): + dst_out = dst_model(tokens, position_ids, attention_mask) # last stage returns tensor, others return None + if dist.get_rank() == dst_last_pp_rank: + dst_logits.copy_(dst_out) # [b, s, vocab] + dist.broadcast(dst_logits, src=dst_last_pp_rank, group=dst_pgs.pp) + + # Compare + assert ref_logits.shape == dst_logits.shape + assert torch.allclose(dst_logits, ref_logits, atol=1e-4, rtol=1e-4), f"Refit src(TP={src_tp},PP={src_pp})->dst(TP={dst_tp},PP={dst_pp}) GPT outputs differ" + + dist.barrier() + Utils.destroy_model_parallel() + +# def test_nccl_swap_row_parallel_linear_tp2_to_tp1(): +# Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1) +# model_parallel_cuda_manual_seed(1234) +# device = torch.device(f"cuda:{torch.cuda.current_device()}") + +# # Build TP=2 source and TP=1 dest groups +# src_pgs = ProcessGroupCollection.use_mpu_process_groups() +# infer_pgs = _build_pg_collection(tp_size=1) + +# in_features = 12 +# out_features = 16 +# cfg = _mp_config() + +# # Source RowParallelLinear (TP=2), input_is_parallel=False so it scatters internally +# src_layer = RowParallelLinear( +# input_size=in_features, +# output_size=out_features, +# config=cfg, +# init_method=torch.nn.init.zeros_, +# bias=False, +# input_is_parallel=False, +# skip_bias_add=True, +# tp_group=src_pgs.tp, +# ).to(device) +# _set_pg_collection(src_layer, src_pgs.tp, src_pgs.dp) +# # Ensure TP metadata is present for planner (row-parallel shards input dim=1) +# src_layer.weight.tensor_model_parallel = True +# src_layer.weight.partition_dim = 1 +# src_layer.weight.partition_stride = 1 + +# # Deterministic per-rank weights (sharded along dim=1) +# rank = dist.get_rank(src_pgs.tp) +# with torch.no_grad(): +# src_layer.weight.copy_( +# torch.arange(src_layer.weight.numel(), device=device, dtype=torch.float32).reshape_as( +# src_layer.weight +# ) +# + rank * 1000.0 +# ) + +# # Destination RowParallelLinear (TP=1) +# dst_layer = RowParallelLinear( +# input_size=in_features, +# output_size=out_features, +# config=_mp_config(), +# init_method=torch.nn.init.zeros_, +# bias=False, +# input_is_parallel=False, +# skip_bias_add=True, +# tp_group=infer_pgs.tp, +# ).to(device) +# _set_pg_collection(dst_layer, infer_pgs.tp, infer_pgs.dp) +# # Destination is unsharded (TP=1) but keep metadata consistent +# dst_layer.weight.tensor_model_parallel = False +# dst_layer.weight.partition_dim = 1 +# dst_layer.weight.partition_stride = 1 + +# # Use layers directly to simplify parameter name matching +# src = src_layer +# dst = dst_layer +# # Attach pg_collection to layers so reshard can find process groups +# src.pg_collection = src_pgs +# dst.pg_collection = infer_pgs + +# # Input and reference (gather master weight along dim=1 from TP=2) +# x = torch.randn(4, in_features, device=device) +# parts = [torch.empty_like(src_layer.weight) for _ in range(dist.get_world_size(src_pgs.tp))] +# dist.all_gather(parts, src_layer.weight.contiguous(), group=src_pgs.tp) +# master_w = torch.cat(parts, dim=1).contiguous() # [out, in] +# ref = x @ master_w.t() + +# # Use resharder directly for per-layer validation and inspect plan +# plan = reshard_with_general_planner(src, dst) +# assert (len(plan.recv_ops) + len(plan.local_copy_ops)) > 0, "No transfers scheduled for RowParallelLinear" +# # Verify weights transferred correctly +# with torch.no_grad(): +# assert dst_layer.weight.shape == master_w.shape +# assert torch.allclose(dst_layer.weight, master_w, atol=1e-6, rtol=1e-6), "RowParallelLinear weights mismatch after transfer" +# y, _ = dst(x) +# assert torch.allclose(y, ref, atol=1e-4, rtol=1e-4), "RowParallelLinear TP2->TP1 mismatch" + +# dist.barrier() +# Utils.destroy_model_parallel() + +# def test_nccl_swap_column_parallel_linear_tp2_to_tp1(): +# Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1) +# model_parallel_cuda_manual_seed(1234) +# device = torch.device(f"cuda:{torch.cuda.current_device()}") + +# # Build TP=2 source and TP=1 dest groups +# src_pgs = ProcessGroupCollection.use_mpu_process_groups() +# infer_pgs = _build_pg_collection(tp_size=1) + +# in_features = 12 +# out_features = 16 +# cfg = _mp_config() + +# # Source ColumnParallelLinear (TP=2) +# src_layer = ColumnParallelLinear( +# input_size=in_features, +# output_size=out_features, +# config=cfg, +# init_method=torch.nn.init.zeros_, +# bias=False, +# gather_output=False, +# tp_group=src_pgs.tp, +# ).to(device) +# _set_pg_collection(src_layer, src_pgs.tp, src_pgs.dp) +# # Ensure TP metadata is present for planner +# src_layer.weight.tensor_model_parallel = True +# src_layer.weight.partition_dim = 0 +# src_layer.weight.partition_stride = 1 + +# # Deterministic per-rank weights +# rank = dist.get_rank(src_pgs.tp) +# with torch.no_grad(): +# src_layer.weight.copy_( +# torch.arange(src_layer.weight.numel(), device=device, dtype=torch.float32).reshape_as( +# src_layer.weight +# ) +# + rank * 1000.0 +# ) + +# # Destination ColumnParallelLinear (TP=1) +# dst_layer = ColumnParallelLinear( +# input_size=in_features, +# output_size=out_features, +# config=_mp_config(), +# init_method=torch.nn.init.zeros_, +# bias=False, +# gather_output=False, +# tp_group=infer_pgs.tp, +# ).to(device) +# _set_pg_collection(dst_layer, infer_pgs.tp, infer_pgs.dp) +# # Destination is unsharded (TP=1) but keep metadata consistent +# dst_layer.weight.tensor_model_parallel = False +# dst_layer.weight.partition_dim = 0 +# dst_layer.weight.partition_stride = 1 + +# # Use layers directly to simplify parameter name matching +# src = src_layer +# dst = dst_layer +# # Attach pg_collection to layers so reshard can find process groups +# src.pg_collection = src_pgs +# dst.pg_collection = infer_pgs + +# # Input and reference (gather master weight from TP=2) +# x = torch.randn(4, in_features, device=device) +# parts = [torch.empty_like(src_layer.weight) for _ in range(dist.get_world_size(src_pgs.tp))] +# dist.all_gather(parts, src_layer.weight.contiguous(), group=src_pgs.tp) +# master_w = torch.cat(parts, dim=0).contiguous() # [out, in] +# ref = x @ master_w.t() + +# # Use resharder directly for per-layer validation and inspect plan +# plan = reshard_with_general_planner(src, dst) +# assert (len(plan.recv_ops) + len(plan.local_copy_ops)) > 0, "No transfers scheduled for ColumnParallelLinear" +# # Verify weights transferred correctly +# with torch.no_grad(): +# assert dst_layer.weight.shape == master_w.shape +# assert torch.allclose(dst_layer.weight, master_w, atol=1e-6, rtol=1e-6), "ColumnParallelLinear weights mismatch after transfer" +# y, _ = dst(x) +# assert torch.allclose(y, ref, atol=1e-4, rtol=1e-4), "ColumnParallelLinear TP2->TP1 mismatch" + +# dist.barrier() +# Utils.destroy_model_parallel() \ No newline at end of file From 2f804dbd313b16ed6a379ed638a409c5bd659e90 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 24 Nov 2025 11:16:38 -0800 Subject: [PATCH 02/44] move eveything into megatron --- megatron/core/model_refitting.py | 138 +------ megatron/core/resharding/__init__.py | 17 + .../copy_services/nccl_copy_service.py | 64 ++++ megatron/core/resharding/execution.py | 71 ++++ megatron/core/resharding/planner.py | 354 ++++++++++++++++++ megatron/core/resharding/utils.py | 231 ++++++++++++ .../inference/test_nccl_model_swap.py | 297 +++++---------- 7 files changed, 849 insertions(+), 323 deletions(-) create mode 100644 megatron/core/resharding/__init__.py create mode 100644 megatron/core/resharding/copy_services/nccl_copy_service.py create mode 100644 megatron/core/resharding/execution.py create mode 100644 megatron/core/resharding/planner.py create mode 100644 megatron/core/resharding/utils.py diff --git a/megatron/core/model_refitting.py b/megatron/core/model_refitting.py index 33c50d79a47..d9aed2c93c4 100644 --- a/megatron/core/model_refitting.py +++ b/megatron/core/model_refitting.py @@ -3,7 +3,7 @@ import torch.distributed as dist from typing import Any from megatron.core import parallel_state -from mcore_reshard import reshard_with_general_planner +from megatron.core.resharding import build_centralized_reshard_plan, execute_reshard_plan from typing import Any, Optional @@ -11,131 +11,8 @@ def _unwrap_module(module: LanguageModule) -> Any: return module.module.module if hasattr(module, 'module') and hasattr(module.module, 'module') else module.module if hasattr(module, 'module') else module - -def _move_module(module: Any, device: torch.device | str) -> None: - for p in module.parameters(recurse=True): - if p is not None and p.data is not None: - p.data = p.data.to(device, non_blocking=True) - if p is not None and p._grad is not None: - p._grad = p._grad.to(device, non_blocking=True) - for buf_name, buf in module._buffers.items(): # type: ignore[attr-defined] - if buf is not None: - module._buffers[buf_name] = buf.to(device, non_blocking=True) # type: ignore[index] - - - -def naive_model_swap(src_model: LanguageModule, target_model: LanguageModule): - print(f"Swapping train to inference model") - # Handle list-wrapped modules used throughout training utils - src_lm = src_model[0] if isinstance(src_model, (list, tuple)) else src_model - target_lm = target_model[0] if isinstance(target_model, (list, tuple)) else target_model - - # Unwrap possible precision/wrapper modules to reach the module that owns parameters - src_model = _unwrap_module(src_lm) - target_model = _unwrap_module(target_lm) - - - src_tp_size = dist.get_world_size(src_model.pg_collection.tp) - target_tp_size = dist.get_world_size(target_model.pg_collection.tp) - src_tp_group = src_model.pg_collection.tp - target_tp_group = target_model.pg_collection.tp - - # Build name->param map for inference module - infer_params = {name: p for name, p in target_model.named_parameters(recurse=True)} - - - # Utility: reconstruct global tensor from TP sharded local tensors - def _gather_master_from_training(local_shard: torch.Tensor, dim: int, stride: int, group) -> torch.Tensor: - world_size = dist.get_world_size(group=group) - if world_size == 1: - return local_shard.detach().clone() - # Gather shards from all TP ranks - gather_list = [torch.empty_like(local_shard) for _ in range(world_size)] - dist.all_gather(gather_list, local_shard.contiguous(), group=group) - if stride == 1: - return torch.cat(gather_list, dim=dim).contiguous() - # Strided partition: split each shard into stride chunks and interleave - per_part = local_shard.size(dim) - assert per_part % stride == 0, "Local shard size must be divisible by stride" - per_stride = per_part // stride - blocks: list[torch.Tensor] = [None] * (world_size * stride) # type: ignore[assignment] - for r in range(world_size): - chunks = torch.split(gather_list[r], per_stride, dim=dim) - assert len(chunks) == stride, "Unexpected number of stride chunks" - for i in range(stride): - blocks[r + i * world_size] = chunks[i] - return torch.cat(blocks, dim=dim).contiguous() - - # Utility: shard master tensor to the target inference layout - def _shard_master_to_infer(master: torch.Tensor, dim: int, stride: int, group) -> torch.Tensor: - world_size = dist.get_world_size(group=group) - if world_size == 1: - return master - rank = dist.get_rank(group=group) - full = master - assert full.size(dim) % world_size == 0, "Master size not divisible by TP world size" - per_part = full.size(dim) // world_size - assert per_part % stride == 0, "Per-part size must be divisible by stride" - per_stride = per_part // stride - weight_list = torch.split(full, per_stride, dim=dim) - # Pick this rank's stride segments and concatenate along partition dim - my_chunks = weight_list[rank::world_size][:stride] - return torch.cat(my_chunks, dim=dim).contiguous() - - # Perform transfer - with torch.no_grad(): - # Simple same-TP copy path for models with the same TP size - if src_tp_size == target_tp_size: - for name, src in src_model.named_parameters(recurse=True): - if name not in infer_params: - raise ValueError(f"Parameter {name} in training model not found in inference model") - dst = infer_params[name] - if src.shape == dst.shape: - dst.copy_(src) - else: - raise ValueError(f"Parameter {name} in training model has different shape than in inference model") - return - - # General reshard path - for name, src in src_model.named_parameters(recurse=True): - dst = infer_params.get(name, None) - if dst is None: - raise ValueError(f"Parameter {name} in training model not found in inference model") - - # Non-TP params: direct copy - is_tp = bool(getattr(src, 'tensor_model_parallel', False)) - if not is_tp: - if src.shape != dst.shape: - # If shapes differ unexpectedly for non-TP, try to broadcast master (rank 0) value - #dst.copy_(src.detach().clone().to(dst.dtype)) - raise ValueError(f"Parameter {name} in training model has different shape than in inference model") - else: - dst.copy_(src) - continue - - # Resolve sharding attributes and groups - dim = int(getattr(src, 'partition_dim', 0)) - stride = int(getattr(src, 'partition_stride', 1)) - # Gather training shards -> master - master = _gather_master_from_training(src, dim=dim, stride=stride, group=src_tp_group) - - # Use dst's own TP attributes for target stride/dim if present - #TODO when do these just match the the original expecailly dim? - target_dim = int(getattr(dst, 'partition_dim', dim)) - target_stride = int(getattr(dst, 'partition_stride', stride)) - local_target = _shard_master_to_infer(master, dim=target_dim, stride=target_stride, group=target_tp_group) - # Cast and copy - if local_target.dtype != dst.dtype: - raise ValueError(f"Parameter {name} in training model has different dtype than in inference model") - # local_target = local_target.to(dst.dtype) - dst.copy_(local_target) - print(f"finished Swapped train to inference model") - - def swap_model_weights(src_model: LanguageModule, target_model: LanguageModule, refit_method: str): - if refit_method == "naive": - naive_model_swap(src_model, target_model) - elif refit_method == "nccl": + if refit_method == "nccl": nccl_model_swap(src_model, target_model) else: raise ValueError(f"Invalid refit method: {refit_method}") @@ -145,6 +22,8 @@ def nccl_model_swap(src_model: LanguageModule, target_model: LanguageModule): src_lm = src_model[0] if isinstance(src_model, (list, tuple)) else src_model tgt_lm = target_model[0] if isinstance(target_model, (list, tuple)) else target_model + num_experts = src_lm.config.num_moe_experts + # Unwrap to get owning modules (with parameters and pg_collection) src_core = _unwrap_module(src_lm) tgt_core = _unwrap_module(tgt_lm) @@ -155,12 +34,15 @@ def nccl_model_swap(src_model: LanguageModule, target_model: LanguageModule): if not hasattr(tgt_core, "pg_collection") or tgt_core.pg_collection is None: raise RuntimeError("Target model missing pg_collection required for NCCL reshard") - #TODO(Peter): We should figure out why this happens. + #TODO(Peter): We should figure out why this happens. Seems like a bug in Orthotope. # Fill missing DP group on the source using Megatron's parallel state if not provided if getattr(src_core.pg_collection, "dp", None) is None: src_core.pg_collection.dp = parallel_state.get_data_parallel_group() # caching plan for reuse cached_plan: Optional[Any] = getattr(tgt_core, "_cached_reshard_plan", None) - plan = reshard_with_general_planner(src_core, tgt_core, cached_plan=cached_plan) if cached_plan is None: - setattr(tgt_core, "_cached_reshard_plan", plan) \ No newline at end of file + plan = build_centralized_reshard_plan(src_core, tgt_core, num_experts=num_experts) + setattr(tgt_core, "_cached_reshard_plan", plan) + else: + plan = cached_plan + execute_reshard_plan(plan, src_core, tgt_core) \ No newline at end of file diff --git a/megatron/core/resharding/__init__.py b/megatron/core/resharding/__init__.py new file mode 100644 index 00000000000..cb06ddebe2e --- /dev/null +++ b/megatron/core/resharding/__init__.py @@ -0,0 +1,17 @@ +from .planner import build_centralized_reshard_plan +from .execution import execute_reshard_plan +from .utils import ( + ParameterMetadata, + ShardingDescriptor, + TransferOp, + ReshardPlan, +) + +__all__ = [ + "build_centralized_reshard_plan", + "execute_reshard_plan", + "ParameterMetadata", + "ShardingDescriptor", + "TransferOp", + "ReshardPlan", +] diff --git a/megatron/core/resharding/copy_services/nccl_copy_service.py b/megatron/core/resharding/copy_services/nccl_copy_service.py new file mode 100644 index 00000000000..c81a05c80dc --- /dev/null +++ b/megatron/core/resharding/copy_services/nccl_copy_service.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import List + +import torch +import torch.distributed as dist + + +logger = logging.getLogger(__name__) + + +@dataclass +class SendOp: + tensor: torch.Tensor + dest_rank: int + + +@dataclass +class RecvOp: + tensor: torch.Tensor + src_rank: int + + +class NCCLCopyService: + """ + Thin wrapper around torch.distributed batch_isend_irecv to submit and execute + a batch of point-to-point sends and recvs. + """ + + def __init__(self): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.send_ops: List[SendOp] = [] + self.recv_ops: List[RecvOp] = [] + logger.info(f"NCCLCopyService initialized with {self.world_size} ranks") + + def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): + self.send_ops.append(SendOp(tensor=src_tensor, dest_rank=dest_rank)) + + def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): + self.recv_ops.append(RecvOp(tensor=dest_tensor, src_rank=src_rank)) + + def run(self): + total_ops = len(self.send_ops) + len(self.recv_ops) + logger.info(f"Executing batched communication: {len(self.send_ops)} sends + {len(self.recv_ops)} recvs = {total_ops} ops") + + p2p_ops = [] + for op in self.send_ops: + p2p_ops.append(dist.P2POp(dist.isend, op.tensor, op.dest_rank)) + for op in self.recv_ops: + p2p_ops.append(dist.P2POp(dist.irecv, op.tensor, op.src_rank)) + + if p2p_ops: + reqs = dist.batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + logger.info("Batched communication completed") + self.send_ops.clear() + self.recv_ops.clear() + + diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py new file mode 100644 index 00000000000..6a12e01fe08 --- /dev/null +++ b/megatron/core/resharding/execution.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import logging +from typing import List, Tuple + +import torch +import torch.distributed as dist + +from .utils import ReshardPlan +from .copy_services.nccl_copy_service import NCCLCopyService + + +logger = logging.getLogger(__name__) + + +def execute_reshard_plan( + plan: ReshardPlan, + src_module: torch.nn.Module, + dst_module: torch.nn.Module, +) -> None: + """Execute a reshard plan (from centralized controller).""" + service = NCCLCopyService() + + src_params = {name: p for name, p in src_module.named_parameters(recurse=True)} + dst_params = {name: p for name, p in dst_module.named_parameters(recurse=True)} + + #TODO(Peter) do this on like a separate stream? + # Execute local copies + for param_name, src_param, dst_param, src_slice, dst_slice in plan.local_copy_ops: + if src_param is None: + src_param = src_params.get(param_name) + if dst_param is None: + dst_param = dst_params.get(param_name) + if src_param is not None and dst_param is not None: + with torch.no_grad(): + src_view = src_param.data[src_slice] + dst_view = dst_param.data[dst_slice] + dst_view.copy_(src_view) + + # Submit sends + for op in plan.send_ops: + src_param = src_params.get(op.param_name) + if src_param is not None: + src_view = src_param.data[op.my_slice].contiguous() + service.submit_send(src_view, op.peer_rank) + + # Submit recvs + recv_writebacks: List[Tuple[torch.Tensor, torch.nn.Parameter, tuple[slice, ...]]] = [] + for op in plan.recv_ops: + dst_param = dst_params.get(op.param_name) + if dst_param is not None: + dst_slice_view = dst_param.data[op.my_slice] + recv_buffer = torch.empty_like(dst_slice_view.contiguous()) + service.submit_recv(recv_buffer, op.peer_rank) + recv_writebacks.append((recv_buffer, dst_param, op.my_slice)) + + # Execute + logger.info(f"Executing {len(plan.send_ops)} sends + {len(plan.recv_ops)} recvs") + service.run() + #TODO(Peter) remove this eventually? + dist.barrier() + torch.cuda.synchronize() + + # Write back received buffers into their destination parameter slices + for recv_buffer, dst_param, dst_slice in recv_writebacks: + with torch.no_grad(): + dst_param.data[dst_slice].copy_(recv_buffer) + + logger.info("Reshard complete") + + diff --git a/megatron/core/resharding/planner.py b/megatron/core/resharding/planner.py new file mode 100644 index 00000000000..b134c15cf3d --- /dev/null +++ b/megatron/core/resharding/planner.py @@ -0,0 +1,354 @@ +from __future__ import annotations + +import logging +import math +import sys +import traceback +from typing import Optional + +import torch +import torch.distributed as dist + +from .utils import ( + ParameterMetadata, + ShardingDescriptor, + TransferOp, + ReshardPlan, + _get_rank_in_group, + extract_param_metadata, + select_src_metadata_balanced, +) + + +logger = logging.getLogger(__name__) + + +def _build_descriptors_for_param( + src_metadata: ParameterMetadata, + dst_metadata: ParameterMetadata, +) -> list[ShardingDescriptor]: + """Construct sharding descriptors (currently TP) for this parameter based on actual layout. + Guard TP descriptor with size conservation so we don't mis-classify replicated tensors. + """ + descriptors: list[ShardingDescriptor] = [] + + # TP descriptor: allow when either side participates in TP + if src_metadata.is_tp or dst_metadata.is_tp: + # Prefer destination partition_dim, else source + tp_dim = dst_metadata.partition_dim if dst_metadata.is_tp else src_metadata.partition_dim + src_tp_ranks = src_metadata.tensor_parallel_group_ranks + dst_tp_ranks = dst_metadata.tensor_parallel_group_ranks + if src_tp_ranks is None or dst_tp_ranks is None: + # Not enough context to build TP descriptor + return descriptors + src_stride = src_metadata.partition_stride if src_metadata.is_tp else 1 + dst_stride = dst_metadata.partition_stride if dst_metadata.is_tp else 1 + + # Size conservation check on partition dim + src_world = len(src_tp_ranks) + dst_world = len(dst_tp_ranks) + src_local = src_metadata.shape[tp_dim] + dst_local = dst_metadata.shape[tp_dim] + if src_world * src_local != dst_world * dst_local: + # Not truly TP-sharded for this param; let DP handle it + logger.debug( + f"Skipping TP descriptor for {dst_metadata.name} dim{tp_dim}: " + f"src_world*src_local={src_world}*{src_local} != {dst_world}*{dst_local}" + ) + return descriptors + + descriptors.append( + ShardingDescriptor( + name="tp", + dim=tp_dim, + src_stride=src_stride, + dst_stride=dst_stride, + src_dim_ranks=src_tp_ranks, + dst_dim_ranks=dst_tp_ranks, + ) + ) + return descriptors + + +def _plan_multi_dim_lcm( + param_name: str, + src_metadata: ParameterMetadata, + dst_metadata: ParameterMetadata, + descriptors: list[ShardingDescriptor], + my_global_rank: int, +) -> list[tuple[int, tuple[slice, ...], tuple[slice, ...]]]: + """ + TP-only planner using LCM tiling to support strides on source/destination. + - Requires exactly one TP descriptor + - Supports arbitrary integer strides (contiguous micro-tiles) + """ + if not descriptors: + return [] + if len(descriptors) != 1: + raise NotImplementedError(f"{param_name}: _plan_multi_dim_lcm supports TP-only (one descriptor)") + if descriptors[0].name != "tp": + raise NotImplementedError(f"{param_name}: _plan_multi_dim_lcm expects TP descriptor") + d = descriptors[0] + if my_global_rank not in d.dst_dim_ranks: + return [] + src_shape = tuple(src_metadata.shape) + dst_shape = tuple(dst_metadata.shape) + dim = d.dim + src_world = len(d.src_dim_ranks) + dst_world = len(d.dst_dim_ranks) + src_local = src_shape[dim] + dst_local = dst_shape[dim] + if src_world * src_local != dst_world * dst_local: + raise RuntimeError( + f"{param_name}: size mismatch on TP dim{dim} " + f"(src_world={src_world}, src_local={src_local}, dst_world={dst_world}, dst_local={dst_local})" + ) + # LCM tiling with strides + Ns = src_world * max(1, d.src_stride) + Nd = dst_world * max(1, d.dst_stride) + full_len = dst_local * dst_world + g = math.gcd(Ns, Nd) + L = (Ns // g) * Nd + if full_len % L != 0: + raise RuntimeError( + f"{param_name}: TP dim{dim} full_len {full_len} not divisible by LCM {L} " + f"(Ns={Ns}, Nd={Nd})" + ) + unit = full_len // L # micro-tile length + cps = L // Ns # micro-tiles per source segment + cpd = L // Nd # micro-tiles per destination segment + seg_src = cps * unit # contiguous length per source segment + seg_dst = cpd * unit # contiguous length per destination segment + dst_local_rank = _get_rank_in_group(my_global_rank, d.dst_dim_ranks) + ops: list[tuple[int, tuple[slice, ...], tuple[slice, ...]]] = [] + # Sweep destination segments owned by this rank (handle destination stride) + for k in range(max(1, d.dst_stride)): + g_dst_seg = dst_local_rank + k * dst_world + # Within this segment, enumerate the cpd micro-tiles + for off in range(cpd): + g_micro = g_dst_seg * cpd + off + s_idx = g_micro // cps + in_seg = g_micro % cps + src_owner_in_dim = s_idx % src_world + src_global_rank = d.src_dim_ranks[src_owner_in_dim] + src_local_seg_idx = s_idx // src_world + src_start = src_local_seg_idx * seg_src + in_seg * unit + dst_start = k * seg_dst + off * unit + # Build full N-D slices + src_slice = [slice(None)] * len(src_shape) + dst_slice = [slice(None)] * len(dst_shape) + src_slice[dim] = slice(src_start, src_start + unit) + dst_slice[dim] = slice(dst_start, dst_start + unit) + ops.append((src_global_rank, tuple(src_slice), tuple(dst_slice))) + # Stable order by destination offset + def dst_key(op): + _, _, dsl = op + s = dsl[dim] + return s.start if isinstance(s, slice) else 0 + + ops.sort(key=dst_key) + return ops + + +def _plan_dp_recv( + param_name: str, + src_metadata: ParameterMetadata, + dst_metadata: ParameterMetadata, + my_global_rank: int, +) -> list[tuple[int, tuple[slice, ...], tuple[slice, ...]]]: + """Plan DP transfer for a replicated (non-TP) parameter (receiver side).""" + dst_dp_ranks = dst_metadata.data_parallel_group_ranks + src_dp_ranks = src_metadata.data_parallel_group_ranks + if my_global_rank not in dst_dp_ranks: + return [] + + my_dst_dp_rank = _get_rank_in_group(my_global_rank, dst_dp_ranks) + dst_shape = dst_metadata.shape + + # Same DP layout - local copy + if src_dp_ranks == dst_dp_ranks: + full_slice = tuple(slice(None) for _ in range(len(dst_shape))) + return [(my_global_rank, full_slice, full_slice)] + + # Different DP groups - use round-robin for load balancing + src_global_rank = src_dp_ranks[my_dst_dp_rank % len(src_dp_ranks)] + full_slice = tuple(slice(None) for _ in range(len(dst_shape))) + return [(src_global_rank, full_slice, full_slice)] + + +def _determine_source_ranks_for_dst_param( + param_name: str, + src_metadata: ParameterMetadata, + dst_metadata: ParameterMetadata, + my_global_rank: int, +) -> list[tuple[int, tuple[slice, ...], tuple[slice, ...]]]: + """Route to dimension-specific planner based on parameter sharding type.""" + + # PP filtering (simple, symmetric) + src_pp_ranks = src_metadata.pipeline_parallel_group_ranks + dst_pp_ranks = dst_metadata.pipeline_parallel_group_ranks + if len(dst_pp_ranks) > 1 and my_global_rank not in dst_pp_ranks: + return [] + if len(src_pp_ranks) > 1 and len(dst_pp_ranks) > 1: + my_dst_pp_rank = _get_rank_in_group(my_global_rank, dst_pp_ranks) + if my_dst_pp_rank >= len(src_pp_ranks): + return [] + + # Regular TP/DP planning with EP-resolved metadata + descriptors = _build_descriptors_for_param(src_metadata=src_metadata, dst_metadata=dst_metadata) + if descriptors: + return _plan_multi_dim_lcm( + param_name=param_name, + src_metadata=src_metadata, + dst_metadata=dst_metadata, + descriptors=descriptors, + my_global_rank=my_global_rank, + ) + # DP / replicated fallback + return _plan_dp_recv(param_name, src_metadata, dst_metadata, my_global_rank) + + +def build_centralized_reshard_plan( + src_module: torch.nn.Module, + dst_module: torch.nn.Module, + num_experts: int = None, + validate_config: bool = True, +) -> ReshardPlan: + """ + Centralized planning: Rank 0 builds complete plan for all ranks, then scatters. + """ + my_global_rank = dist.get_rank() + world_size = dist.get_world_size() + + # Get process groups + src_pg = getattr(src_module, "pg_collection", None) + dst_pg = getattr(dst_module, "pg_collection", None) + if src_pg is None or dst_pg is None: + raise ValueError("Both modules must have pg_collection") + if not hasattr(src_pg, 'dp'): + raise ValueError("src_pg must have dp process group") + + src_num_experts = num_experts + dst_num_experts = num_experts + + # Gather param metadata from all ranks + my_src_params = {name: p for name, p in src_module.named_parameters(recurse=True)} + my_dst_params = {name: p for name, p in dst_module.named_parameters(recurse=True)} + + my_src_metadata = [ + extract_param_metadata(p, name, my_global_rank, src_pg, num_experts=src_num_experts) + for name, p in my_src_params.items() + ] + my_dst_metadata = [ + extract_param_metadata(p, name, my_global_rank, dst_pg, num_experts=dst_num_experts) + for name, p in my_dst_params.items() + ] + + all_src_metadata_by_rank = [None] * world_size + all_dst_metadata_by_rank = [None] * world_size + dist.all_gather_object(all_src_metadata_by_rank, my_src_metadata) + dist.all_gather_object(all_dst_metadata_by_rank, my_dst_metadata) + + # Parameter to metadata maps keyed by resolved_name + src_param_metadata_by_rank = {} + dst_param_metadata_by_rank = {} + src_param_metadata: dict[str, list[ParameterMetadata]] = {} + + for rank_id, rank_metadata_list in enumerate(all_src_metadata_by_rank): + src_param_metadata_by_rank[rank_id] = {m.resolved_name: m for m in rank_metadata_list} + for rank_id, rank_metadata_list in enumerate(all_dst_metadata_by_rank): + dst_param_metadata_by_rank[rank_id] = {m.resolved_name: m for m in rank_metadata_list} + for rank_metadata_list in all_src_metadata_by_rank: + for metadata in rank_metadata_list: + key = metadata.resolved_name + if key not in src_param_metadata: + src_param_metadata[key] = [] + src_param_metadata[key].append(metadata) + + # Build the plan on global rank 0 and broadcast to all ranks with error propagation + if my_global_rank == 0: + error_box = [None] + plans_for_all_ranks = {r: ReshardPlan([], [], []) for r in range(world_size)} + try: + for dst_rank in range(world_size): + dst_rank_params = dst_param_metadata_by_rank.get(dst_rank, {}) + for resolved_name, dst_metadata in dst_rank_params.items(): + src_meta_list = src_param_metadata.get(resolved_name) + if not src_meta_list: + raise RuntimeError( + f"Destination parameter '{resolved_name}' on rank {dst_rank} not found in source model." + ) + # Choose a representative source metadata with DP round-robin balancing + src_metadata = select_src_metadata_balanced(src_meta_list, dst_metadata, dst_rank) + sources = _determine_source_ranks_for_dst_param( + resolved_name, src_metadata, dst_metadata, dst_rank + ) + for src_rank, src_slice, dst_slice in sources: + if src_rank == dst_rank and src_metadata.name == dst_metadata.name: + plans_for_all_ranks[dst_rank].local_copy_ops.append( + (dst_metadata.name, None, None, src_slice, dst_slice) + ) + else: + plans_for_all_ranks[dst_rank].recv_ops.append( + TransferOp( + param_name=dst_metadata.name, + peer_rank=src_rank, + is_send=False, + my_slice=dst_slice, + peer_slice=src_slice, + ) + ) + plans_for_all_ranks[src_rank].send_ops.append( + TransferOp( + param_name=src_metadata.name, + peer_rank=dst_rank, + is_send=True, + my_slice=src_slice, + peer_slice=dst_slice, + ) + ) + plans_list = [plans_for_all_ranks[r] for r in range(world_size)] + except Exception as e: + tb = traceback.format_exc() + error_box[0] = { + "rank": my_global_rank, + "param": resolved_name if 'resolved_name' in locals() else None, + "type": type(e).__name__, + "msg": str(e), + "traceback": tb, + } + plans_list = [None] * world_size + dist.broadcast_object_list(error_box, src=0) + else: + error_box = [None] + plans_list = [None] * world_size + dist.broadcast_object_list(error_box, src=0) + if error_box[0] is not None: + err = error_box[0] + print( + f"[Reshard Planner] Aborting due to error on rank {err['rank']} while planning {err['param']}: " + f"{err['type']}: {err['msg']}" + ) + print(err["traceback"]) + sys.stdout.flush() + raise RuntimeError(f"Reshard plan failed on rank {err['rank']} for {err['param']}: {err['msg']}") + torch.distributed.barrier() + torch.distributed.broadcast_object_list(plans_list, src=0) + my_plan = plans_list[my_global_rank] + torch.distributed.barrier() + + # Fill in actual parameter references for local copies + for i, (param_name, _, _, src_slice, dst_slice) in enumerate(my_plan.local_copy_ops): + src_param = my_src_params.get(param_name) + dst_param = my_dst_params.get(param_name) + if src_param is not None and dst_param is not None: + my_plan.local_copy_ops[i] = (param_name, src_param, dst_param, src_slice, dst_slice) + + logger.info( + f"Rank {my_global_rank}: Received plan - {len(my_plan.recv_ops)} recvs, " + f"{len(my_plan.send_ops)} sends, {len(my_plan.local_copy_ops)} local copies" + ) + + return my_plan + + diff --git a/megatron/core/resharding/utils.py b/megatron/core/resharding/utils.py new file mode 100644 index 00000000000..fd4d070e18f --- /dev/null +++ b/megatron/core/resharding/utils.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.distributed as dist + + +# ----------------------------------------------------------------------------- +# Dataclasses used by the planner +# ----------------------------------------------------------------------------- + + +@dataclass +class TransferOp: + param_name: str + peer_rank: int # Who to send to / receive from + is_send: bool # True=send, False=recv + + # Slice information (for when we execute the plan) + my_slice: tuple[slice, ...] # My tensor slice + peer_slice: tuple[slice, ...] # Peer's tensor slice (for reference) + + +@dataclass +class ParameterMetadata: + """Metadata for a parameter (used when param is on different rank).""" + + name: str + shape: tuple[int, ...] + dtype: torch.dtype + element_size: int + + # TP sharding info + is_tp: bool = False + partition_dim: int = 0 + partition_stride: int = 1 + + # EP sharding info (fused/grouped MoE) + is_ep: bool = False + num_experts: Optional[int] = None + + # Which rank owns this param + owner_rank: int = -1 + + tensor_parallel_group_ranks: list[int] | None = None + expert_parallel_group_ranks: list[int] | None = None + data_parallel_group_ranks: list[int] | None = None + pipeline_parallel_group_ranks: list[int] | None = None + + # Canonicalization for EP per-expert params + resolved_name: Optional[str] = None + global_expert_index: Optional[int] = None + + +@dataclass +class ShardingDescriptor: + """Descriptor for a sharded dimension for a parameter.""" + + name: str # "tp" | "ep" | custom label + dim: int + src_stride: int + dst_stride: int + src_dim_ranks: list[int] + dst_dim_ranks: list[int] + + +@dataclass +class ReshardPlan: + """Reshard plan - operations for this rank.""" + + send_ops: list[TransferOp] + recv_ops: list[TransferOp] + local_copy_ops: list[ + tuple[str, torch.nn.Parameter | None, torch.nn.Parameter | None, tuple[slice, ...], tuple[slice, ...]] + ] # (name, src_param, dst_param, src_slice, dst_slice) + + def __str__(self): + return ( + f"ReshardPlan(sends={len(self.send_ops)}, recvs={len(self.recv_ops)}, " + f"local_copies={len(self.local_copy_ops)})" + ) + + +# ----------------------------------------------------------------------------- +# EP + Metadata helpers +# ----------------------------------------------------------------------------- + + +def _get_rank_in_group(global_rank: int, group_ranks: list[int]) -> int: + try: + return group_ranks.index(global_rank) + except ValueError: + raise ValueError( + f"Rank {global_rank} not found in process group {group_ranks}. " + f"This likely indicates a configuration mismatch." + ) + + +def _detect_expert_index_from_param_name(param_name: str) -> Optional[int]: + """Extract expert index from parameter name for TEGroupedMLP per-expert tensors.""" + for part in param_name.split('.'): + if part.startswith('weight') and len(part) > len('weight') and part[len('weight'):].isdigit(): + return int(part[len('weight'):]) + if part.startswith('bias') and len(part) > len('bias') and part[len('bias'):].isdigit(): + return int(part[len('bias'):]) + return None + + +def assign_resolved_name_inplace(meta: ParameterMetadata) -> None: + """ + Compute a canonical resolved_name for EP per-expert parameters, and set global_expert_index. + For non-EP or non-per-expert params, resolved_name defaults to original name. + """ + meta.resolved_name = meta.name + meta.global_expert_index = None + if not meta.is_ep: + return + + local_idx = _detect_expert_index_from_param_name(meta.name) + if local_idx is None: + # Fused experts tensor: leave name as-is; TP planner will handle slicing + return + ep_group = meta.expert_parallel_group_ranks + ep_size = len(ep_group) + ep_local_rank = ep_group.index(meta.owner_rank) + experts_per_rank = meta.num_experts // ep_size + global_idx = ep_local_rank * experts_per_rank + local_idx + meta.global_expert_index = global_idx + + # Replace trailing integer in "weightK"/"biasK" with global_idx + parts = meta.name.split('.') + new_parts = [] + for p in parts: + if (p.startswith('weight') and len(p) > len('weight') and p[len('weight'):].isdigit()): + new_parts.append('weight' + str(global_idx)) + elif (p.startswith('bias') and len(p) > len('bias') and p[len('bias'):].isdigit()): + new_parts.append('bias' + str(global_idx)) + else: + new_parts.append(p) + meta.resolved_name = '.'.join(new_parts) + + +def extract_param_metadata( + param: torch.nn.Parameter, + param_name: str, + owner_rank: int, + pg_collection, + num_experts: Optional[int] = None, +) -> ParameterMetadata: + """Extract metadata from a parameter for cross-rank communication.""" + # TP flags from attributes (set by Megatron linear layers) + is_tp = bool(getattr(param, 'tensor_model_parallel', False)) + partition_dim = int(getattr(param, 'partition_dim', 0)) + partition_stride = int(getattr(param, 'partition_stride', 1)) + # EP detection: Megatron convention - expert params are not allreduced + is_ep = not bool(getattr(param, 'allreduce', True)) + + tensor_parallel_group_ranks: list[int] | None = None + expert_parallel_group_ranks: list[int] | None = None + data_parallel_group_ranks: list[int] | None = None + pipeline_parallel_group_ranks: list[int] | None = None + + if is_ep: + expert_parallel_group_ranks = dist.get_process_group_ranks(pg_collection.ep) + # For MoE params, prefer expert TP group when available, else regular TP + if is_tp and hasattr(pg_collection, 'expt_tp') and pg_collection.expt_tp is not None: + tensor_parallel_group_ranks = dist.get_process_group_ranks(pg_collection.expt_tp) + elif is_tp and hasattr(pg_collection, 'tp') and pg_collection.tp is not None: + tensor_parallel_group_ranks = dist.get_process_group_ranks(pg_collection.tp) + data_parallel_group_ranks = dist.get_process_group_ranks(pg_collection.dp) + elif is_tp: + # Non-EP: use regular TP group + if hasattr(pg_collection, 'tp') and pg_collection.tp is not None: + tensor_parallel_group_ranks = dist.get_process_group_ranks(pg_collection.tp) + data_parallel_group_ranks = dist.get_process_group_ranks(pg_collection.dp) + else: + data_parallel_group_ranks = dist.get_process_group_ranks(pg_collection.dp) + + if hasattr(pg_collection, 'pp') and pg_collection.pp is not None: + pipeline_parallel_group_ranks = dist.get_process_group_ranks(pg_collection.pp) + else: + pipeline_parallel_group_ranks = list(range(dist.get_world_size())) + + meta = ParameterMetadata( + name=param_name, + shape=tuple(param.shape), + dtype=param.dtype, + element_size=param.element_size(), + is_tp=is_tp, + partition_dim=partition_dim, + partition_stride=partition_stride, + is_ep=is_ep, + num_experts=num_experts, + owner_rank=owner_rank, + tensor_parallel_group_ranks=tensor_parallel_group_ranks, + expert_parallel_group_ranks=expert_parallel_group_ranks, + data_parallel_group_ranks=data_parallel_group_ranks, + pipeline_parallel_group_ranks=pipeline_parallel_group_ranks, + ) + assign_resolved_name_inplace(meta) + return meta + + +def select_src_metadata_balanced( + src_meta_list: list[ParameterMetadata], dst_metadata: ParameterMetadata, dst_rank: int +) -> ParameterMetadata: + """Choose representative source metadata using DP round-robin across source DP groups.""" + if not src_meta_list: + raise ValueError("src_meta_list must be non-empty") + groups: dict[tuple[int, ...], list[ParameterMetadata]] = {} + for m in src_meta_list: + key = tuple(m.data_parallel_group_ranks or []) + groups.setdefault(key, []).append(m) + if len(groups) == 1: + return src_meta_list[0] + dst_dp = dst_metadata.data_parallel_group_ranks or [] + if dst_rank in dst_dp and len(dst_dp) > 0: + my_dst_dp_idx = dst_dp.index(dst_rank) + else: + my_dst_dp_idx = 0 + keys_sorted = sorted(groups.keys()) + chosen_key = keys_sorted[my_dst_dp_idx % len(keys_sorted)] + return groups[chosen_key][0] + + +logger = logging.getLogger(__name__) + + diff --git a/tests/unit_tests/inference/test_nccl_model_swap.py b/tests/unit_tests/inference/test_nccl_model_swap.py index e99c042cc6b..8f984e3659b 100644 --- a/tests/unit_tests/inference/test_nccl_model_swap.py +++ b/tests/unit_tests/inference/test_nccl_model_swap.py @@ -4,6 +4,8 @@ import pytest import torch import torch.distributed as dist +import os +import pytest from tests.unit_tests.test_utilities import Utils from megatron.core.model_refitting import swap_model_weights @@ -11,38 +13,77 @@ from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) from megatron.core import parallel_state as mpu from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.model_parallel_config import ModelParallelConfig -from mcore_reshard import reshard_with_general_planner -from typing import Tuple +from mcore_reshard import reshard_with_general_planner, build_centralized_reshard_plan +from mcore_reshard.reshard_planner import _extract_param_metadata, _detect_expert_index_from_param_name, _build_descriptors_for_param, _plan_multi_dim_lcm, _plan_dp_recv +from typing import Tuple, Optional -def _build_pg_collection(tp_size: int, pp_size: int = None) -> ProcessGroupCollection: +def _build_pg_collection(tp_size: int, pp_size: int = None, ep_size: int = 1) -> ProcessGroupCollection: cp_size = mpu.get_context_parallel_world_size() if pp_size is None: pp_size = mpu.get_pipeline_model_parallel_world_size() world_size = dist.get_world_size() - dp_size = world_size // (tp_size * cp_size * pp_size) - assert dp_size >= 1 and (tp_size * cp_size * pp_size * dp_size) == world_size + dp_size = world_size // (tp_size * cp_size * ep_size * pp_size) + assert dp_size >= 1 and (tp_size * cp_size * ep_size * pp_size * dp_size) == world_size - grid = HyperCommGrid([tp_size, cp_size, 1, pp_size, dp_size], ["tp", "cp", "ep", "pp", "dp"]) + grid = HyperCommGrid([tp_size, cp_size, ep_size, pp_size, dp_size], ["tp", "cp", "ep", "pp", "dp"]) tp_group = grid.create_pg("tp") cp_group = grid.create_pg("cp") pp_group = grid.create_pg("pp") ep_group = grid.create_pg("ep") dp_group = grid.create_pg("dp") + # Composite groups required by MoE/router and some utilities + tp_cp_group = grid.create_pg(["tp", "cp"]) + mp_group = grid.create_pg(["tp", "cp", "ep", "pp"]) + tp_ep_group = grid.create_pg(["tp", "ep"]) + tp_ep_pp_group = grid.create_pg(["tp", "ep", "pp"]) + dp_cp_group = grid.create_pg(["cp", "dp"]) + tp_dp_cp_group = grid.create_pg(["tp", "cp", "dp"]) embd_group_ranks = mpu.default_embedding_ranks(dist.get_process_group_ranks(pp_group)) embd_group = dist.new_group(ranks=embd_group_ranks) - return ProcessGroupCollection(tp=tp_group, cp=cp_group, pp=pp_group, ep=ep_group, embd=embd_group, dp=dp_group) + pos_embd_group_ranks = mpu.default_position_embedding_ranks(dist.get_process_group_ranks(pp_group)) + pos_embd_group = dist.new_group(ranks=pos_embd_group_ranks) + return ProcessGroupCollection( + tp=tp_group, + cp=cp_group, + pp=pp_group, + ep=ep_group, + embd=embd_group, + pos_embd=pos_embd_group, + dp=dp_group, + tp_cp=tp_cp_group, + mp=mp_group, + expt_tp=tp_group, + expt_dp=dp_group, + tp_ep=tp_ep_group, + tp_ep_pp=tp_ep_pp_group, + dp_cp=dp_cp_group, + tp_dp_cp=tp_dp_cp_group, + ) -def _build_gpt(config: TransformerConfig, vocab_size: int, seq_len: int, pg_collection, parallel_output: bool = True) -> GPTModel: +def _build_gpt( + config: TransformerConfig, + vocab_size: int, + seq_len: int, + pg_collection, + parallel_output: bool = True, + num_moe_experts: Optional[int] = None, +) -> GPTModel: model = GPTModel( config=config, - transformer_layer_spec=get_gpt_layer_local_spec(), + transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec( + num_experts=num_moe_experts, + moe_grouped_gemm=(num_moe_experts is not None), + ), vocab_size=vocab_size, max_sequence_length=seq_len, pre_process=True, @@ -70,27 +111,36 @@ def _set_pg_collection(module, tp_group, dp_group): module.pg_collection = types.SimpleNamespace(tp=tp_group, dp=dp_group, ep=None, pp=None) return module + @pytest.mark.parametrize( - "src_tp,src_pp,dst_tp,dst_pp", + "src_tp,src_pp,src_ep,dst_tp,dst_pp,dst_ep,num_experts", [ - (2, 1, 1, 1), # TP2 -> TP1 - (1, 1, 2, 1), # TP1 -> TP2 - (1, 2, 1, 1), # PP2 -> PP1 - (1, 1, 1, 2), # PP1 -> PP2 - (2, 2, 1, 1), # TP2,PP2 -> TP1,PP1 - (1, 1, 2, 2), # TP1,PP1 -> TP2,PP2 - (2, 1, 1, 2), # TP2,PP1 -> TP1,PP2 - (1, 2, 2, 1), # TP1,PP2 -> TP2,PP1 + #TP only changes + (2, 1, 1, 1, 1, 1, None), # TP2 -> TP1 + (1, 1, 1, 2, 1, 1, None), # TP1 -> TP2 + # PP only changes + (1, 2, 1, 1, 1, 1, None), # PP2 -> PP1 + (1, 1, 1, 1, 2, 1, None), # PP1 -> PP2 + # Both TP and PP change + (2, 2, 1, 1, 1, 1, None), # TP2,PP2 -> TP1,PP1 + (1, 1, 1, 2, 2, 1, None), # TP1,PP1 -> TP2,PP2 + (2, 1, 1, 1, 2, 1, None), # TP2,PP1 -> TP1,PP2 + (1, 2, 1, 2, 1, 1, None), # TP1,PP2 -> TP2,PP1 + (1, 1, 2, 1, 1, 4, 4), # EP2 -> EP4 + (1, 1, 2, 1, 1, 1, 4), + (1, 1, 1, 1, 1, 2, 4), ], ) -def test_nccl_swap_gpt_parametrized(src_tp: int, src_pp: int, dst_tp: int, dst_pp: int): +def test_nccl_swap_gpt_parametrized( + src_tp: int, src_pp: int, src_ep: int, dst_tp: int, dst_pp: int, dst_ep: int, num_experts: Optional[int] +): # Initialize environment with source MP sizing Utils.initialize_model_parallel(tensor_model_parallel_size=src_tp, pipeline_model_parallel_size=src_pp) # Validate divisibility post-init using the default PG safely world = dist.get_world_size() - if (world % (src_tp * src_pp) != 0) or (world % (dst_tp * dst_pp) != 0): + if (world % (src_tp * src_pp * src_ep) != 0) or (world % (dst_tp * dst_pp * dst_ep) != 0): Utils.destroy_model_parallel() - pytest.skip("WORLD_SIZE must be divisible by both src_tp*src_pp and dst_tp*dst_pp") + pytest.skip("WORLD_SIZE must be divisible by both src_tp*src_pp*src_ep and dst_tp*dst_pp*dst_ep") model_parallel_cuda_manual_seed(1234) torch.manual_seed(1234) @@ -107,14 +157,38 @@ def test_nccl_swap_gpt_parametrized(src_tp: int, src_pp: int, dst_tp: int, dst_p pipeline_dtype=torch.float32, hidden_dropout=0.0, attention_dropout=0.0, + moe_router_dtype="fp64", + moe_token_dispatcher_type="alltoall" ) - # Build PGs and models - src_pgs = ProcessGroupCollection.use_mpu_process_groups() - dst_pgs = _build_pg_collection(tp_size=dst_tp, pp_size=dst_pp) - # Use parallel_output=False to gather vocab-parallel outputs inside model and emit only on last PP stage - src_model = _build_gpt(copy.deepcopy(cfg), vocab_size, seq_len, src_pgs, parallel_output=False).to(device).eval() - dst_model = _build_gpt(copy.deepcopy(cfg), vocab_size, seq_len, dst_pgs, parallel_output=False).to(device).eval() + # Build PGs and models (always use unified PG builder so we can set EP) + src_pgs = _build_pg_collection(tp_size=src_tp, pp_size=src_pp, ep_size=src_ep) + dst_pgs = _build_pg_collection(tp_size=dst_tp, pp_size=dst_pp, ep_size=dst_ep) + # Apply EP configuration to TransformerConfigs when MoE is requested + src_cfg = copy.deepcopy(cfg) + dst_cfg = copy.deepcopy(cfg) + if num_experts is not None: + src_cfg.num_moe_experts = num_experts + dst_cfg.num_moe_experts = num_experts + # Ensure MoE MLP has an intermediate size; __post_init__ won't rerun after manual mutation + src_cfg.moe_ffn_hidden_size = src_cfg.ffn_hidden_size + dst_cfg.moe_ffn_hidden_size = dst_cfg.ffn_hidden_size + src_cfg.expert_model_parallel_size = src_ep + dst_cfg.expert_model_parallel_size = dst_ep + # Force grouped MLP path under Transformer Engine and satisfy requirements + src_cfg.moe_grouped_gemm = True + dst_cfg.moe_grouped_gemm = True + src_cfg.add_bias_linear = False + dst_cfg.add_bias_linear = False + # Require Transformer Engine for TEGroupedMLP; skip if unavailable + try: + import transformer_engine # noqa: F401 + except Exception: + Utils.destroy_model_parallel() + pytest.skip("Transformer Engine not available; skipping TE-grouped MoE test") + # Use parallel_output=False to gather TP logits inside model and emit only on last PP stage + src_model = _build_gpt(src_cfg, vocab_size, seq_len, src_pgs, parallel_output=False, num_moe_experts=num_experts).to(device).eval() + dst_model = _build_gpt(dst_cfg, vocab_size, seq_len, dst_pgs, parallel_output=False, num_moe_experts=num_experts).to(device).eval() # Inputs batch = 2 @@ -149,172 +223,5 @@ def test_nccl_swap_gpt_parametrized(src_tp: int, src_pp: int, dst_tp: int, dst_p # Compare assert ref_logits.shape == dst_logits.shape assert torch.allclose(dst_logits, ref_logits, atol=1e-4, rtol=1e-4), f"Refit src(TP={src_tp},PP={src_pp})->dst(TP={dst_tp},PP={dst_pp}) GPT outputs differ" - dist.barrier() Utils.destroy_model_parallel() - -# def test_nccl_swap_row_parallel_linear_tp2_to_tp1(): -# Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1) -# model_parallel_cuda_manual_seed(1234) -# device = torch.device(f"cuda:{torch.cuda.current_device()}") - -# # Build TP=2 source and TP=1 dest groups -# src_pgs = ProcessGroupCollection.use_mpu_process_groups() -# infer_pgs = _build_pg_collection(tp_size=1) - -# in_features = 12 -# out_features = 16 -# cfg = _mp_config() - -# # Source RowParallelLinear (TP=2), input_is_parallel=False so it scatters internally -# src_layer = RowParallelLinear( -# input_size=in_features, -# output_size=out_features, -# config=cfg, -# init_method=torch.nn.init.zeros_, -# bias=False, -# input_is_parallel=False, -# skip_bias_add=True, -# tp_group=src_pgs.tp, -# ).to(device) -# _set_pg_collection(src_layer, src_pgs.tp, src_pgs.dp) -# # Ensure TP metadata is present for planner (row-parallel shards input dim=1) -# src_layer.weight.tensor_model_parallel = True -# src_layer.weight.partition_dim = 1 -# src_layer.weight.partition_stride = 1 - -# # Deterministic per-rank weights (sharded along dim=1) -# rank = dist.get_rank(src_pgs.tp) -# with torch.no_grad(): -# src_layer.weight.copy_( -# torch.arange(src_layer.weight.numel(), device=device, dtype=torch.float32).reshape_as( -# src_layer.weight -# ) -# + rank * 1000.0 -# ) - -# # Destination RowParallelLinear (TP=1) -# dst_layer = RowParallelLinear( -# input_size=in_features, -# output_size=out_features, -# config=_mp_config(), -# init_method=torch.nn.init.zeros_, -# bias=False, -# input_is_parallel=False, -# skip_bias_add=True, -# tp_group=infer_pgs.tp, -# ).to(device) -# _set_pg_collection(dst_layer, infer_pgs.tp, infer_pgs.dp) -# # Destination is unsharded (TP=1) but keep metadata consistent -# dst_layer.weight.tensor_model_parallel = False -# dst_layer.weight.partition_dim = 1 -# dst_layer.weight.partition_stride = 1 - -# # Use layers directly to simplify parameter name matching -# src = src_layer -# dst = dst_layer -# # Attach pg_collection to layers so reshard can find process groups -# src.pg_collection = src_pgs -# dst.pg_collection = infer_pgs - -# # Input and reference (gather master weight along dim=1 from TP=2) -# x = torch.randn(4, in_features, device=device) -# parts = [torch.empty_like(src_layer.weight) for _ in range(dist.get_world_size(src_pgs.tp))] -# dist.all_gather(parts, src_layer.weight.contiguous(), group=src_pgs.tp) -# master_w = torch.cat(parts, dim=1).contiguous() # [out, in] -# ref = x @ master_w.t() - -# # Use resharder directly for per-layer validation and inspect plan -# plan = reshard_with_general_planner(src, dst) -# assert (len(plan.recv_ops) + len(plan.local_copy_ops)) > 0, "No transfers scheduled for RowParallelLinear" -# # Verify weights transferred correctly -# with torch.no_grad(): -# assert dst_layer.weight.shape == master_w.shape -# assert torch.allclose(dst_layer.weight, master_w, atol=1e-6, rtol=1e-6), "RowParallelLinear weights mismatch after transfer" -# y, _ = dst(x) -# assert torch.allclose(y, ref, atol=1e-4, rtol=1e-4), "RowParallelLinear TP2->TP1 mismatch" - -# dist.barrier() -# Utils.destroy_model_parallel() - -# def test_nccl_swap_column_parallel_linear_tp2_to_tp1(): -# Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1) -# model_parallel_cuda_manual_seed(1234) -# device = torch.device(f"cuda:{torch.cuda.current_device()}") - -# # Build TP=2 source and TP=1 dest groups -# src_pgs = ProcessGroupCollection.use_mpu_process_groups() -# infer_pgs = _build_pg_collection(tp_size=1) - -# in_features = 12 -# out_features = 16 -# cfg = _mp_config() - -# # Source ColumnParallelLinear (TP=2) -# src_layer = ColumnParallelLinear( -# input_size=in_features, -# output_size=out_features, -# config=cfg, -# init_method=torch.nn.init.zeros_, -# bias=False, -# gather_output=False, -# tp_group=src_pgs.tp, -# ).to(device) -# _set_pg_collection(src_layer, src_pgs.tp, src_pgs.dp) -# # Ensure TP metadata is present for planner -# src_layer.weight.tensor_model_parallel = True -# src_layer.weight.partition_dim = 0 -# src_layer.weight.partition_stride = 1 - -# # Deterministic per-rank weights -# rank = dist.get_rank(src_pgs.tp) -# with torch.no_grad(): -# src_layer.weight.copy_( -# torch.arange(src_layer.weight.numel(), device=device, dtype=torch.float32).reshape_as( -# src_layer.weight -# ) -# + rank * 1000.0 -# ) - -# # Destination ColumnParallelLinear (TP=1) -# dst_layer = ColumnParallelLinear( -# input_size=in_features, -# output_size=out_features, -# config=_mp_config(), -# init_method=torch.nn.init.zeros_, -# bias=False, -# gather_output=False, -# tp_group=infer_pgs.tp, -# ).to(device) -# _set_pg_collection(dst_layer, infer_pgs.tp, infer_pgs.dp) -# # Destination is unsharded (TP=1) but keep metadata consistent -# dst_layer.weight.tensor_model_parallel = False -# dst_layer.weight.partition_dim = 0 -# dst_layer.weight.partition_stride = 1 - -# # Use layers directly to simplify parameter name matching -# src = src_layer -# dst = dst_layer -# # Attach pg_collection to layers so reshard can find process groups -# src.pg_collection = src_pgs -# dst.pg_collection = infer_pgs - -# # Input and reference (gather master weight from TP=2) -# x = torch.randn(4, in_features, device=device) -# parts = [torch.empty_like(src_layer.weight) for _ in range(dist.get_world_size(src_pgs.tp))] -# dist.all_gather(parts, src_layer.weight.contiguous(), group=src_pgs.tp) -# master_w = torch.cat(parts, dim=0).contiguous() # [out, in] -# ref = x @ master_w.t() - -# # Use resharder directly for per-layer validation and inspect plan -# plan = reshard_with_general_planner(src, dst) -# assert (len(plan.recv_ops) + len(plan.local_copy_ops)) > 0, "No transfers scheduled for ColumnParallelLinear" -# # Verify weights transferred correctly -# with torch.no_grad(): -# assert dst_layer.weight.shape == master_w.shape -# assert torch.allclose(dst_layer.weight, master_w, atol=1e-6, rtol=1e-6), "ColumnParallelLinear weights mismatch after transfer" -# y, _ = dst(x) -# assert torch.allclose(y, ref, atol=1e-4, rtol=1e-4), "ColumnParallelLinear TP2->TP1 mismatch" - -# dist.barrier() -# Utils.destroy_model_parallel() \ No newline at end of file From 8a7b5086cd45450ce26fab50dd9d0812724dd73f Mon Sep 17 00:00:00 2001 From: root Date: Mon, 24 Nov 2025 11:24:29 -0800 Subject: [PATCH 03/44] clean up --- megatron/core/resharding/planner.py | 98 +++++++++++------------------ 1 file changed, 36 insertions(+), 62 deletions(-) diff --git a/megatron/core/resharding/planner.py b/megatron/core/resharding/planner.py index b134c15cf3d..1e58d97f323 100644 --- a/megatron/core/resharding/planner.py +++ b/megatron/core/resharding/planner.py @@ -265,77 +265,51 @@ def build_centralized_reshard_plan( src_param_metadata[key] = [] src_param_metadata[key].append(metadata) - # Build the plan on global rank 0 and broadcast to all ranks with error propagation + # Build the plan on global rank 0 and broadcast to all ranks if my_global_rank == 0: - error_box = [None] plans_for_all_ranks = {r: ReshardPlan([], [], []) for r in range(world_size)} - try: - for dst_rank in range(world_size): - dst_rank_params = dst_param_metadata_by_rank.get(dst_rank, {}) - for resolved_name, dst_metadata in dst_rank_params.items(): - src_meta_list = src_param_metadata.get(resolved_name) - if not src_meta_list: - raise RuntimeError( - f"Destination parameter '{resolved_name}' on rank {dst_rank} not found in source model." - ) - # Choose a representative source metadata with DP round-robin balancing - src_metadata = select_src_metadata_balanced(src_meta_list, dst_metadata, dst_rank) - sources = _determine_source_ranks_for_dst_param( - resolved_name, src_metadata, dst_metadata, dst_rank + for dst_rank in range(world_size): + dst_rank_params = dst_param_metadata_by_rank.get(dst_rank, {}) + for resolved_name, dst_metadata in dst_rank_params.items(): + src_meta_list = src_param_metadata.get(resolved_name) + if not src_meta_list: + raise RuntimeError( + f"Destination parameter '{resolved_name}' on rank {dst_rank} not found in source model." ) - for src_rank, src_slice, dst_slice in sources: - if src_rank == dst_rank and src_metadata.name == dst_metadata.name: - plans_for_all_ranks[dst_rank].local_copy_ops.append( - (dst_metadata.name, None, None, src_slice, dst_slice) - ) - else: - plans_for_all_ranks[dst_rank].recv_ops.append( - TransferOp( - param_name=dst_metadata.name, - peer_rank=src_rank, - is_send=False, - my_slice=dst_slice, - peer_slice=src_slice, - ) + # Choose a representative source metadata with DP round-robin balancing + src_metadata = select_src_metadata_balanced(src_meta_list, dst_metadata, dst_rank) + sources = _determine_source_ranks_for_dst_param( + resolved_name, src_metadata, dst_metadata, dst_rank + ) + for src_rank, src_slice, dst_slice in sources: + if src_rank == dst_rank and src_metadata.name == dst_metadata.name: + plans_for_all_ranks[dst_rank].local_copy_ops.append( + (dst_metadata.name, None, None, src_slice, dst_slice) + ) + else: + plans_for_all_ranks[dst_rank].recv_ops.append( + TransferOp( + param_name=dst_metadata.name, + peer_rank=src_rank, + is_send=False, + my_slice=dst_slice, + peer_slice=src_slice, ) - plans_for_all_ranks[src_rank].send_ops.append( - TransferOp( - param_name=src_metadata.name, - peer_rank=dst_rank, - is_send=True, - my_slice=src_slice, - peer_slice=dst_slice, - ) + ) + plans_for_all_ranks[src_rank].send_ops.append( + TransferOp( + param_name=src_metadata.name, + peer_rank=dst_rank, + is_send=True, + my_slice=src_slice, + peer_slice=dst_slice, ) - plans_list = [plans_for_all_ranks[r] for r in range(world_size)] - except Exception as e: - tb = traceback.format_exc() - error_box[0] = { - "rank": my_global_rank, - "param": resolved_name if 'resolved_name' in locals() else None, - "type": type(e).__name__, - "msg": str(e), - "traceback": tb, - } - plans_list = [None] * world_size - dist.broadcast_object_list(error_box, src=0) + ) + plans_list = [plans_for_all_ranks[r] for r in range(world_size)] else: - error_box = [None] plans_list = [None] * world_size - dist.broadcast_object_list(error_box, src=0) - if error_box[0] is not None: - err = error_box[0] - print( - f"[Reshard Planner] Aborting due to error on rank {err['rank']} while planning {err['param']}: " - f"{err['type']}: {err['msg']}" - ) - print(err["traceback"]) - sys.stdout.flush() - raise RuntimeError(f"Reshard plan failed on rank {err['rank']} for {err['param']}: {err['msg']}") - torch.distributed.barrier() torch.distributed.broadcast_object_list(plans_list, src=0) my_plan = plans_list[my_global_rank] - torch.distributed.barrier() # Fill in actual parameter references for local copies for i, (param_name, _, _, src_slice, dst_slice) in enumerate(my_plan.local_copy_ops): From a881fba1294becf413e6cd7524f4f8d6792301dc Mon Sep 17 00:00:00 2001 From: root Date: Mon, 24 Nov 2025 11:41:47 -0800 Subject: [PATCH 04/44] clean up --- .../refit.py} | 29 ++++++++++++++----- megatron/rl/rl_utils.py | 2 +- .../inference/test_nccl_model_swap.py | 2 +- 3 files changed, 23 insertions(+), 10 deletions(-) rename megatron/core/{model_refitting.py => resharding/refit.py} (83%) diff --git a/megatron/core/model_refitting.py b/megatron/core/resharding/refit.py similarity index 83% rename from megatron/core/model_refitting.py rename to megatron/core/resharding/refit.py index d9aed2c93c4..2d4cb5ba728 100644 --- a/megatron/core/model_refitting.py +++ b/megatron/core/resharding/refit.py @@ -1,22 +1,32 @@ -from megatron.core.models.common.language_module.language_module import LanguageModule +from __future__ import annotations + +from typing import Any, Optional + import torch import torch.distributed as dist -from typing import Any -from megatron.core import parallel_state -from megatron.core.resharding import build_centralized_reshard_plan, execute_reshard_plan -from typing import Any, Optional +from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core import parallel_state +from . import build_centralized_reshard_plan, execute_reshard_plan def _unwrap_module(module: LanguageModule) -> Any: - return module.module.module if hasattr(module, 'module') and hasattr(module.module, 'module') else module.module if hasattr(module, 'module') else module + return ( + module.module.module + if hasattr(module, 'module') and hasattr(module.module, 'module') + else module.module + if hasattr(module, 'module') + else module + ) + def swap_model_weights(src_model: LanguageModule, target_model: LanguageModule, refit_method: str): - if refit_method == "nccl": + if refit_method == "nccl": nccl_model_swap(src_model, target_model) else: raise ValueError(f"Invalid refit method: {refit_method}") + def nccl_model_swap(src_model: LanguageModule, target_model: LanguageModule): # Handle list-wrapped modules used throughout training utils src_lm = src_model[0] if isinstance(src_model, (list, tuple)) else src_model @@ -38,6 +48,7 @@ def nccl_model_swap(src_model: LanguageModule, target_model: LanguageModule): # Fill missing DP group on the source using Megatron's parallel state if not provided if getattr(src_core.pg_collection, "dp", None) is None: src_core.pg_collection.dp = parallel_state.get_data_parallel_group() + # caching plan for reuse cached_plan: Optional[Any] = getattr(tgt_core, "_cached_reshard_plan", None) if cached_plan is None: @@ -45,4 +56,6 @@ def nccl_model_swap(src_model: LanguageModule, target_model: LanguageModule): setattr(tgt_core, "_cached_reshard_plan", plan) else: plan = cached_plan - execute_reshard_plan(plan, src_core, tgt_core) \ No newline at end of file + execute_reshard_plan(plan, src_core, tgt_core) + + diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 7e9a2c39418..7bd1a976c78 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -32,6 +32,7 @@ from megatron.core.rerun_state_machine import RerunDataIterator from megatron.core.transformer.cuda_graphs import _CudagraphGlobalRecord from megatron.core.transformer.utils import toggle_cuda_graphs +from megatron.core.resharding.refit import swap_model_weights from megatron.rl.agent.api import ( EvaluationRequest, EvaluationResponse, @@ -2398,5 +2399,4 @@ def swap_train_to_inference_model(train_model: LanguageModule, inference_model: train_model: The train model to swap to the inference model. inference_model: The inference model to swap to the train model. """ - from megatron.core.model_refitting import swap_model_weights swap_model_weights(train_model, inference_model, refit_method) \ No newline at end of file diff --git a/tests/unit_tests/inference/test_nccl_model_swap.py b/tests/unit_tests/inference/test_nccl_model_swap.py index 8f984e3659b..9ffe8c2de16 100644 --- a/tests/unit_tests/inference/test_nccl_model_swap.py +++ b/tests/unit_tests/inference/test_nccl_model_swap.py @@ -8,7 +8,7 @@ import pytest from tests.unit_tests.test_utilities import Utils -from megatron.core.model_refitting import swap_model_weights +from megatron.core.resharding.refit import swap_model_weights from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.transformer_config import TransformerConfig From f45c32d27a716a3e9b25275892da973c63369d8d Mon Sep 17 00:00:00 2001 From: root Date: Mon, 24 Nov 2025 11:57:56 -0800 Subject: [PATCH 05/44] more refactor --- megatron/core/resharding/execution.py | 9 ++++++--- megatron/core/resharding/refit.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py index 6a12e01fe08..7b012512eb1 100644 --- a/megatron/core/resharding/execution.py +++ b/megatron/core/resharding/execution.py @@ -7,7 +7,6 @@ import torch.distributed as dist from .utils import ReshardPlan -from .copy_services.nccl_copy_service import NCCLCopyService logger = logging.getLogger(__name__) @@ -17,9 +16,13 @@ def execute_reshard_plan( plan: ReshardPlan, src_module: torch.nn.Module, dst_module: torch.nn.Module, + service: object, ) -> None: - """Execute a reshard plan (from centralized controller).""" - service = NCCLCopyService() + """ + Execute a reshard plan (from centralized controller). + A communication service must be provided to abstract transport. + Expected service API: submit_send(tensor, dest_rank), submit_recv(tensor, src_rank), run(). + """ src_params = {name: p for name, p in src_module.named_parameters(recurse=True)} dst_params = {name: p for name, p in dst_module.named_parameters(recurse=True)} diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index 2d4cb5ba728..60cdb4faa94 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -8,6 +8,7 @@ from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.core import parallel_state from . import build_centralized_reshard_plan, execute_reshard_plan +from .copy_services.nccl_copy_service import NCCLCopyService def _unwrap_module(module: LanguageModule) -> Any: @@ -56,6 +57,6 @@ def nccl_model_swap(src_model: LanguageModule, target_model: LanguageModule): setattr(tgt_core, "_cached_reshard_plan", plan) else: plan = cached_plan - execute_reshard_plan(plan, src_core, tgt_core) + execute_reshard_plan(plan, src_core, tgt_core, service=NCCLCopyService()) From 7ae543d39bd85613fe8c5d717223e3e5efc12df4 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 24 Nov 2025 12:40:25 -0800 Subject: [PATCH 06/44] more cleanup --- .../core/resharding/copy_services/__init__.py | 8 +++++ .../core/resharding/copy_services/base.py | 20 +++++++++++ .../copy_services/nccl_copy_service.py | 3 +- megatron/core/resharding/execution.py | 3 +- megatron/core/resharding/refit.py | 34 +++++++++++++++---- 5 files changed, 59 insertions(+), 9 deletions(-) create mode 100644 megatron/core/resharding/copy_services/__init__.py create mode 100644 megatron/core/resharding/copy_services/base.py diff --git a/megatron/core/resharding/copy_services/__init__.py b/megatron/core/resharding/copy_services/__init__.py new file mode 100644 index 00000000000..eb7133c64b0 --- /dev/null +++ b/megatron/core/resharding/copy_services/__init__.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from .base import CopyService +from .nccl_copy_service import NCCLCopyService + +__all__ = ["CopyService", "NCCLCopyService"] + + diff --git a/megatron/core/resharding/copy_services/base.py b/megatron/core/resharding/copy_services/base.py new file mode 100644 index 00000000000..cab7dc71655 --- /dev/null +++ b/megatron/core/resharding/copy_services/base.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +import torch + + +class CopyService(ABC): + @abstractmethod + def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): + ... + + @abstractmethod + def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): + ... + + @abstractmethod + def run(self): + ... + + diff --git a/megatron/core/resharding/copy_services/nccl_copy_service.py b/megatron/core/resharding/copy_services/nccl_copy_service.py index c81a05c80dc..687f967128f 100644 --- a/megatron/core/resharding/copy_services/nccl_copy_service.py +++ b/megatron/core/resharding/copy_services/nccl_copy_service.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist +from .base import CopyService logger = logging.getLogger(__name__) @@ -23,7 +24,7 @@ class RecvOp: src_rank: int -class NCCLCopyService: +class NCCLCopyService(CopyService): """ Thin wrapper around torch.distributed batch_isend_irecv to submit and execute a batch of point-to-point sends and recvs. diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py index 7b012512eb1..f9b950eca82 100644 --- a/megatron/core/resharding/execution.py +++ b/megatron/core/resharding/execution.py @@ -7,6 +7,7 @@ import torch.distributed as dist from .utils import ReshardPlan +from .copy_services.base import CopyService logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ def execute_reshard_plan( plan: ReshardPlan, src_module: torch.nn.Module, dst_module: torch.nn.Module, - service: object, + service: CopyService, ) -> None: """ Execute a reshard plan (from centralized controller). diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index 60cdb4faa94..776591c9051 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any, Optional, Union import torch import torch.distributed as dist @@ -8,6 +8,7 @@ from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.core import parallel_state from . import build_centralized_reshard_plan, execute_reshard_plan +from .copy_services.base import CopyService from .copy_services.nccl_copy_service import NCCLCopyService @@ -21,14 +22,32 @@ def _unwrap_module(module: LanguageModule) -> Any: ) -def swap_model_weights(src_model: LanguageModule, target_model: LanguageModule, refit_method: str): - if refit_method == "nccl": - nccl_model_swap(src_model, target_model) +def swap_model_weights( + src_model: LanguageModule, + target_model: LanguageModule, + refit_method: Union[str, CopyService], +): + """ + Orchestrate weight swap/refit. + - refit_method can be a string backend name ('nccl') or a CopyService instance. + """ + if isinstance(refit_method, CopyService): + service = refit_method + elif isinstance(refit_method, str): + if refit_method == "nccl": + service = NCCLCopyService() + else: + raise ValueError(f"Unknown refit_method '{refit_method}'") else: - raise ValueError(f"Invalid refit method: {refit_method}") + raise TypeError("refit_method must be a str backend name or a CopyService instance") + nccl_model_swap(src_model, target_model, service=service) -def nccl_model_swap(src_model: LanguageModule, target_model: LanguageModule): +def nccl_model_swap( + src_model: LanguageModule, + target_model: LanguageModule, + service: CopyService, +): # Handle list-wrapped modules used throughout training utils src_lm = src_model[0] if isinstance(src_model, (list, tuple)) else src_model tgt_lm = target_model[0] if isinstance(target_model, (list, tuple)) else target_model @@ -51,12 +70,13 @@ def nccl_model_swap(src_model: LanguageModule, target_model: LanguageModule): src_core.pg_collection.dp = parallel_state.get_data_parallel_group() # caching plan for reuse + # TODO(Peter): Is there a better place to cache this? cached_plan: Optional[Any] = getattr(tgt_core, "_cached_reshard_plan", None) if cached_plan is None: plan = build_centralized_reshard_plan(src_core, tgt_core, num_experts=num_experts) setattr(tgt_core, "_cached_reshard_plan", plan) else: plan = cached_plan - execute_reshard_plan(plan, src_core, tgt_core, service=NCCLCopyService()) + execute_reshard_plan(plan, src_core, tgt_core, service=service) From 52b8b8dcdb7488f918030f7154547a5f68a3040e Mon Sep 17 00:00:00 2001 From: root Date: Mon, 24 Nov 2025 13:11:59 -0800 Subject: [PATCH 07/44] clean up --- megatron/core/resharding/__init__.py | 3 +++ megatron/core/resharding/execution.py | 1 - megatron/core/resharding/planner.py | 10 ++-------- megatron/core/resharding/refit.py | 16 ++++++++++------ 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/megatron/core/resharding/__init__.py b/megatron/core/resharding/__init__.py index cb06ddebe2e..f84b0665696 100644 --- a/megatron/core/resharding/__init__.py +++ b/megatron/core/resharding/__init__.py @@ -1,5 +1,6 @@ from .planner import build_centralized_reshard_plan from .execution import execute_reshard_plan +from .refit import swap_model_weights, reshard_model_weights from .utils import ( ParameterMetadata, ShardingDescriptor, @@ -10,6 +11,8 @@ __all__ = [ "build_centralized_reshard_plan", "execute_reshard_plan", + "swap_model_weights", + "reshard_model_weights", "ParameterMetadata", "ShardingDescriptor", "TransferOp", diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py index f9b950eca82..a710210308d 100644 --- a/megatron/core/resharding/execution.py +++ b/megatron/core/resharding/execution.py @@ -63,7 +63,6 @@ def execute_reshard_plan( service.run() #TODO(Peter) remove this eventually? dist.barrier() - torch.cuda.synchronize() # Write back received buffers into their destination parameter slices for recv_buffer, dst_param, dst_slice in recv_writebacks: diff --git a/megatron/core/resharding/planner.py b/megatron/core/resharding/planner.py index 1e58d97f323..b7ee3cb6e55 100644 --- a/megatron/core/resharding/planner.py +++ b/megatron/core/resharding/planner.py @@ -2,8 +2,6 @@ import logging import math -import sys -import traceback from typing import Optional import torch @@ -212,7 +210,6 @@ def build_centralized_reshard_plan( src_module: torch.nn.Module, dst_module: torch.nn.Module, num_experts: int = None, - validate_config: bool = True, ) -> ReshardPlan: """ Centralized planning: Rank 0 builds complete plan for all ranks, then scatters. @@ -228,19 +225,16 @@ def build_centralized_reshard_plan( if not hasattr(src_pg, 'dp'): raise ValueError("src_pg must have dp process group") - src_num_experts = num_experts - dst_num_experts = num_experts - # Gather param metadata from all ranks my_src_params = {name: p for name, p in src_module.named_parameters(recurse=True)} my_dst_params = {name: p for name, p in dst_module.named_parameters(recurse=True)} my_src_metadata = [ - extract_param_metadata(p, name, my_global_rank, src_pg, num_experts=src_num_experts) + extract_param_metadata(p, name, my_global_rank, src_pg, num_experts=num_experts) for name, p in my_src_params.items() ] my_dst_metadata = [ - extract_param_metadata(p, name, my_global_rank, dst_pg, num_experts=dst_num_experts) + extract_param_metadata(p, name, my_global_rank, dst_pg, num_experts=num_experts) for name, p in my_dst_params.items() ] diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index 776591c9051..f7f76d00f33 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -1,10 +1,12 @@ from __future__ import annotations +""" +High-level refit/reshard orchestration: +- swap_model_weights: public API; accepts a backend name or CopyService and delegates. +- reshard_model_weights: transport-agnostic core; builds/caches plan and executes. +""" from typing import Any, Optional, Union -import torch -import torch.distributed as dist - from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.core import parallel_state from . import build_centralized_reshard_plan, execute_reshard_plan @@ -33,17 +35,19 @@ def swap_model_weights( """ if isinstance(refit_method, CopyService): service = refit_method + reshard_model_weights(src_model, target_model, service=service) elif isinstance(refit_method, str): if refit_method == "nccl": service = NCCLCopyService() + reshard_model_weights(src_model, target_model, service=service) else: raise ValueError(f"Unknown refit_method '{refit_method}'") else: raise TypeError("refit_method must be a str backend name or a CopyService instance") - nccl_model_swap(src_model, target_model, service=service) + -def nccl_model_swap( +def reshard_model_weights( src_model: LanguageModule, target_model: LanguageModule, service: CopyService, @@ -70,7 +74,7 @@ def nccl_model_swap( src_core.pg_collection.dp = parallel_state.get_data_parallel_group() # caching plan for reuse - # TODO(Peter): Is there a better place to cache this? + # TODO(Peter): Find a better place to cache this. cached_plan: Optional[Any] = getattr(tgt_core, "_cached_reshard_plan", None) if cached_plan is None: plan = build_centralized_reshard_plan(src_core, tgt_core, num_experts=num_experts) From f475327987ae6a16b912c203e490df2b0d7ac091 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 24 Nov 2025 13:20:45 -0800 Subject: [PATCH 08/44] more tests --- tests/unit_tests/inference/test_nccl_model_swap.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit_tests/inference/test_nccl_model_swap.py b/tests/unit_tests/inference/test_nccl_model_swap.py index 9ffe8c2de16..72f4923df6d 100644 --- a/tests/unit_tests/inference/test_nccl_model_swap.py +++ b/tests/unit_tests/inference/test_nccl_model_swap.py @@ -129,6 +129,7 @@ def _set_pg_collection(module, tp_group, dp_group): (1, 1, 2, 1, 1, 4, 4), # EP2 -> EP4 (1, 1, 2, 1, 1, 1, 4), (1, 1, 1, 1, 1, 2, 4), + (1, 1, 2, 1, 2, 2, 4), ], ) def test_nccl_swap_gpt_parametrized( From 2fbd44d1bc8d5b9e155de82877c8441dbc15c495 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Dec 2025 13:08:16 -0800 Subject: [PATCH 09/44] end2end --- megatron/core/process_groups_config.py | 126 ++++++++++++++++++ .../copy_services/gloo_copy_service.py | 84 ++++++++++++ megatron/core/resharding/refit.py | 6 + megatron/rl/inference/megatron.py | 3 +- megatron/rl/rl_utils.py | 11 +- megatron/training/arguments.py | 5 +- megatron/training/training.py | 34 ++++- train_rl.py | 31 ++++- 8 files changed, 281 insertions(+), 19 deletions(-) create mode 100644 megatron/core/resharding/copy_services/gloo_copy_service.py diff --git a/megatron/core/process_groups_config.py b/megatron/core/process_groups_config.py index ef8f31ea150..74f078de431 100644 --- a/megatron/core/process_groups_config.py +++ b/megatron/core/process_groups_config.py @@ -249,6 +249,132 @@ def use_mpu_process_groups(cls, required_pgs: Optional[List[str]] = None): return cls(**init_dict) + def get_pipeline_model_parallel_world_size(self) -> int: + """Return PP world size using the PP process group.""" + return torch.distributed.get_world_size(self.pp) + + def get_tensor_model_parallel_rank(self) -> int: + """Return this rank's TP rank within the TP group.""" + global_rank = torch.distributed.get_rank() + tp_ranks = torch.distributed.get_process_group_ranks(self.tp) + return tp_ranks.index(global_rank) + + def get_pipeline_model_parallel_rank(self) -> int: + """Return this rank's PP rank within the PP group.""" + global_rank = torch.distributed.get_rank() + pp_ranks = torch.distributed.get_process_group_ranks(self.pp) + return pp_ranks.index(global_rank) + + def is_pipeline_first_stage(self, ignore_virtual: bool = True, vp_stage: Optional[int] = None) -> bool: + """Return True if this rank is on the first PP stage. + + By default, ignores virtual pipeline (matches legacy interface). If you need VP-aware + behavior, pass ignore_virtual=False and specify vp_stage. + """ + pp_ranks = torch.distributed.get_process_group_ranks(self.pp) + global_rank = torch.distributed.get_rank() + is_pp_first = (pp_ranks[0] == global_rank) + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + if ignore_virtual or vp_size in (None, 1) or vp_stage is None: + return is_pp_first + is_vp_first = (vp_stage == 0) + return is_vp_first and is_pp_first + + def is_pipeline_last_stage(self, ignore_virtual: bool = True, vp_stage: Optional[int] = None) -> bool: + """Return True if this rank is on the last PP stage. + + By default, ignores virtual pipeline (matches legacy interface). If you need VP-aware + behavior, pass ignore_virtual=False and specify vp_stage. + """ + pp_ranks = torch.distributed.get_process_group_ranks(self.pp) + global_rank = torch.distributed.get_rank() + is_pp_last = (pp_ranks[-1] == global_rank) + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + if ignore_virtual or vp_size in (None, 1) or vp_stage is None: + return is_pp_last + is_vp_last = (vp_stage == (vp_size - 1)) + return is_vp_last and is_pp_last + + def get_data_parallel_rank(self) -> int: + """Return this rank's DP rank within the DP group.""" + global_rank = torch.distributed.get_rank() + dp_ranks = torch.distributed.get_process_group_ranks(self.dp) + return dp_ranks.index(global_rank) + + def get_context_parallel_rank(self) -> int: + """Return this rank's CP rank within the CP group, or 0 if no CP.""" + if not hasattr(self, 'cp') or self.cp is None: + return 0 + global_rank = torch.distributed.get_rank() + cp_ranks = torch.distributed.get_process_group_ranks(self.cp) + return cp_ranks.index(global_rank) + + def get_data_parallel_group(self, with_context_parallel: bool = False, partial_data_parallel: bool = False): + """Return the DP/DP+CP process group, optionally partial.""" + if with_context_parallel: + if partial_data_parallel: + # Prefer intra_dp_cp if available, else fallback to dp_cp + if hasattr(self, 'intra_dp_cp') and self.intra_dp_cp is not None: + return self.intra_dp_cp + if hasattr(self, 'dp_cp') and self.dp_cp is not None: + return self.dp_cp + return self.dp + else: + if hasattr(self, 'dp_cp') and self.dp_cp is not None: + return self.dp_cp + return self.dp + return self.dp + + def get_tensor_model_parallel_group(self): + return self.tp + + def get_pipeline_model_parallel_group(self): + return self.pp + + def get_model_parallel_group(self): + return self.mp + + def get_model_parallel_world_size(self) -> int: + """Return MP world size using the MP process group.""" + return torch.distributed.get_world_size(self.mp) + + def get_model_parallel_src_rank(self) -> int: + """Return the source (leader) global rank for the MP group.""" + ranks = torch.distributed.get_process_group_ranks(self.mp) + return ranks[0] + + def get_context_parallel_group(self): + return getattr(self, 'cp', None) + + def get_expert_model_parallel_group(self): + return getattr(self, 'ep', None) + + def get_expert_data_parallel_group(self, partial_expert_data_parallel: bool = False): + if partial_expert_data_parallel: + return getattr(self, 'intra_expt_dp', None) + return getattr(self, 'expt_dp', None) + + def get_inter_distributed_optimizer_instance_group(self): + return getattr(self, 'inter_dist_opt', None) + + def get_intra_distributed_optimizer_instance_group(self): + return getattr(self, 'intra_dist_opt', None) + + def get_data_parallel_src_rank(self, with_context_parallel: bool = False, partial_data_parallel: bool = False) -> int: + """Return the source (leader) global rank for the selected DP group.""" + group = self.get_data_parallel_group( + with_context_parallel=with_context_parallel, partial_data_parallel=partial_data_parallel + ) + ranks = torch.distributed.get_process_group_ranks(group) + return ranks[0] + + def get_data_parallel_world_size(self, with_context_parallel: bool = False, partial_data_parallel: bool = False) -> int: + """Return world size of the selected DP group.""" + group = self.get_data_parallel_group( + with_context_parallel=with_context_parallel, partial_data_parallel=partial_data_parallel + ) + return torch.distributed.get_world_size(group) + @staticmethod def setup_process_groups_for_optimizer( pg_collection: Optional['ProcessGroupCollection'], diff --git a/megatron/core/resharding/copy_services/gloo_copy_service.py b/megatron/core/resharding/copy_services/gloo_copy_service.py new file mode 100644 index 00000000000..42b72be1954 --- /dev/null +++ b/megatron/core/resharding/copy_services/gloo_copy_service.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import List, Tuple + +import torch +import torch.distributed as dist + +from .base import CopyService + + +logger = logging.getLogger(__name__) + + +@dataclass +class SendOp: + tensor: torch.Tensor + dest_rank: int + + +@dataclass +class RecvOp: + tensor: torch.Tensor + src_rank: int + + +class GlooCopyService(CopyService): + """ + CopyService implementation that routes refit traffic over a CPU/Gloo + process group instead of NCCL. + """ + + def __init__(self): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.gloo_pg = dist.new_group(backend="gloo") + self.send_ops: List[SendOp] = [] + self.recv_ops: List[Tuple[RecvOp, torch.Tensor]] = [] + logger.info(f"GlooCopyService initialized on rank {self.rank} with {self.world_size} ranks") + + def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): + self.send_ops.append(SendOp(tensor=src_tensor, dest_rank=dest_rank)) + + def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): + # Allocate a CPU buffer that matches the destination view; we'll + # copy into dest_tensor after the Gloo recv completes. + cpu_buffer = torch.empty_like(dest_tensor, device="cpu").contiguous() + self.recv_ops.append((RecvOp(tensor=cpu_buffer, src_rank=src_rank), dest_tensor)) + + def run(self): + total_ops = len(self.send_ops) + len(self.recv_ops) + logger.info( + f"GlooCopyService rank {self.rank}: executing batched communication: " + f"{len(self.send_ops)} sends + {len(self.recv_ops)} recvs = {total_ops} ops" + ) + + p2p_ops: List[dist.P2POp] = [] + + # Build Gloo P2P ops over CPU tensors. For sends we clone to CPU; + # for recvs we use the preallocated CPU buffers. + for op in self.send_ops: + cpu_tensor = op.tensor.detach().to("cpu").contiguous() + p2p_ops.append(dist.P2POp(dist.isend, cpu_tensor, op.dest_rank, group=self.gloo_pg)) + for recv, _dst_tensor in self.recv_ops: + p2p_ops.append(dist.P2POp(dist.irecv, recv.tensor, recv.src_rank, group=self.gloo_pg)) + + if p2p_ops: + reqs = dist.batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + # Copy received CPU buffers back into the original destination tensors. + for recv, dst_tensor in self.recv_ops: + if dst_tensor.is_cuda: + dst_tensor.copy_(recv.tensor.to(dst_tensor.device)) + else: + dst_tensor.copy_(recv.tensor) + + logger.info("GlooCopyService: batched communication completed") + self.send_ops.clear() + self.recv_ops.clear() + + diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index f7f76d00f33..95c0203b5f8 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -12,6 +12,7 @@ from . import build_centralized_reshard_plan, execute_reshard_plan from .copy_services.base import CopyService from .copy_services.nccl_copy_service import NCCLCopyService +from .copy_services.gloo_copy_service import GlooCopyService def _unwrap_module(module: LanguageModule) -> Any: @@ -40,6 +41,10 @@ def swap_model_weights( if refit_method == "nccl": service = NCCLCopyService() reshard_model_weights(src_model, target_model, service=service) + elif refit_method == "gloo": + # Debug / fallback backend: run refit over CPU/Gloo instead of NCCL. + service = GlooCopyService() + reshard_model_weights(src_model, target_model, service=service) else: raise ValueError(f"Unknown refit_method '{refit_method}'") else: @@ -81,6 +86,7 @@ def reshard_model_weights( setattr(tgt_core, "_cached_reshard_plan", plan) else: plan = cached_plan + execute_reshard_plan(plan, src_core, tgt_core, service=service) diff --git a/megatron/rl/inference/megatron.py b/megatron/rl/inference/megatron.py index 2be354717ee..12dcfbb15df 100644 --- a/megatron/rl/inference/megatron.py +++ b/megatron/rl/inference/megatron.py @@ -8,6 +8,7 @@ import torch.distributed as dist from megatron.core import parallel_state +from megatron.core.utils import get_attr_wrapped_model from megatron.core.inference.inference_client import InferenceClient from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext from megatron.core.inference.engines.abstract_engine import AbstractEngine @@ -227,6 +228,7 @@ async def base_generate(self, request: InferenceRequest): async def launch(cls, model: GPTModel, **kwargs): args = get_args() tokenizer = get_tokenizer() + rank = dist.get_rank() if tokenizer.bos is None: log_single_rank( @@ -246,7 +248,6 @@ async def launch(cls, model: GPTModel, **kwargs): if metrics_writer is None: log_single_rank(logger, logging.WARNING, "WARNING: --rl-inference-logging-step-interval is set but no metrics writer " "wandb module is available. Inference logging will be disabled.") - # TODO(Peter) We need to pass the pg_collection to the coordinator, but like where is the coordinator even defined inference_engine: DynamicInferenceEngine = get_dynamic_inference_engine(args, model, inference_logging_step_interval, metrics_writer) await inference_engine.start_listening_to_data_parallel_coordinator(inference_coordinator_port=41521, launch_inference_coordinator=True) if dist.get_rank() == 0: diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 32644656e1d..6da35a228be 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -55,7 +55,7 @@ ) from megatron.training.tokenizer.tokenizer import CustomTikTokenizer, _HuggingFaceTokenizer from megatron.training.utils import get_ltor_masks_and_position_ids, get_nvtx_range, print_rank_0 - +from megatron.training.utils import unwrap_model logger = logging.getLogger(__name__) # Global variable to store packing context for forward_step @@ -1983,7 +1983,7 @@ def setup_grpo_data_iterator( args = get_args() if inference_model is not None: - inference_mpu = inference_model.pg_collection + inference_mpu = unwrap_model(inference_model[0]).pg_collection else: inference_mpu = mpu @@ -2274,7 +2274,6 @@ def megatron_rl_inference_mode( loop = get_asyncio_loop() nvtx_range = get_nvtx_range() - print(f"[{dist.get_rank()}:DP] Entering inference mode") # If we get a lower precision wrapper, we go one object deeper. lang_module = model[0].module.module if hasattr(model[0].module, "module") else model[0].module @@ -2326,8 +2325,6 @@ def megatron_rl_inference_mode( inference_interface._inference_engine.create_cuda_graphs(reset_context=True) loop.run_until_complete(inference_interface.resume()) - - print(f"[{dist.get_rank()}:DP] Entered inference mode") yield inference_interface with nvtx_range("suspend-engine"): @@ -2405,11 +2402,11 @@ def get_sequence_packing_tensorboard_metrics(args): metrics['consumed-bins'] = args.consumed_train_bins return metrics -def swap_train_to_inference_model(train_model: LanguageModule, inference_model: LanguageModule, refit_method: str): +def swap_train_to_inference_model(train_model: LanguageModule, inference_model: list[LanguageModule], refit_method: str): """Swap the train model to the inference model. Args: train_model: The train model to swap to the inference model. - inference_model: The inference model to swap to the train model. + inference_model: The inference model (list) to swap to the train model. """ swap_model_weights(train_model, inference_model, refit_method) \ No newline at end of file diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 91fa61b3746..6fbfc6613f3 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1986,10 +1986,9 @@ def _add_rl_args(parser): 'round-robin: distribute bins cyclically across ranks for better load balancing') group.add_argument('--rl-inference-tensor-model-parallel-size', type=int, default=None, help='Degree of tensor model parallelism for inference for RL.') - group.add_argument('--refit-method', type=str, default='naive', - choices=['naive', 'nccl'], + group.add_argument('--refit-method', type=str, default='nccl', + choices=['nccl', 'gloo'], help=('Method to refit the model weights between training and inference models during RL. ' - 'naive: naive method to refit the model weights between training and inference models during RL. ' 'nccl: use NCCLCopyService to refit the model weights between training and inference models during RL.')) return parser diff --git a/megatron/training/training.py b/megatron/training/training.py index 9fb19b94359..e6f30a73820 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -694,7 +694,7 @@ def pretrain( print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') config = get_model_config(model[0]) - # Build a separate inference model for RL if requested. + # Build a separate inference model for RL if requested. inference_model = None if args.perform_rl_step: pg_collection = None @@ -708,17 +708,45 @@ def pretrain( assert dp_size >= 1 and (tp_size * cp_size * pp_size * dp_size) == args.world_size, \ "World size must be divisible by tp*cp*pp for inference PG layout" + # TODO(Peter) We need to pass the expert parallel correctly here grid = HyperCommGrid([tp_size, cp_size, 1, pp_size, dp_size], ["tp", "cp", "ep", "pp", "dp"]) tp_group = grid.create_pg("tp") cp_group = grid.create_pg("cp") pp_group = grid.create_pg("pp") ep_group = grid.create_pg("ep") dp_group = grid.create_pg("dp") + # Composite groups required by MoE/router and some utilities + tp_cp_group = grid.create_pg(["tp", "cp"]) + mp_group = grid.create_pg(["tp", "cp", "ep", "pp"]) + tp_ep_group = grid.create_pg(["tp", "ep"]) + tp_ep_pp_group = grid.create_pg(["tp", "ep", "pp"]) + dp_cp_group = grid.create_pg(["cp", "dp"]) + tp_dp_cp_group = grid.create_pg(["tp", "cp", "dp"]) embd_group_ranks = mpu.default_embedding_ranks( torch.distributed.get_process_group_ranks(pp_group) ) embd_group = torch.distributed.new_group(ranks=embd_group_ranks) - inference_pg_collection = ProcessGroupCollection(tp=tp_group, cp=cp_group, pp=pp_group, ep=ep_group, embd=embd_group, dp=dp_group) + pos_embd_group_ranks = mpu.default_position_embedding_ranks( + torch.distributed.get_process_group_ranks(pp_group) + ) + pos_embd_group = torch.distributed.new_group(ranks=pos_embd_group_ranks) + inference_pg_collection = ProcessGroupCollection( + tp=tp_group, + cp=cp_group, + pp=pp_group, + ep=ep_group, + embd=embd_group, + pos_embd=pos_embd_group, + dp=dp_group, + tp_cp=tp_cp_group, + mp=mp_group, + expt_tp=tp_group, + expt_dp=dp_group, + tp_ep=tp_ep_group, + tp_ep_pp=tp_ep_pp_group, + dp_cp=dp_cp_group, + tp_dp_cp=tp_dp_cp_group, + ) # Build an isolated inference config so training config remains unchanged inference_config = copy.deepcopy(config) @@ -922,7 +950,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap args = get_args() args.model_type = model_type if pg_collection is None: - pg_collection = mpu + pg_collection = ProcessGroupCollection.use_mpu_process_groups() if has_nvidia_modelopt: diff --git a/train_rl.py b/train_rl.py index 33fca0cb840..d73c5f3536f 100644 --- a/train_rl.py +++ b/train_rl.py @@ -22,7 +22,7 @@ stimer = StragglerDetector() -def _gpt_builder(args, pre_process, post_process, vp_stage=None, config=None): +def _gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None): # TODO(Peter): This is a hack to get around the fact that we are activation recomputation for training but not # for inference with cuda graphs. Without out this the post checks in the transformer config will assert error. if config is None: @@ -54,7 +54,14 @@ def _gpt_builder(args, pre_process, post_process, vp_stage=None, config=None): ) with build_model_context(**build_model_context_args): - return gpt_builder(args, pre_process, post_process, vp_stage=vp_stage, config=config) + return gpt_builder( + args, + pre_process, + post_process, + vp_stage=vp_stage, + config=config, + pg_collection=pg_collection, + ) # define spiky loss as a variation of 20% or more @@ -363,11 +370,25 @@ def __getitem__(self, idx): # Temporary for transition to core datasets train_valid_test_datasets_provider.is_distributed = True - def _model_builder(args, pre_process, post_process, vp_stage=None): + def _model_builder(args, pre_process, post_process, vp_stage=None, pg_collection=None, config=None): if getattr(args, "is_hybrid_model", False): - return mamba_builder(args, pre_process, post_process, vp_stage) + return mamba_builder( + args, + pre_process, + post_process, + vp_stage, + config=config, + pg_collection=pg_collection, + ) else: - return _gpt_builder(args, pre_process, post_process, vp_stage) + return _gpt_builder( + args, + pre_process, + post_process, + vp_stage, + config=config, + pg_collection=pg_collection, + ) pretrain( None, # we don't need to build any datasets for RL training From 2d93e49a74ed360f9d3d6a7c6f0aaf756b275b44 Mon Sep 17 00:00:00 2001 From: William Dykas Date: Tue, 2 Dec 2025 16:53:12 -0800 Subject: [PATCH 10/44] clean up --- megatron/core/resharding/planner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/megatron/core/resharding/planner.py b/megatron/core/resharding/planner.py index b7ee3cb6e55..f40722ea680 100644 --- a/megatron/core/resharding/planner.py +++ b/megatron/core/resharding/planner.py @@ -222,8 +222,6 @@ def build_centralized_reshard_plan( dst_pg = getattr(dst_module, "pg_collection", None) if src_pg is None or dst_pg is None: raise ValueError("Both modules must have pg_collection") - if not hasattr(src_pg, 'dp'): - raise ValueError("src_pg must have dp process group") # Gather param metadata from all ranks my_src_params = {name: p for name, p in src_module.named_parameters(recurse=True)} From 0960061b154c052ab890b3b1401b595354348ab6 Mon Sep 17 00:00:00 2001 From: William Dykas Date: Tue, 2 Dec 2025 17:01:09 -0800 Subject: [PATCH 11/44] add tests --- .../golden_values_dev_dgx_h100.json | 287 ++++++++++++++++++ .../model_config.yaml | 79 +++++ tests/test_utils/recipes/gpt-grpo.yaml | 5 + 3 files changed, 371 insertions(+) create mode 100644 tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/golden_values_dev_dgx_h100.json create mode 100644 tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/model_config.yaml diff --git a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/golden_values_dev_dgx_h100.json new file mode 100644 index 00000000000..1ea946d1587 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/golden_values_dev_dgx_h100.json @@ -0,0 +1,287 @@ +{ + "lm loss": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 0.0, + "2": 0.04415, + "3": 0.0378, + "4": 0.02944, + "5": 0.0, + "6": 0.0, + "7": 0.0, + "8": 0.08111, + "9": 0.0, + "10": 0.0, + "11": 0.0, + "12": 0.0, + "13": 0.0, + "14": 0.05935, + "15": 0.0, + "16": 0.05496, + "17": 0.0, + "18": 0.0, + "19": 0.0, + "20": 0.04534, + "21": 0.0, + "22": 0.0, + "23": 0.0, + "24": 0.0, + "25": 0.0, + "26": 0.0, + "27": 0.0, + "28": 0.0, + "29": 0.0, + "30": 0.0, + "31": 0.0, + "32": 0.0, + "33": 0.0, + "34": 0.0, + "35": 0.0, + "36": 0.0, + "37": 0.0099, + "38": 0.0, + "39": 0.0, + "40": 0.0, + "41": 0.03221, + "42": 0.0, + "43": 0.0, + "44": 0.0, + "45": 0.0, + "46": 0.0, + "47": 0.0, + "48": 0.0, + "49": 0.0, + "50": 0.0 + } + }, + "num-zeros": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 583687296.0, + "2": 0.0, + "3": 0.0, + "4": 49.0, + "5": 583687296.0, + "6": 583687296.0, + "7": 583687296.0, + "8": 12.0, + "9": 583687296.0, + "10": 583687296.0, + "11": 583687296.0, + "12": 583687296.0, + "13": 583687296.0, + "14": 6.0, + "15": 583687296.0, + "16": 62.0, + "17": 583687296.0, + "18": 583687296.0, + "19": 583687296.0, + "20": 23.0, + "21": 583687296.0, + "22": 583687296.0, + "23": 583687296.0, + "24": 583687296.0, + "25": 583687296.0, + "26": 583687296.0, + "27": 583687296.0, + "28": 583687296.0, + "29": 583687296.0, + "30": 583687296.0, + "31": 583687296.0, + "32": 583687296.0, + "33": 583687296.0, + "34": 583687296.0, + "35": 583687296.0, + "36": 583687296.0, + "37": 37.0, + "38": 583687296.0, + "39": 583687296.0, + "40": 583687296.0, + "41": 53.0, + "42": 583687296.0, + "43": 583687296.0, + "44": 583687296.0, + "45": 583687296.0, + "46": 583687296.0, + "47": 583687296.0, + "48": 583687296.0, + "49": 583687296.0, + "50": 583687296.0 + } + }, + "mem-allocated-bytes": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 55320928256.0, + "2": 55319695360.0, + "3": 55319674880.0, + "4": 55319638016.0, + "5": 55319638016.0, + "6": 55319638016.0, + "7": 55319633920.0, + "8": 55319625728.0, + "9": 55319621632.0, + "10": 55319625728.0, + "11": 55319625728.0, + "12": 55319629824.0, + "13": 55319547904.0, + "14": 55319552000.0, + "15": 55319552000.0, + "16": 55319552000.0, + "17": 55319552000.0, + "18": 55319552000.0, + "19": 55319556096.0, + "20": 55319556096.0, + "21": 55319556096.0, + "22": 55319556096.0, + "23": 55319556096.0, + "24": 55319560192.0, + "25": 55319560192.0, + "26": 55319560192.0, + "27": 55319560192.0, + "28": 55319552000.0, + "29": 55319552000.0, + "30": 55319552000.0, + "31": 55319552000.0, + "32": 55319552000.0, + "33": 55319552000.0, + "34": 55319556096.0, + "35": 55319556096.0, + "36": 55319556096.0, + "37": 55319560192.0, + "38": 55319560192.0, + "39": 55319560192.0, + "40": 55319556096.0, + "41": 55319552000.0, + "42": 55319552000.0, + "43": 55319552000.0, + "44": 55319552000.0, + "45": 55319552000.0, + "46": 55319552000.0, + "47": 55319556096.0, + "48": 55319556096.0, + "49": 55319556096.0, + "50": 55319552000.0 + } + }, + "mem-max-allocated-bytes": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 64753942528.0, + "2": 69804253184.0, + "3": 69804253184.0, + "4": 69804253184.0, + "5": 69804253184.0, + "6": 69804253184.0, + "7": 69804253184.0, + "8": 69804253184.0, + "9": 69804253184.0, + "10": 69804253184.0, + "11": 69804253184.0, + "12": 69804253184.0, + "13": 69804253184.0, + "14": 69804253184.0, + "15": 69804253184.0, + "16": 69804253184.0, + "17": 69804253184.0, + "18": 69804253184.0, + "19": 69804253184.0, + "20": 69804253184.0, + "21": 69804253184.0, + "22": 69804253184.0, + "23": 69804253184.0, + "24": 69804253184.0, + "25": 69804253184.0, + "26": 69804253184.0, + "27": 69804253184.0, + "28": 69804253184.0, + "29": 69804253184.0, + "30": 69804253184.0, + "31": 69804253184.0, + "32": 69804253184.0, + "33": 69804253184.0, + "34": 69804253184.0, + "35": 69804253184.0, + "36": 69804253184.0, + "37": 69804253184.0, + "38": 69804253184.0, + "39": 69804253184.0, + "40": 69804253184.0, + "41": 69804253184.0, + "42": 69804253184.0, + "43": 69804253184.0, + "44": 69804253184.0, + "45": 69804253184.0, + "46": 69804253184.0, + "47": 69804253184.0, + "48": 69804253184.0, + "49": 69804253184.0, + "50": 69804253184.0 + } + }, + "iteration-time": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 74.35665, + "2": 5.25731, + "3": 5.75582, + "4": 4.02061, + "5": 3.8529, + "6": 3.91732, + "7": 4.14616, + "8": 3.83737, + "9": 3.75158, + "10": 3.91902, + "11": 3.96073, + "12": 3.83611, + "13": 3.86989, + "14": 3.88658, + "15": 4.46432, + "16": 3.90389, + "17": 3.8143, + "18": 3.86593, + "19": 3.78307, + "20": 3.90922, + "21": 3.82247, + "22": 3.76037, + "23": 4.00863, + "24": 3.74678, + "25": 3.86492, + "26": 3.83492, + "27": 3.86387, + "28": 3.99894, + "29": 3.85812, + "30": 4.34066, + "31": 3.88411, + "32": 3.80617, + "33": 3.90347, + "34": 3.7771, + "35": 3.84701, + "36": 3.81111, + "37": 3.75554, + "38": 3.99552, + "39": 3.87227, + "40": 3.81079, + "41": 3.83039, + "42": 3.74567, + "43": 3.82531, + "44": 3.78258, + "45": 3.73294, + "46": 4.579, + "47": 3.72516, + "48": 3.8117, + "49": 3.80651, + "50": 3.78283 + } + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/model_config.yaml new file mode 100644 index 00000000000..7a1db6d3427 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/model_config.yaml @@ -0,0 +1,79 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: :4096:8 +TEST_TYPE: frozen-start +MODE: rl +MODEL_ARGS: + --tiktoken-pattern: v2 + --use-mcore-models: true + --tokenizer-type: TikTokenizer + --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json + --load: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/ + --auto-detect-ckpt-format: true + --max-tokens-to-oom: 3600000 + --inference-max-seq-length: 4096 + --attention-backend: flash + --mock-data: true + --micro-batch-size: 1 + --no-load-optim: true + --no-use-tokenizer-model-from-checkpoint-args: true + --timing-log-level: 0 + --distributed-backend: nccl + --log-interval: 1 + --log-progress: true + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 1 + --ckpt-format: torch_dist + --bf16: true + --log-memory-to-tensorboard: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --num-layers: 24 + --hidden-size: 1152 + --num-attention-heads: 16 + --max-position-embeddings: 1024 + --seq-length: 1024 + --timing-log-option: minmax + --log-throughput: true + --no-create-attention-mask-in-dataloader: true + --straggler-minmax-count: 16 + --tensorboard-log-interval: 1 + --empty-unused-memory-level: 2 + --langrl-inference-server-type: inplace_megatron + --seed: 42 + --calculate-per-token-loss: true + --rl-use-sequence-packing: true + --rl-sequence-packing-bin-size: 8192 + --rl-sequence-packing-algo: fifo + --rl-offload-optimizer-during-inference: true + --timing-log-level: 1 + --log-timers-to-tensorboard: true + --cuda-graph-impl: local + --micro-batch-size: 1 + --global-batch-size: 16 + --grpo-group-size: 2 + --grpo-prompts-per-step: 8 + --grpo-iterations: 1 + --grpo-clamp-eps-lower: 0.2 + --grpo-clamp-eps-upper: 0.2 + --grpo-kl-beta: 0.0 + --grpo-entropy-term-weight: 0.0 + --langrl-env-config: examples/rl/environment_configs/countdown.yaml + --rl-partial-rollouts: true + --lr: 0.000001 + --lr-warmup-samples: 0 + --clip-grad: 1.0 + --use-checkpoint-args: true + --dist-ckpt-strictness: log_unexpected + --perform-rl-step: true + --train-samples: 48828125 + --exit-interval: 50 + --tensorboard-dir: ${TENSORBOARD_PATH} + --save-interval: 1000000 + --eval-interval: 1000000 + --finetune: true + --rl-inference-tensor-model-parallel-size: 2 \ No newline at end of file diff --git a/tests/test_utils/recipes/gpt-grpo.yaml b/tests/test_utils/recipes/gpt-grpo.yaml index 7849128de04..049618ce18e 100644 --- a/tests/test_utils/recipes/gpt-grpo.yaml +++ b/tests/test_utils/recipes/gpt-grpo.yaml @@ -60,3 +60,8 @@ products: - environment: [dev] scope: [mr, mr-github] platforms: [dgx_h100] + - test_case: [gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest] + products: + - environment: [dev] + scope: [mr, mr-github] + platforms: [dgx_h100] From 2190b222535877c9b9be596b30c2dda27a3e6205 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 2 Dec 2025 17:44:56 -0800 Subject: [PATCH 12/44] refactor --- .../data_parallel_inference_coordinator.py | 2 +- .../core/inference/engines/dynamic_engine.py | 5 +- megatron/core/process_groups_config.py | 40 +++++-- megatron/core/resharding/__init__.py | 11 +- .../core/resharding/copy_services/__init__.py | 2 - .../core/resharding/copy_services/base.py | 8 +- .../copy_services/gloo_copy_service.py | 7 +- .../copy_services/nccl_copy_service.py | 13 ++- megatron/core/resharding/execution.py | 9 +- megatron/core/resharding/planner.py | 24 ++--- megatron/core/resharding/refit.py | 26 ++--- megatron/core/resharding/utils.py | 29 +++-- megatron/core/transformer/cuda_graphs.py | 2 +- .../inference/test_nccl_model_swap.py | 102 +++++++++++++----- 14 files changed, 173 insertions(+), 107 deletions(-) diff --git a/megatron/core/inference/data_parallel_inference_coordinator.py b/megatron/core/inference/data_parallel_inference_coordinator.py index 9364374cdcd..4ad36015e50 100644 --- a/megatron/core/inference/data_parallel_inference_coordinator.py +++ b/megatron/core/inference/data_parallel_inference_coordinator.py @@ -25,7 +25,7 @@ except: HAVE_MSGPACK = False -#TODO We need to see where the process group collection is used. +# TODO We need to see where the process group collection is used. # Register faulthandler to emit stack traces upon process kill. faulthandler.enable() faulthandler.register(signal.SIGTERM, all_threads=False, chain=True) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index efbabede0d6..5bc055d4a97 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -40,9 +40,10 @@ TextGenerationController, ) from megatron.core.inference.utils import Counter, await_process_event +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import delete_cuda_graphs from megatron.core.utils import get_asyncio_loop, internal_api, trace_async_exceptions -from megatron.core.process_groups_config import ProcessGroupCollection + try: from tqdm import tqdm @@ -401,7 +402,7 @@ async def start_listening_to_data_parallel_coordinator( if launch_inference_coordinator and self.is_dp_coordinator: spawn_context = multiprocessing.get_context('spawn') coordinator_ready_event = spawn_context.Event() - #TODO(Peter) We need to pass the correct data parallel world size here + # TODO(Peter) We need to pass the correct data parallel world size here self.inference_coordinator_process = spawn_context.Process( target=DataParallelInferenceCoordinator.entrypoint, args=( diff --git a/megatron/core/process_groups_config.py b/megatron/core/process_groups_config.py index 74f078de431..f704e814970 100644 --- a/megatron/core/process_groups_config.py +++ b/megatron/core/process_groups_config.py @@ -265,7 +265,9 @@ def get_pipeline_model_parallel_rank(self) -> int: pp_ranks = torch.distributed.get_process_group_ranks(self.pp) return pp_ranks.index(global_rank) - def is_pipeline_first_stage(self, ignore_virtual: bool = True, vp_stage: Optional[int] = None) -> bool: + def is_pipeline_first_stage( + self, ignore_virtual: bool = True, vp_stage: Optional[int] = None + ) -> bool: """Return True if this rank is on the first PP stage. By default, ignores virtual pipeline (matches legacy interface). If you need VP-aware @@ -273,14 +275,16 @@ def is_pipeline_first_stage(self, ignore_virtual: bool = True, vp_stage: Optiona """ pp_ranks = torch.distributed.get_process_group_ranks(self.pp) global_rank = torch.distributed.get_rank() - is_pp_first = (pp_ranks[0] == global_rank) + is_pp_first = pp_ranks[0] == global_rank vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() if ignore_virtual or vp_size in (None, 1) or vp_stage is None: return is_pp_first - is_vp_first = (vp_stage == 0) + is_vp_first = vp_stage == 0 return is_vp_first and is_pp_first - def is_pipeline_last_stage(self, ignore_virtual: bool = True, vp_stage: Optional[int] = None) -> bool: + def is_pipeline_last_stage( + self, ignore_virtual: bool = True, vp_stage: Optional[int] = None + ) -> bool: """Return True if this rank is on the last PP stage. By default, ignores virtual pipeline (matches legacy interface). If you need VP-aware @@ -288,11 +292,11 @@ def is_pipeline_last_stage(self, ignore_virtual: bool = True, vp_stage: Optional """ pp_ranks = torch.distributed.get_process_group_ranks(self.pp) global_rank = torch.distributed.get_rank() - is_pp_last = (pp_ranks[-1] == global_rank) + is_pp_last = pp_ranks[-1] == global_rank vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() if ignore_virtual or vp_size in (None, 1) or vp_stage is None: return is_pp_last - is_vp_last = (vp_stage == (vp_size - 1)) + is_vp_last = vp_stage == (vp_size - 1) return is_vp_last and is_pp_last def get_data_parallel_rank(self) -> int: @@ -309,7 +313,9 @@ def get_context_parallel_rank(self) -> int: cp_ranks = torch.distributed.get_process_group_ranks(self.cp) return cp_ranks.index(global_rank) - def get_data_parallel_group(self, with_context_parallel: bool = False, partial_data_parallel: bool = False): + def get_data_parallel_group( + self, with_context_parallel: bool = False, partial_data_parallel: bool = False + ): """Return the DP/DP+CP process group, optionally partial.""" if with_context_parallel: if partial_data_parallel: @@ -326,41 +332,51 @@ def get_data_parallel_group(self, with_context_parallel: bool = False, partial_d return self.dp def get_tensor_model_parallel_group(self): + """Return the tensor model parallel (TP) process group.""" return self.tp def get_pipeline_model_parallel_group(self): + """Return the pipeline model parallel (PP) process group.""" return self.pp def get_model_parallel_group(self): + """Return the combined model parallel (MP) process group.""" return self.mp - + def get_model_parallel_world_size(self) -> int: """Return MP world size using the MP process group.""" return torch.distributed.get_world_size(self.mp) - + def get_model_parallel_src_rank(self) -> int: """Return the source (leader) global rank for the MP group.""" ranks = torch.distributed.get_process_group_ranks(self.mp) return ranks[0] def get_context_parallel_group(self): + """Return the context parallel (CP) process group, if configured.""" return getattr(self, 'cp', None) def get_expert_model_parallel_group(self): + """Return the expert model parallel (EP) process group, if configured.""" return getattr(self, 'ep', None) def get_expert_data_parallel_group(self, partial_expert_data_parallel: bool = False): + """Return the expert data parallel (ExDP) process group.""" if partial_expert_data_parallel: return getattr(self, 'intra_expt_dp', None) return getattr(self, 'expt_dp', None) def get_inter_distributed_optimizer_instance_group(self): + """Return the inter-node distributed optimizer instance process group, if configured.""" return getattr(self, 'inter_dist_opt', None) def get_intra_distributed_optimizer_instance_group(self): + """Return the intra-node distributed optimizer instance process group, if configured.""" return getattr(self, 'intra_dist_opt', None) - def get_data_parallel_src_rank(self, with_context_parallel: bool = False, partial_data_parallel: bool = False) -> int: + def get_data_parallel_src_rank( + self, with_context_parallel: bool = False, partial_data_parallel: bool = False + ) -> int: """Return the source (leader) global rank for the selected DP group.""" group = self.get_data_parallel_group( with_context_parallel=with_context_parallel, partial_data_parallel=partial_data_parallel @@ -368,7 +384,9 @@ def get_data_parallel_src_rank(self, with_context_parallel: bool = False, partia ranks = torch.distributed.get_process_group_ranks(group) return ranks[0] - def get_data_parallel_world_size(self, with_context_parallel: bool = False, partial_data_parallel: bool = False) -> int: + def get_data_parallel_world_size( + self, with_context_parallel: bool = False, partial_data_parallel: bool = False + ) -> int: """Return world size of the selected DP group.""" group = self.get_data_parallel_group( with_context_parallel=with_context_parallel, partial_data_parallel=partial_data_parallel diff --git a/megatron/core/resharding/__init__.py b/megatron/core/resharding/__init__.py index f84b0665696..ef530c8bd10 100644 --- a/megatron/core/resharding/__init__.py +++ b/megatron/core/resharding/__init__.py @@ -1,12 +1,7 @@ -from .planner import build_centralized_reshard_plan from .execution import execute_reshard_plan -from .refit import swap_model_weights, reshard_model_weights -from .utils import ( - ParameterMetadata, - ShardingDescriptor, - TransferOp, - ReshardPlan, -) +from .planner import build_centralized_reshard_plan +from .refit import reshard_model_weights, swap_model_weights +from .utils import ParameterMetadata, ReshardPlan, ShardingDescriptor, TransferOp __all__ = [ "build_centralized_reshard_plan", diff --git a/megatron/core/resharding/copy_services/__init__.py b/megatron/core/resharding/copy_services/__init__.py index eb7133c64b0..91ba6e4d267 100644 --- a/megatron/core/resharding/copy_services/__init__.py +++ b/megatron/core/resharding/copy_services/__init__.py @@ -4,5 +4,3 @@ from .nccl_copy_service import NCCLCopyService __all__ = ["CopyService", "NCCLCopyService"] - - diff --git a/megatron/core/resharding/copy_services/base.py b/megatron/core/resharding/copy_services/base.py index cab7dc71655..13d2a348985 100644 --- a/megatron/core/resharding/copy_services/base.py +++ b/megatron/core/resharding/copy_services/base.py @@ -1,20 +1,24 @@ from __future__ import annotations from abc import ABC, abstractmethod + import torch class CopyService(ABC): + """Abstract interface for submitting and executing batched P2P copy operations.""" + @abstractmethod def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): + """Register a tensor send from the current rank to ``dest_rank``.""" ... @abstractmethod def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): + """Register a tensor receive into ``dest_tensor`` from ``src_rank``.""" ... @abstractmethod def run(self): + """Execute all previously submitted send/recv operations as a single batch.""" ... - - diff --git a/megatron/core/resharding/copy_services/gloo_copy_service.py b/megatron/core/resharding/copy_services/gloo_copy_service.py index 42b72be1954..4ebc5bf22b5 100644 --- a/megatron/core/resharding/copy_services/gloo_copy_service.py +++ b/megatron/core/resharding/copy_services/gloo_copy_service.py @@ -9,18 +9,21 @@ from .base import CopyService - logger = logging.getLogger(__name__) @dataclass class SendOp: + """Simple container describing a single send operation.""" + tensor: torch.Tensor dest_rank: int @dataclass class RecvOp: + """Simple container describing a single receive operation.""" + tensor: torch.Tensor src_rank: int @@ -80,5 +83,3 @@ def run(self): logger.info("GlooCopyService: batched communication completed") self.send_ops.clear() self.recv_ops.clear() - - diff --git a/megatron/core/resharding/copy_services/nccl_copy_service.py b/megatron/core/resharding/copy_services/nccl_copy_service.py index 687f967128f..3a50456176e 100644 --- a/megatron/core/resharding/copy_services/nccl_copy_service.py +++ b/megatron/core/resharding/copy_services/nccl_copy_service.py @@ -14,12 +14,16 @@ @dataclass class SendOp: + """Simple container describing a single NCCL send operation.""" + tensor: torch.Tensor dest_rank: int @dataclass class RecvOp: + """Simple container describing a single NCCL receive operation.""" + tensor: torch.Tensor src_rank: int @@ -45,7 +49,12 @@ def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): def run(self): total_ops = len(self.send_ops) + len(self.recv_ops) - logger.info(f"Executing batched communication: {len(self.send_ops)} sends + {len(self.recv_ops)} recvs = {total_ops} ops") + logger.info( + "Executing batched communication: %d sends + %d recvs = %d ops", + len(self.send_ops), + len(self.recv_ops), + total_ops, + ) p2p_ops = [] for op in self.send_ops: @@ -61,5 +70,3 @@ def run(self): logger.info("Batched communication completed") self.send_ops.clear() self.recv_ops.clear() - - diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py index a710210308d..2850d35c49a 100644 --- a/megatron/core/resharding/execution.py +++ b/megatron/core/resharding/execution.py @@ -6,9 +6,8 @@ import torch import torch.distributed as dist -from .utils import ReshardPlan from .copy_services.base import CopyService - +from .utils import ReshardPlan logger = logging.getLogger(__name__) @@ -28,7 +27,7 @@ def execute_reshard_plan( src_params = {name: p for name, p in src_module.named_parameters(recurse=True)} dst_params = {name: p for name, p in dst_module.named_parameters(recurse=True)} - #TODO(Peter) do this on like a separate stream? + # TODO(Peter) do this on like a separate stream? # Execute local copies for param_name, src_param, dst_param, src_slice, dst_slice in plan.local_copy_ops: if src_param is None: @@ -61,7 +60,7 @@ def execute_reshard_plan( # Execute logger.info(f"Executing {len(plan.send_ops)} sends + {len(plan.recv_ops)} recvs") service.run() - #TODO(Peter) remove this eventually? + # TODO(Peter) remove this eventually? dist.barrier() # Write back received buffers into their destination parameter slices @@ -70,5 +69,3 @@ def execute_reshard_plan( dst_param.data[dst_slice].copy_(recv_buffer) logger.info("Reshard complete") - - diff --git a/megatron/core/resharding/planner.py b/megatron/core/resharding/planner.py index f40722ea680..a5ac87b81b5 100644 --- a/megatron/core/resharding/planner.py +++ b/megatron/core/resharding/planner.py @@ -2,28 +2,25 @@ import logging import math -from typing import Optional import torch import torch.distributed as dist from .utils import ( ParameterMetadata, + ReshardPlan, ShardingDescriptor, TransferOp, - ReshardPlan, _get_rank_in_group, extract_param_metadata, select_src_metadata_balanced, ) - logger = logging.getLogger(__name__) def _build_descriptors_for_param( - src_metadata: ParameterMetadata, - dst_metadata: ParameterMetadata, + src_metadata: ParameterMetadata, dst_metadata: ParameterMetadata ) -> list[ShardingDescriptor]: """Construct sharding descriptors (currently TP) for this parameter based on actual layout. Guard TP descriptor with size conservation so we don't mis-classify replicated tensors. @@ -83,7 +80,9 @@ def _plan_multi_dim_lcm( if not descriptors: return [] if len(descriptors) != 1: - raise NotImplementedError(f"{param_name}: _plan_multi_dim_lcm supports TP-only (one descriptor)") + raise NotImplementedError( + f"{param_name}: _plan_multi_dim_lcm supports TP-only (one descriptor)" + ) if descriptors[0].name != "tp": raise NotImplementedError(f"{param_name}: _plan_multi_dim_lcm expects TP descriptor") d = descriptors[0] @@ -99,7 +98,8 @@ def _plan_multi_dim_lcm( if src_world * src_local != dst_world * dst_local: raise RuntimeError( f"{param_name}: size mismatch on TP dim{dim} " - f"(src_world={src_world}, src_local={src_local}, dst_world={dst_world}, dst_local={dst_local})" + f"(src_world={src_world}, src_local={src_local}, " + f"dst_world={dst_world}, dst_local={dst_local})" ) # LCM tiling with strides Ns = src_world * max(1, d.src_stride) @@ -138,6 +138,7 @@ def _plan_multi_dim_lcm( src_slice[dim] = slice(src_start, src_start + unit) dst_slice[dim] = slice(dst_start, dst_start + unit) ops.append((src_global_rank, tuple(src_slice), tuple(dst_slice))) + # Stable order by destination offset def dst_key(op): _, _, dsl = op @@ -207,9 +208,7 @@ def _determine_source_ranks_for_dst_param( def build_centralized_reshard_plan( - src_module: torch.nn.Module, - dst_module: torch.nn.Module, - num_experts: int = None, + src_module: torch.nn.Module, dst_module: torch.nn.Module, num_experts: int = None ) -> ReshardPlan: """ Centralized planning: Rank 0 builds complete plan for all ranks, then scatters. @@ -266,7 +265,8 @@ def build_centralized_reshard_plan( src_meta_list = src_param_metadata.get(resolved_name) if not src_meta_list: raise RuntimeError( - f"Destination parameter '{resolved_name}' on rank {dst_rank} not found in source model." + f"Destination parameter '{resolved_name}' on rank {dst_rank} " + "not found in source model." ) # Choose a representative source metadata with DP round-robin balancing src_metadata = select_src_metadata_balanced(src_meta_list, dst_metadata, dst_rank) @@ -316,5 +316,3 @@ def build_centralized_reshard_plan( ) return my_plan - - diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index 95c0203b5f8..8e917280c20 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -1,4 +1,5 @@ from __future__ import annotations + """ High-level refit/reshard orchestration: - swap_model_weights: public API; accepts a backend name or CopyService and delegates. @@ -7,28 +8,25 @@ from typing import Any, Optional, Union -from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.core import parallel_state +from megatron.core.models.common.language_module.language_module import LanguageModule + from . import build_centralized_reshard_plan, execute_reshard_plan from .copy_services.base import CopyService -from .copy_services.nccl_copy_service import NCCLCopyService from .copy_services.gloo_copy_service import GlooCopyService +from .copy_services.nccl_copy_service import NCCLCopyService def _unwrap_module(module: LanguageModule) -> Any: return ( module.module.module if hasattr(module, 'module') and hasattr(module.module, 'module') - else module.module - if hasattr(module, 'module') - else module + else module.module if hasattr(module, 'module') else module ) def swap_model_weights( - src_model: LanguageModule, - target_model: LanguageModule, - refit_method: Union[str, CopyService], + src_model: LanguageModule, target_model: LanguageModule, refit_method: Union[str, CopyService] ): """ Orchestrate weight swap/refit. @@ -49,14 +47,12 @@ def swap_model_weights( raise ValueError(f"Unknown refit_method '{refit_method}'") else: raise TypeError("refit_method must be a str backend name or a CopyService instance") - def reshard_model_weights( - src_model: LanguageModule, - target_model: LanguageModule, - service: CopyService, + src_model: LanguageModule, target_model: LanguageModule, service: CopyService ): + """Reshard and copy model weights from ``src_model`` to ``target_model`` using ``service``.""" # Handle list-wrapped modules used throughout training utils src_lm = src_model[0] if isinstance(src_model, (list, tuple)) else src_model tgt_lm = target_model[0] if isinstance(target_model, (list, tuple)) else target_model @@ -73,7 +69,7 @@ def reshard_model_weights( if not hasattr(tgt_core, "pg_collection") or tgt_core.pg_collection is None: raise RuntimeError("Target model missing pg_collection required for NCCL reshard") - #TODO(Peter): We should figure out why this happens. Seems like a bug in Orthotope. + # TODO(Peter): We should figure out why this happens. Seems like a bug in Orthotope. # Fill missing DP group on the source using Megatron's parallel state if not provided if getattr(src_core.pg_collection, "dp", None) is None: src_core.pg_collection.dp = parallel_state.get_data_parallel_group() @@ -86,7 +82,5 @@ def reshard_model_weights( setattr(tgt_core, "_cached_reshard_plan", plan) else: plan = cached_plan - - execute_reshard_plan(plan, src_core, tgt_core, service=service) - + execute_reshard_plan(plan, src_core, tgt_core, service=service) diff --git a/megatron/core/resharding/utils.py b/megatron/core/resharding/utils.py index fd4d070e18f..c8f95186998 100644 --- a/megatron/core/resharding/utils.py +++ b/megatron/core/resharding/utils.py @@ -7,7 +7,6 @@ import torch import torch.distributed as dist - # ----------------------------------------------------------------------------- # Dataclasses used by the planner # ----------------------------------------------------------------------------- @@ -15,6 +14,8 @@ @dataclass class TransferOp: + """Single logical send/recv operation used in a reshard plan.""" + param_name: str peer_rank: int # Who to send to / receive from is_send: bool # True=send, False=recv @@ -74,7 +75,13 @@ class ReshardPlan: send_ops: list[TransferOp] recv_ops: list[TransferOp] local_copy_ops: list[ - tuple[str, torch.nn.Parameter | None, torch.nn.Parameter | None, tuple[slice, ...], tuple[slice, ...]] + tuple[ + str, + torch.nn.Parameter | None, + torch.nn.Parameter | None, + tuple[slice, ...], + tuple[slice, ...], + ] ] # (name, src_param, dst_param, src_slice, dst_slice) def __str__(self): @@ -102,10 +109,14 @@ def _get_rank_in_group(global_rank: int, group_ranks: list[int]) -> int: def _detect_expert_index_from_param_name(param_name: str) -> Optional[int]: """Extract expert index from parameter name for TEGroupedMLP per-expert tensors.""" for part in param_name.split('.'): - if part.startswith('weight') and len(part) > len('weight') and part[len('weight'):].isdigit(): - return int(part[len('weight'):]) - if part.startswith('bias') and len(part) > len('bias') and part[len('bias'):].isdigit(): - return int(part[len('bias'):]) + if ( + part.startswith('weight') + and len(part) > len('weight') + and part[len('weight') :].isdigit() + ): + return int(part[len('weight') :]) + if part.startswith('bias') and len(part) > len('bias') and part[len('bias') :].isdigit(): + return int(part[len('bias') :]) return None @@ -134,9 +145,9 @@ def assign_resolved_name_inplace(meta: ParameterMetadata) -> None: parts = meta.name.split('.') new_parts = [] for p in parts: - if (p.startswith('weight') and len(p) > len('weight') and p[len('weight'):].isdigit()): + if p.startswith('weight') and len(p) > len('weight') and p[len('weight') :].isdigit(): new_parts.append('weight' + str(global_idx)) - elif (p.startswith('bias') and len(p) > len('bias') and p[len('bias'):].isdigit()): + elif p.startswith('bias') and len(p) > len('bias') and p[len('bias') :].isdigit(): new_parts.append('bias' + str(global_idx)) else: new_parts.append(p) @@ -227,5 +238,3 @@ def select_src_metadata_balanced( logger = logging.getLogger(__name__) - - diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 64b05d5e8f4..bd2ce9656c0 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -58,7 +58,7 @@ _IS_GRAPH_CAPTURING = False logger = logging.getLogger(__name__) -#TODO(Peter) We have changes needed in this for refit to work properly. +# TODO(Peter) We have changes needed in this for refit to work properly. # Freeze GC during capture. # TODO (@lmcafee): remove all freeze-GC code once most users are on PyTorch 2.9+. diff --git a/tests/unit_tests/inference/test_nccl_model_swap.py b/tests/unit_tests/inference/test_nccl_model_swap.py index bfd5d15bc27..b19f956615b 100644 --- a/tests/unit_tests/inference/test_nccl_model_swap.py +++ b/tests/unit_tests/inference/test_nccl_model_swap.py @@ -1,30 +1,31 @@ -import os import copy +import os import types +from typing import Optional, Tuple + import pytest import torch import torch.distributed as dist -import os -import pytest -from tests.unit_tests.test_utilities import Utils -from megatron.core.resharding.refit import swap_model_weights +from megatron.core import parallel_state as mpu from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, ) -from megatron.core import parallel_state as mpu -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.resharding.refit import swap_model_weights from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.model_parallel_config import ModelParallelConfig -from typing import Tuple, Optional +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils -def _build_pg_collection(tp_size: int, pp_size: int = None, ep_size: int = 1) -> ProcessGroupCollection: +def _build_pg_collection( + tp_size: int, pp_size: int = None, ep_size: int = 1 +) -> ProcessGroupCollection: cp_size = mpu.get_context_parallel_world_size() if pp_size is None: pp_size = mpu.get_pipeline_model_parallel_world_size() @@ -32,7 +33,9 @@ def _build_pg_collection(tp_size: int, pp_size: int = None, ep_size: int = 1) -> dp_size = world_size // (tp_size * cp_size * ep_size * pp_size) assert dp_size >= 1 and (tp_size * cp_size * ep_size * pp_size * dp_size) == world_size - grid = HyperCommGrid([tp_size, cp_size, ep_size, pp_size, dp_size], ["tp", "cp", "ep", "pp", "dp"]) + grid = HyperCommGrid( + [tp_size, cp_size, ep_size, pp_size, dp_size], ["tp", "cp", "ep", "pp", "dp"] + ) tp_group = grid.create_pg("tp") cp_group = grid.create_pg("cp") pp_group = grid.create_pg("pp") @@ -47,7 +50,9 @@ def _build_pg_collection(tp_size: int, pp_size: int = None, ep_size: int = 1) -> tp_dp_cp_group = grid.create_pg(["tp", "cp", "dp"]) embd_group_ranks = mpu.default_embedding_ranks(dist.get_process_group_ranks(pp_group)) embd_group = dist.new_group(ranks=embd_group_ranks) - pos_embd_group_ranks = mpu.default_position_embedding_ranks(dist.get_process_group_ranks(pp_group)) + pos_embd_group_ranks = mpu.default_position_embedding_ranks( + dist.get_process_group_ranks(pp_group) + ) pos_embd_group = dist.new_group(ranks=pos_embd_group_ranks) return ProcessGroupCollection( tp=tp_group, @@ -79,8 +84,7 @@ def _build_gpt( model = GPTModel( config=config, transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, - moe_grouped_gemm=(num_moe_experts is not None), + num_experts=num_moe_experts, moe_grouped_gemm=(num_moe_experts is not None) ), vocab_size=vocab_size, max_sequence_length=seq_len, @@ -113,7 +117,7 @@ def _set_pg_collection(module, tp_group, dp_group): @pytest.mark.parametrize( "src_tp,src_pp,src_ep,dst_tp,dst_pp,dst_ep,num_experts", [ - #TP only changes + # TP only changes (2, 1, 1, 1, 1, 1, None), # TP2 -> TP1 (1, 1, 1, 2, 1, 1, None), # TP1 -> TP2 # PP only changes @@ -124,22 +128,32 @@ def _set_pg_collection(module, tp_group, dp_group): (1, 1, 1, 2, 2, 1, None), # TP1,PP1 -> TP2,PP2 (2, 1, 1, 1, 2, 1, None), # TP2,PP1 -> TP1,PP2 (1, 2, 1, 2, 1, 1, None), # TP1,PP2 -> TP2,PP1 - (1, 1, 2, 1, 1, 4, 4), # EP2 -> EP4 + (1, 1, 2, 1, 1, 4, 4), # EP2 -> EP4 (1, 1, 2, 1, 1, 1, 4), (1, 1, 1, 1, 1, 2, 4), (1, 1, 2, 1, 2, 2, 4), ], ) def test_nccl_swap_gpt_parametrized( - src_tp: int, src_pp: int, src_ep: int, dst_tp: int, dst_pp: int, dst_ep: int, num_experts: Optional[int] + src_tp: int, + src_pp: int, + src_ep: int, + dst_tp: int, + dst_pp: int, + dst_ep: int, + num_experts: Optional[int], ): # Initialize environment with source MP sizing - Utils.initialize_model_parallel(tensor_model_parallel_size=src_tp, pipeline_model_parallel_size=src_pp) + Utils.initialize_model_parallel( + tensor_model_parallel_size=src_tp, pipeline_model_parallel_size=src_pp + ) # Validate divisibility post-init using the default PG safely world = dist.get_world_size() if (world % (src_tp * src_pp * src_ep) != 0) or (world % (dst_tp * dst_pp * dst_ep) != 0): Utils.destroy_model_parallel() - pytest.skip("WORLD_SIZE must be divisible by both src_tp*src_pp*src_ep and dst_tp*dst_pp*dst_ep") + pytest.skip( + "WORLD_SIZE must be divisible by both src_tp*src_pp*src_ep and dst_tp*dst_pp*dst_ep" + ) model_parallel_cuda_manual_seed(1234) torch.manual_seed(1234) @@ -157,7 +171,7 @@ def test_nccl_swap_gpt_parametrized( hidden_dropout=0.0, attention_dropout=0.0, moe_router_dtype="fp64", - moe_token_dispatcher_type="alltoall" + moe_token_dispatcher_type="alltoall", ) # Build PGs and models (always use unified PG builder so we can set EP) @@ -186,13 +200,39 @@ def test_nccl_swap_gpt_parametrized( Utils.destroy_model_parallel() pytest.skip("Transformer Engine not available; skipping TE-grouped MoE test") # Use parallel_output=False to gather TP logits inside model and emit only on last PP stage - src_model = _build_gpt(src_cfg, vocab_size, seq_len, src_pgs, parallel_output=False, num_moe_experts=num_experts).to(device).eval() - dst_model = _build_gpt(dst_cfg, vocab_size, seq_len, dst_pgs, parallel_output=False, num_moe_experts=num_experts).to(device).eval() + src_model = ( + _build_gpt( + src_cfg, + vocab_size, + seq_len, + src_pgs, + parallel_output=False, + num_moe_experts=num_experts, + ) + .to(device) + .eval() + ) + dst_model = ( + _build_gpt( + dst_cfg, + vocab_size, + seq_len, + dst_pgs, + parallel_output=False, + num_moe_experts=num_experts, + ) + .to(device) + .eval() + ) # Inputs batch = 2 - tokens = torch.randint(low=0, high=vocab_size, size=(batch, seq_len), device=device, dtype=torch.long) - position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch, -1) + tokens = torch.randint( + low=0, high=vocab_size, size=(batch, seq_len), device=device, dtype=torch.long + ) + position_ids = ( + torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch, -1) + ) attention_mask = torch.ones((batch, 1, seq_len, seq_len), device=device, dtype=torch.bool) # Collect source reference logits (parallel_output=False ensures full vocab on last PP stage) @@ -214,13 +254,17 @@ def test_nccl_swap_gpt_parametrized( dst_pp_ranks = dist.get_process_group_ranks(dst_pgs.pp) dst_last_pp_rank = dst_pp_ranks[-1] with torch.no_grad(): - dst_out = dst_model(tokens, position_ids, attention_mask) # last stage returns tensor, others return None + dst_out = dst_model( + tokens, position_ids, attention_mask + ) # last stage returns tensor, others return None if dist.get_rank() == dst_last_pp_rank: dst_logits.copy_(dst_out) # [b, s, vocab] dist.broadcast(dst_logits, src=dst_last_pp_rank, group=dst_pgs.pp) # Compare assert ref_logits.shape == dst_logits.shape - assert torch.allclose(dst_logits, ref_logits, atol=1e-4, rtol=1e-4), f"Refit src(TP={src_tp},PP={src_pp})->dst(TP={dst_tp},PP={dst_pp}) GPT outputs differ" + assert torch.allclose( + dst_logits, ref_logits, atol=1e-4, rtol=1e-4 + ), f"Refit src(TP={src_tp},PP={src_pp})->dst(TP={dst_tp},PP={dst_pp}) GPT outputs differ" dist.barrier() Utils.destroy_model_parallel() From 9641c38a6a2dde146109bf81ba33d03ede95383b Mon Sep 17 00:00:00 2001 From: root Date: Wed, 3 Dec 2025 08:29:45 -0800 Subject: [PATCH 13/44] fix tests --- .../unit_tests/dist_checkpointing/test_optimizer.py | 5 +++++ tests/unit_tests/dist_checkpointing/utils.py | 10 ++++++++++ .../transformer/test_multi_latent_attention.py | 12 +++++++++--- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/dist_checkpointing/test_optimizer.py b/tests/unit_tests/dist_checkpointing/test_optimizer.py index df501366f74..4c5bfd9b32e 100644 --- a/tests/unit_tests/dist_checkpointing/test_optimizer.py +++ b/tests/unit_tests/dist_checkpointing/test_optimizer.py @@ -276,6 +276,11 @@ def initialize_real_model( virtual_pipeline_model_parallel_size=None, **config_kwargs, ): + # These kwargs are passed through training.get_model for model construction, + # but are not part of TransformerConfig; strip them before building config. + config_kwargs.pop("pg_collection", None) + config_kwargs.pop("config", None) + torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) diff --git a/tests/unit_tests/dist_checkpointing/utils.py b/tests/unit_tests/dist_checkpointing/utils.py index e722ebe79ca..ee3b0d75164 100644 --- a/tests/unit_tests/dist_checkpointing/utils.py +++ b/tests/unit_tests/dist_checkpointing/utils.py @@ -24,6 +24,11 @@ def initialize_gpt_model( pre_process=True, post_process=True, seed=0, use_glu=True, **config_kwargs ): + # These kwargs are passed through training.get_model for model construction, + # but are not part of TransformerConfig; strip them before building config. + config_kwargs.pop("pg_collection", None) + config_kwargs.pop("config", None) + torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) @@ -61,6 +66,11 @@ def initialize_moe_model( use_grouped_mlp=False, **config_kwargs, ): + # These kwargs are passed through training.get_model for model construction, + # but are not part of TransformerConfig; strip them before building config. + config_kwargs.pop("pg_collection", None) + config_kwargs.pop("config", None) + torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) expert_num = 8 diff --git a/tests/unit_tests/transformer/test_multi_latent_attention.py b/tests/unit_tests/transformer/test_multi_latent_attention.py index 8ade4b6bcb8..93976a24ac6 100644 --- a/tests/unit_tests/transformer/test_multi_latent_attention.py +++ b/tests/unit_tests/transformer/test_multi_latent_attention.py @@ -1082,7 +1082,13 @@ def test_parallel_multi_latent_attention_correctness( hidden_size = 128 # Model initialization function - def initialize_gpt_model(config, pre_process=True, post_process=True, vp_stage=None): + def initialize_gpt_model( + pre_process=True, + post_process=True, + vp_stage=None, + pg_collection=None, + config=None, + ): layer_spec = get_gpt_layer_with_transformer_engine_spec(multi_latent_attention=True) gpt_model = GPTModel( config=config, @@ -1142,7 +1148,7 @@ def initialize_gpt_model(config, pre_process=True, post_process=True, vp_stage=N mock_args.context_parallel_size = 1 mock_args.sequence_parallel = 1 gpt_model = unwrap_model( - get_model(partial(initialize_gpt_model, config=transformer_config)) + get_model(initialize_gpt_model, config=transformer_config) ) # Initialize args and save checkpoint @@ -1179,7 +1185,7 @@ def initialize_gpt_model(config, pre_process=True, post_process=True, vp_stage=N mock_args.context_parallel_size = cp mock_args.sequence_parallel = sp gpt_model = unwrap_model( - get_model(partial(initialize_gpt_model, config=transformer_config)) + get_model(initialize_gpt_model, config=transformer_config) ) with mock.patch('megatron.training.checkpointing.check_checkpoint_args'): with mock.patch('megatron.training.checkpointing.update_num_microbatches'): From d5d4c47ee225db341849de0b4469c2fd1735c14d Mon Sep 17 00:00:00 2001 From: root Date: Wed, 3 Dec 2025 20:36:38 -0800 Subject: [PATCH 14/44] check changes --- .../data_parallel_inference_coordinator.py | 1 - .../core/inference/engines/dynamic_engine.py | 1 - .../copy_services/nccl_copy_service.py | 26 +++++++++++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/megatron/core/inference/data_parallel_inference_coordinator.py b/megatron/core/inference/data_parallel_inference_coordinator.py index 4ad36015e50..e1fe7b21566 100644 --- a/megatron/core/inference/data_parallel_inference_coordinator.py +++ b/megatron/core/inference/data_parallel_inference_coordinator.py @@ -25,7 +25,6 @@ except: HAVE_MSGPACK = False -# TODO We need to see where the process group collection is used. # Register faulthandler to emit stack traces upon process kill. faulthandler.enable() faulthandler.register(signal.SIGTERM, all_threads=False, chain=True) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 5bc055d4a97..a37cad4ab31 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -402,7 +402,6 @@ async def start_listening_to_data_parallel_coordinator( if launch_inference_coordinator and self.is_dp_coordinator: spawn_context = multiprocessing.get_context('spawn') coordinator_ready_event = spawn_context.Event() - # TODO(Peter) We need to pass the correct data parallel world size here self.inference_coordinator_process = spawn_context.Process( target=DataParallelInferenceCoordinator.entrypoint, args=( diff --git a/megatron/core/resharding/copy_services/nccl_copy_service.py b/megatron/core/resharding/copy_services/nccl_copy_service.py index 3a50456176e..29effff9e61 100644 --- a/megatron/core/resharding/copy_services/nccl_copy_service.py +++ b/megatron/core/resharding/copy_services/nccl_copy_service.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os from dataclasses import dataclass from typing import List @@ -56,6 +57,31 @@ def run(self): total_ops, ) + # Dump only lightweight tensor metadata for repro: shapes, dtypes, and ranks. + dump_filename = os.path.join(os.getcwd(), f"nccl_copy_service_rank{self.rank}.pt") + payload = { + "rank": self.rank, + "world_size": self.world_size, + "send_ops": [ + { + "shape": tuple(op.tensor.size()), + "dtype": str(op.tensor.dtype).replace("torch.", ""), + "dest_rank": op.dest_rank, + } + for op in self.send_ops + ], + "recv_ops": [ + { + "shape": tuple(op.tensor.size()), + "dtype": str(op.tensor.dtype).replace("torch.", ""), + "src_rank": op.src_rank, + } + for op in self.recv_ops + ], + } + torch.save(payload, dump_filename) + logger.info(f"NCCLCopyService dumped tensor metadata to {dump_filename}") + p2p_ops = [] for op in self.send_ops: p2p_ops.append(dist.P2POp(dist.isend, op.tensor, op.dest_rank)) From 53298e76c9ae3cc4b7112b97404af711c8234dd0 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 3 Dec 2025 20:37:05 -0800 Subject: [PATCH 15/44] check changes for hao --- .../dist_checkpointing/test_optimizer.py | 4 +++- .../transformer/test_multi_latent_attention.py | 14 +++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/tests/unit_tests/dist_checkpointing/test_optimizer.py b/tests/unit_tests/dist_checkpointing/test_optimizer.py index 4c5bfd9b32e..a5da1cd47f8 100644 --- a/tests/unit_tests/dist_checkpointing/test_optimizer.py +++ b/tests/unit_tests/dist_checkpointing/test_optimizer.py @@ -18,7 +18,9 @@ ) from megatron.core.dist_checkpointing.utils import add_prefix_for_sharding, extract_sharded_tensors from megatron.core.dist_checkpointing.validation import StrictHandling -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_decoder_block_spec, +) from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_with_transformer_engine_spec as gpt_te_spec, ) diff --git a/tests/unit_tests/transformer/test_multi_latent_attention.py b/tests/unit_tests/transformer/test_multi_latent_attention.py index 93976a24ac6..1d00c704d26 100644 --- a/tests/unit_tests/transformer/test_multi_latent_attention.py +++ b/tests/unit_tests/transformer/test_multi_latent_attention.py @@ -1083,11 +1083,7 @@ def test_parallel_multi_latent_attention_correctness( # Model initialization function def initialize_gpt_model( - pre_process=True, - post_process=True, - vp_stage=None, - pg_collection=None, - config=None, + pre_process=True, post_process=True, vp_stage=None, pg_collection=None, config=None ): layer_spec = get_gpt_layer_with_transformer_engine_spec(multi_latent_attention=True) gpt_model = GPTModel( @@ -1147,9 +1143,7 @@ def initialize_gpt_model( init_basic_mock_args(mock_args, 1, 1, bf16=True) mock_args.context_parallel_size = 1 mock_args.sequence_parallel = 1 - gpt_model = unwrap_model( - get_model(initialize_gpt_model, config=transformer_config) - ) + gpt_model = unwrap_model(get_model(initialize_gpt_model, config=transformer_config)) # Initialize args and save checkpoint init_checkpointing_mock_args(mock_args, ckpt_dir, False) @@ -1184,9 +1178,7 @@ def initialize_gpt_model( init_basic_mock_args(mock_args, tp, 1, bf16=True) mock_args.context_parallel_size = cp mock_args.sequence_parallel = sp - gpt_model = unwrap_model( - get_model(initialize_gpt_model, config=transformer_config) - ) + gpt_model = unwrap_model(get_model(initialize_gpt_model, config=transformer_config)) with mock.patch('megatron.training.checkpointing.check_checkpoint_args'): with mock.patch('megatron.training.checkpointing.update_num_microbatches'): load_checkpoint(gpt_model, None, None) From c72ec6bd10e01fe287d0ef7039c0df3a176bf132 Mon Sep 17 00:00:00 2001 From: William Dykas Date: Thu, 4 Dec 2025 12:01:47 -0800 Subject: [PATCH 16/44] cleanup logging --- .../copy_services/nccl_copy_service.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/megatron/core/resharding/copy_services/nccl_copy_service.py b/megatron/core/resharding/copy_services/nccl_copy_service.py index 29effff9e61..3a50456176e 100644 --- a/megatron/core/resharding/copy_services/nccl_copy_service.py +++ b/megatron/core/resharding/copy_services/nccl_copy_service.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import os from dataclasses import dataclass from typing import List @@ -57,31 +56,6 @@ def run(self): total_ops, ) - # Dump only lightweight tensor metadata for repro: shapes, dtypes, and ranks. - dump_filename = os.path.join(os.getcwd(), f"nccl_copy_service_rank{self.rank}.pt") - payload = { - "rank": self.rank, - "world_size": self.world_size, - "send_ops": [ - { - "shape": tuple(op.tensor.size()), - "dtype": str(op.tensor.dtype).replace("torch.", ""), - "dest_rank": op.dest_rank, - } - for op in self.send_ops - ], - "recv_ops": [ - { - "shape": tuple(op.tensor.size()), - "dtype": str(op.tensor.dtype).replace("torch.", ""), - "src_rank": op.src_rank, - } - for op in self.recv_ops - ], - } - torch.save(payload, dump_filename) - logger.info(f"NCCLCopyService dumped tensor metadata to {dump_filename}") - p2p_ops = [] for op in self.send_ops: p2p_ops.append(dist.P2POp(dist.isend, op.tensor, op.dest_rank)) From 940615bf0162f0b0125f2a426c23204b4379030a Mon Sep 17 00:00:00 2001 From: root Date: Thu, 4 Dec 2025 12:20:20 -0800 Subject: [PATCH 17/44] clean up --- megatron/core/transformer/cuda_graphs.py | 1 - megatron/training/training.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index bd2ce9656c0..7b81eb723ed 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -58,7 +58,6 @@ _IS_GRAPH_CAPTURING = False logger = logging.getLogger(__name__) -# TODO(Peter) We have changes needed in this for refit to work properly. # Freeze GC during capture. # TODO (@lmcafee): remove all freeze-GC code once most users are on PyTorch 2.9+. diff --git a/megatron/training/training.py b/megatron/training/training.py index b1ba4604e33..cdef4b2e94b 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -702,14 +702,15 @@ def pretrain( print_rank_0(f"Setting tensor model parallel size to {args.rl_inference_tensor_model_parallel_size} for inference model") # Build custom process groups for inference with a different TP size, keeping CP and PP the same as training tp_size = args.rl_inference_tensor_model_parallel_size + #TODO(peter): Get these from args when we want to support other parallelism changes cp_size = mpu.get_context_parallel_world_size() pp_size = mpu.get_pipeline_model_parallel_world_size() + ep_size = mpu.get_expert_model_parallel_world_size() dp_size = args.world_size // (tp_size * cp_size * pp_size) assert dp_size >= 1 and (tp_size * cp_size * pp_size * dp_size) == args.world_size, \ "World size must be divisible by tp*cp*pp for inference PG layout" - # TODO(Peter) We need to pass the expert parallel correctly here - grid = HyperCommGrid([tp_size, cp_size, 1, pp_size, dp_size], ["tp", "cp", "ep", "pp", "dp"]) + grid = HyperCommGrid([tp_size, cp_size, ep_size, pp_size, dp_size], ["tp", "cp", "ep", "pp", "dp"]) tp_group = grid.create_pg("tp") cp_group = grid.create_pg("cp") pp_group = grid.create_pg("pp") From 12dc7ae19c7aa512a629046e3d3ae88055e5e5d0 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 4 Dec 2025 12:28:07 -0800 Subject: [PATCH 18/44] add copyyright --- megatron/core/resharding/__init__.py | 1 + megatron/core/resharding/copy_services/__init__.py | 1 + megatron/core/resharding/copy_services/base.py | 1 + megatron/core/resharding/copy_services/gloo_copy_service.py | 1 + megatron/core/resharding/copy_services/nccl_copy_service.py | 1 + megatron/core/resharding/execution.py | 1 + megatron/core/resharding/planner.py | 1 + megatron/core/resharding/refit.py | 1 + megatron/core/resharding/utils.py | 1 + tests/unit_tests/inference/test_nccl_model_swap.py | 1 + 10 files changed, 10 insertions(+) diff --git a/megatron/core/resharding/__init__.py b/megatron/core/resharding/__init__.py index ef530c8bd10..d06484eef37 100644 --- a/megatron/core/resharding/__init__.py +++ b/megatron/core/resharding/__init__.py @@ -1,3 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from .execution import execute_reshard_plan from .planner import build_centralized_reshard_plan from .refit import reshard_model_weights, swap_model_weights diff --git a/megatron/core/resharding/copy_services/__init__.py b/megatron/core/resharding/copy_services/__init__.py index 91ba6e4d267..15986e4d28e 100644 --- a/megatron/core/resharding/copy_services/__init__.py +++ b/megatron/core/resharding/copy_services/__init__.py @@ -1,3 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from __future__ import annotations from .base import CopyService diff --git a/megatron/core/resharding/copy_services/base.py b/megatron/core/resharding/copy_services/base.py index 13d2a348985..d7b9205ba83 100644 --- a/megatron/core/resharding/copy_services/base.py +++ b/megatron/core/resharding/copy_services/base.py @@ -1,3 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from __future__ import annotations from abc import ABC, abstractmethod diff --git a/megatron/core/resharding/copy_services/gloo_copy_service.py b/megatron/core/resharding/copy_services/gloo_copy_service.py index 4ebc5bf22b5..af70c33d5bd 100644 --- a/megatron/core/resharding/copy_services/gloo_copy_service.py +++ b/megatron/core/resharding/copy_services/gloo_copy_service.py @@ -1,3 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from __future__ import annotations import logging diff --git a/megatron/core/resharding/copy_services/nccl_copy_service.py b/megatron/core/resharding/copy_services/nccl_copy_service.py index 3a50456176e..fe02d108550 100644 --- a/megatron/core/resharding/copy_services/nccl_copy_service.py +++ b/megatron/core/resharding/copy_services/nccl_copy_service.py @@ -1,3 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from __future__ import annotations import logging diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py index 2850d35c49a..d1b9962facf 100644 --- a/megatron/core/resharding/execution.py +++ b/megatron/core/resharding/execution.py @@ -1,3 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from __future__ import annotations import logging diff --git a/megatron/core/resharding/planner.py b/megatron/core/resharding/planner.py index a5ac87b81b5..2deb5f0ec6d 100644 --- a/megatron/core/resharding/planner.py +++ b/megatron/core/resharding/planner.py @@ -1,3 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from __future__ import annotations import logging diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index 8e917280c20..15e5ba758b1 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -1,3 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from __future__ import annotations """ diff --git a/megatron/core/resharding/utils.py b/megatron/core/resharding/utils.py index c8f95186998..b2b622860ee 100644 --- a/megatron/core/resharding/utils.py +++ b/megatron/core/resharding/utils.py @@ -1,3 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from __future__ import annotations import logging diff --git a/tests/unit_tests/inference/test_nccl_model_swap.py b/tests/unit_tests/inference/test_nccl_model_swap.py index b19f956615b..695c44f70b0 100644 --- a/tests/unit_tests/inference/test_nccl_model_swap.py +++ b/tests/unit_tests/inference/test_nccl_model_swap.py @@ -1,3 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import copy import os import types From 8a13950288673f8737447c986ccd113f3c0b4617 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 4 Dec 2025 18:27:25 -0800 Subject: [PATCH 19/44] fix merge --- .../core/inference/engines/dynamic_engine.py | 5 - megatron/core/process_groups_config.py | 144 ------------------ megatron/rl/inference/megatron.py | 3 +- megatron/rl/rl_utils.py | 3 +- .../dist_checkpointing/test_optimizer.py | 4 +- 5 files changed, 4 insertions(+), 155 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 36f6acbb156..ed107af9bfb 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -154,11 +154,6 @@ def __init__( ), f"context must be a DynamicInferenceContext, got {type(context)}" assert isinstance(random_seed, int), f"random_seed must be an int, got {type(random_seed)}" - if pg_collection is not None: - self.pg_collection = pg_collection - else: - self.pg_collection = ProcessGroupCollection.use_mpu_process_groups() - # Deprecate `enable_cuda_graph`. if enable_cuda_graph is not None: warnings.warn( diff --git a/megatron/core/process_groups_config.py b/megatron/core/process_groups_config.py index f704e814970..ef8f31ea150 100644 --- a/megatron/core/process_groups_config.py +++ b/megatron/core/process_groups_config.py @@ -249,150 +249,6 @@ def use_mpu_process_groups(cls, required_pgs: Optional[List[str]] = None): return cls(**init_dict) - def get_pipeline_model_parallel_world_size(self) -> int: - """Return PP world size using the PP process group.""" - return torch.distributed.get_world_size(self.pp) - - def get_tensor_model_parallel_rank(self) -> int: - """Return this rank's TP rank within the TP group.""" - global_rank = torch.distributed.get_rank() - tp_ranks = torch.distributed.get_process_group_ranks(self.tp) - return tp_ranks.index(global_rank) - - def get_pipeline_model_parallel_rank(self) -> int: - """Return this rank's PP rank within the PP group.""" - global_rank = torch.distributed.get_rank() - pp_ranks = torch.distributed.get_process_group_ranks(self.pp) - return pp_ranks.index(global_rank) - - def is_pipeline_first_stage( - self, ignore_virtual: bool = True, vp_stage: Optional[int] = None - ) -> bool: - """Return True if this rank is on the first PP stage. - - By default, ignores virtual pipeline (matches legacy interface). If you need VP-aware - behavior, pass ignore_virtual=False and specify vp_stage. - """ - pp_ranks = torch.distributed.get_process_group_ranks(self.pp) - global_rank = torch.distributed.get_rank() - is_pp_first = pp_ranks[0] == global_rank - vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() - if ignore_virtual or vp_size in (None, 1) or vp_stage is None: - return is_pp_first - is_vp_first = vp_stage == 0 - return is_vp_first and is_pp_first - - def is_pipeline_last_stage( - self, ignore_virtual: bool = True, vp_stage: Optional[int] = None - ) -> bool: - """Return True if this rank is on the last PP stage. - - By default, ignores virtual pipeline (matches legacy interface). If you need VP-aware - behavior, pass ignore_virtual=False and specify vp_stage. - """ - pp_ranks = torch.distributed.get_process_group_ranks(self.pp) - global_rank = torch.distributed.get_rank() - is_pp_last = pp_ranks[-1] == global_rank - vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() - if ignore_virtual or vp_size in (None, 1) or vp_stage is None: - return is_pp_last - is_vp_last = vp_stage == (vp_size - 1) - return is_vp_last and is_pp_last - - def get_data_parallel_rank(self) -> int: - """Return this rank's DP rank within the DP group.""" - global_rank = torch.distributed.get_rank() - dp_ranks = torch.distributed.get_process_group_ranks(self.dp) - return dp_ranks.index(global_rank) - - def get_context_parallel_rank(self) -> int: - """Return this rank's CP rank within the CP group, or 0 if no CP.""" - if not hasattr(self, 'cp') or self.cp is None: - return 0 - global_rank = torch.distributed.get_rank() - cp_ranks = torch.distributed.get_process_group_ranks(self.cp) - return cp_ranks.index(global_rank) - - def get_data_parallel_group( - self, with_context_parallel: bool = False, partial_data_parallel: bool = False - ): - """Return the DP/DP+CP process group, optionally partial.""" - if with_context_parallel: - if partial_data_parallel: - # Prefer intra_dp_cp if available, else fallback to dp_cp - if hasattr(self, 'intra_dp_cp') and self.intra_dp_cp is not None: - return self.intra_dp_cp - if hasattr(self, 'dp_cp') and self.dp_cp is not None: - return self.dp_cp - return self.dp - else: - if hasattr(self, 'dp_cp') and self.dp_cp is not None: - return self.dp_cp - return self.dp - return self.dp - - def get_tensor_model_parallel_group(self): - """Return the tensor model parallel (TP) process group.""" - return self.tp - - def get_pipeline_model_parallel_group(self): - """Return the pipeline model parallel (PP) process group.""" - return self.pp - - def get_model_parallel_group(self): - """Return the combined model parallel (MP) process group.""" - return self.mp - - def get_model_parallel_world_size(self) -> int: - """Return MP world size using the MP process group.""" - return torch.distributed.get_world_size(self.mp) - - def get_model_parallel_src_rank(self) -> int: - """Return the source (leader) global rank for the MP group.""" - ranks = torch.distributed.get_process_group_ranks(self.mp) - return ranks[0] - - def get_context_parallel_group(self): - """Return the context parallel (CP) process group, if configured.""" - return getattr(self, 'cp', None) - - def get_expert_model_parallel_group(self): - """Return the expert model parallel (EP) process group, if configured.""" - return getattr(self, 'ep', None) - - def get_expert_data_parallel_group(self, partial_expert_data_parallel: bool = False): - """Return the expert data parallel (ExDP) process group.""" - if partial_expert_data_parallel: - return getattr(self, 'intra_expt_dp', None) - return getattr(self, 'expt_dp', None) - - def get_inter_distributed_optimizer_instance_group(self): - """Return the inter-node distributed optimizer instance process group, if configured.""" - return getattr(self, 'inter_dist_opt', None) - - def get_intra_distributed_optimizer_instance_group(self): - """Return the intra-node distributed optimizer instance process group, if configured.""" - return getattr(self, 'intra_dist_opt', None) - - def get_data_parallel_src_rank( - self, with_context_parallel: bool = False, partial_data_parallel: bool = False - ) -> int: - """Return the source (leader) global rank for the selected DP group.""" - group = self.get_data_parallel_group( - with_context_parallel=with_context_parallel, partial_data_parallel=partial_data_parallel - ) - ranks = torch.distributed.get_process_group_ranks(group) - return ranks[0] - - def get_data_parallel_world_size( - self, with_context_parallel: bool = False, partial_data_parallel: bool = False - ) -> int: - """Return world size of the selected DP group.""" - group = self.get_data_parallel_group( - with_context_parallel=with_context_parallel, partial_data_parallel=partial_data_parallel - ) - return torch.distributed.get_world_size(group) - @staticmethod def setup_process_groups_for_optimizer( pg_collection: Optional['ProcessGroupCollection'], diff --git a/megatron/rl/inference/megatron.py b/megatron/rl/inference/megatron.py index 41f022108bc..54acc112dd9 100644 --- a/megatron/rl/inference/megatron.py +++ b/megatron/rl/inference/megatron.py @@ -114,7 +114,6 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen mamba_inference_state_config = get_mamba_inference_state_config_from_model(model) - # DynamicInferenceContext must use the inference model's TP size, not the # training TP size from global args. The inference model may have a custom # ProcessGroupCollection with a different TP size. @@ -231,7 +230,6 @@ async def base_generate(self, request: InferenceRequest): async def launch(cls, model: GPTModel, **kwargs): args = get_args() tokenizer = get_tokenizer() - rank = dist.get_rank() if tokenizer.bos is None: log_single_rank( @@ -251,6 +249,7 @@ async def launch(cls, model: GPTModel, **kwargs): if metrics_writer is None: log_single_rank(logger, logging.WARNING, "WARNING: --rl-inference-logging-step-interval is set but no metrics writer " "wandb module is available. Inference logging will be disabled.") + inference_engine: DynamicInferenceEngine = get_dynamic_inference_engine(args, model, inference_logging_step_interval, metrics_writer) await inference_engine.start_listening_to_data_parallel_coordinator(inference_coordinator_port=41521, launch_inference_coordinator=True) if dist.get_rank() == 0: diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index cdd91e5b4cc..80dffaf0206 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -57,6 +57,7 @@ from megatron.training.tokenizer.tokenizer import CustomTikTokenizer, _HuggingFaceTokenizer from megatron.training.utils import get_ltor_masks_and_position_ids, get_nvtx_range, print_rank_0 from megatron.training.utils import unwrap_model +from megatron.core.utils import get_pg_size logger = logging.getLogger(__name__) # Global variable to store packing context for forward_step @@ -2158,7 +2159,7 @@ def setup_grpo_data_iterator( if bin_idx.item() < len(my_bin_seq_indices) ) # Estimate global sequences for this step - est_global_sequences = step_sequences * inference_mpu.get_data_parallel_world_size() + est_global_sequences = step_sequences * get_pg_size(inference_mpu.dp) print_rank_0( f"[Sequence Packing] Optimizer step {plan['current_step']}/{plan['total_steps']}: " f"processing {len(step_bin_indices)} bins (~{est_global_sequences} sequences globally)" diff --git a/tests/unit_tests/dist_checkpointing/test_optimizer.py b/tests/unit_tests/dist_checkpointing/test_optimizer.py index a5da1cd47f8..4c5bfd9b32e 100644 --- a/tests/unit_tests/dist_checkpointing/test_optimizer.py +++ b/tests/unit_tests/dist_checkpointing/test_optimizer.py @@ -18,9 +18,7 @@ ) from megatron.core.dist_checkpointing.utils import add_prefix_for_sharding, extract_sharded_tensors from megatron.core.dist_checkpointing.validation import StrictHandling -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_decoder_block_spec, -) +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_with_transformer_engine_spec as gpt_te_spec, ) From 70272da9719189d4af7dd9bd176b5ae0eca2e9d3 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 4 Dec 2025 18:45:56 -0800 Subject: [PATCH 20/44] cleanup merge --- megatron/rl/rl_utils.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 80dffaf0206..4c36819ecc0 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -57,7 +57,7 @@ from megatron.training.tokenizer.tokenizer import CustomTikTokenizer, _HuggingFaceTokenizer from megatron.training.utils import get_ltor_masks_and_position_ids, get_nvtx_range, print_rank_0 from megatron.training.utils import unwrap_model -from megatron.core.utils import get_pg_size +from megatron.core.utils import get_pg_size, get_attr_wrapped_model logger = logging.getLogger(__name__) # Global variable to store packing context for forward_step @@ -662,9 +662,9 @@ def get_environment_rollouts( else: inference_model = model - #TODO(peter): We need to get the models process group collection and use that for these checks + inference_pg_collection = get_attr_wrapped_model(inference_model[0], "pg_collection") assert ( - n_prompts % mpu.get_expert_data_parallel_world_size() == 0 + n_prompts % get_pg_size(inference_pg_collection.ep) == 0 ), "n_prompts must be divisible by data_parallel_world_size" with nvtx_range("rollout-collection"): @@ -710,8 +710,7 @@ def get_environment_rollouts( torch.distributed.broadcast_object_list(rollouts, src=0) print(f"Got rollouts on rank {rank}") - #TODO(Peter): We need to use the proper models MPU here. - if lang_rl_log_dir and rank == get_tensor_model_parallel_src_rank(): + if lang_rl_log_dir and rank == get_pg_rank(inference_pg_collection.tp): with open( lang_rl_log_dir + f'/rollouts_rank{rank}_iteration{args.curr_iteration}_' @@ -2117,9 +2116,9 @@ def setup_grpo_data_iterator( args = get_args() if inference_model is not None: - inference_mpu = unwrap_model(inference_model[0]).pg_collection + inference_pg_collection = unwrap_model(inference_model[0]).pg_collection else: - inference_mpu = mpu + inference_pg_collection = ProcessGroupCollection.use_mpu_process_groups() # We collect new rollouts when we've gone over the collected data 'grpo_iterations' times. if ( @@ -2159,7 +2158,7 @@ def setup_grpo_data_iterator( if bin_idx.item() < len(my_bin_seq_indices) ) # Estimate global sequences for this step - est_global_sequences = step_sequences * get_pg_size(inference_mpu.dp) + est_global_sequences = step_sequences * get_pg_size(inference_pg_collection.dp) print_rank_0( f"[Sequence Packing] Optimizer step {plan['current_step']}/{plan['total_steps']}: " f"processing {len(step_bin_indices)} bins (~{est_global_sequences} sequences globally)" @@ -2527,8 +2526,7 @@ def get_sequence_packing_tensorboard_metrics(args): """Get tensorboard metrics for sequence packing mode.""" metrics = {} if args.consumed_train_bins > 0: - # TODO(Peter) We need to use the proper models MPU for refitting. If you forget you probably need to change this all over this - # file + # TODO(Peter) We need to use the proper models MPU for refitting. bin_batch_size = ( mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() ) From cc5b44bebf2a496ffe206a5a3f408396574a707f Mon Sep 17 00:00:00 2001 From: root Date: Thu, 4 Dec 2025 19:03:55 -0800 Subject: [PATCH 21/44] remove unwrap --- megatron/core/resharding/refit.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index 15e5ba758b1..29a433ecc81 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -11,6 +11,7 @@ from megatron.core import parallel_state from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.utils import unwrap_model from . import build_centralized_reshard_plan, execute_reshard_plan from .copy_services.base import CopyService @@ -18,14 +19,6 @@ from .copy_services.nccl_copy_service import NCCLCopyService -def _unwrap_module(module: LanguageModule) -> Any: - return ( - module.module.module - if hasattr(module, 'module') and hasattr(module.module, 'module') - else module.module if hasattr(module, 'module') else module - ) - - def swap_model_weights( src_model: LanguageModule, target_model: LanguageModule, refit_method: Union[str, CopyService] ): @@ -61,8 +54,8 @@ def reshard_model_weights( num_experts = src_lm.config.num_moe_experts # Unwrap to get owning modules (with parameters and pg_collection) - src_core = _unwrap_module(src_lm) - tgt_core = _unwrap_module(tgt_lm) + src_core = unwrap_model(src_lm) + tgt_core = unwrap_model(tgt_lm) # Ensure pg_collection exists if not hasattr(src_core, "pg_collection") or src_core.pg_collection is None: From 16281efea5d71070a8baf69fd8b3f0cbbab247cc Mon Sep 17 00:00:00 2001 From: root Date: Fri, 5 Dec 2025 09:33:34 -0800 Subject: [PATCH 22/44] simplify dp round robin --- megatron/core/resharding/utils.py | 46 ++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/megatron/core/resharding/utils.py b/megatron/core/resharding/utils.py index b2b622860ee..54b05beee63 100644 --- a/megatron/core/resharding/utils.py +++ b/megatron/core/resharding/utils.py @@ -219,23 +219,43 @@ def extract_param_metadata( def select_src_metadata_balanced( src_meta_list: list[ParameterMetadata], dst_metadata: ParameterMetadata, dst_rank: int ) -> ParameterMetadata: - """Choose representative source metadata using DP round-robin across source DP groups.""" + """Choose a representative source `ParameterMetadata` for a destination rank. + + Multiple source data-parallel (DP) groups may hold the same logical parameter. + To avoid always reading from the same group, we: + - bucket `src_meta_list` by their DP group (tuple of ranks) + - if there is only one bucket, just return the first entry + - otherwise, map the destination rank's DP index to one of the source + DP groups in a round-robin fashion, and pick the first metadata in it. + """ if not src_meta_list: raise ValueError("src_meta_list must be non-empty") - groups: dict[tuple[int, ...], list[ParameterMetadata]] = {} - for m in src_meta_list: - key = tuple(m.data_parallel_group_ranks or []) - groups.setdefault(key, []).append(m) - if len(groups) == 1: + + # Group source metadata by their DP group layout so we can balance across groups. + # (dp_rank0, dp_rank1, ...) -> [ParameterMetadata for that DP group] + grouped_by_dp: dict[tuple[int, ...], list[ParameterMetadata]] = {} + for meta in src_meta_list: + dp_group = tuple(meta.data_parallel_group_ranks or []) + grouped_by_dp.setdefault(dp_group, []).append(meta) + + # Fast path: only one DP layout present; no balancing necessary. + if len(grouped_by_dp) == 1: return src_meta_list[0] - dst_dp = dst_metadata.data_parallel_group_ranks or [] - if dst_rank in dst_dp and len(dst_dp) > 0: - my_dst_dp_idx = dst_dp.index(dst_rank) + + # Determine this destination rank's index within its DP group (if any). + dst_dp_ranks = dst_metadata.data_parallel_group_ranks or [] + if dst_dp_ranks and dst_rank in dst_dp_ranks: + dst_dp_index = dst_dp_ranks.index(dst_rank) else: - my_dst_dp_idx = 0 - keys_sorted = sorted(groups.keys()) - chosen_key = keys_sorted[my_dst_dp_idx % len(keys_sorted)] - return groups[chosen_key][0] + # Fallback: treat as the first DP index. + dst_dp_index = 0 + + # Use a stable ordering of DP groups so that round-robin is deterministic. + sorted_dp_groups = sorted(grouped_by_dp.keys()) + chosen_group = sorted_dp_groups[dst_dp_index % len(sorted_dp_groups)] + + # Within the chosen group, any representative metadata works; use the first. + return grouped_by_dp[chosen_group][0] logger = logging.getLogger(__name__) From 16dc9119ace6c42ff2128f5ae7f2c61b1a1a5477 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 5 Dec 2025 10:24:05 -0800 Subject: [PATCH 23/44] Address comments --- megatron/core/resharding/refit.py | 13 ++++++++++--- megatron/rl/rl_utils.py | 14 ++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index 29a433ecc81..b55d7e2e13e 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -7,7 +7,7 @@ - reshard_model_weights: transport-agnostic core; builds/caches plan and executes. """ -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union from megatron.core import parallel_state from megatron.core.models.common.language_module.language_module import LanguageModule @@ -18,13 +18,20 @@ from .copy_services.gloo_copy_service import GlooCopyService from .copy_services.nccl_copy_service import NCCLCopyService +# Supported refit backend names +RefitBackendName = Literal["nccl", "gloo"] + def swap_model_weights( - src_model: LanguageModule, target_model: LanguageModule, refit_method: Union[str, CopyService] + src_model: LanguageModule, + target_model: LanguageModule, + refit_method: Union[RefitBackendName, CopyService], ): """ Orchestrate weight swap/refit. - - refit_method can be a string backend name ('nccl') or a CopyService instance. + - refit_method can be: + * a string backend name (one of the supported refit backends), or + * a CopyService instance. """ if isinstance(refit_method, CopyService): service = refit_method diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 4c36819ecc0..8cd17e430cd 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -658,7 +658,7 @@ def get_environment_rollouts( # If we have seperate training and inference models we to refit weights from the training model to the inference model. if inference_model is not None: - swap_train_to_inference_model(model, inference_model, args.refit_method) + swap_model_weights(model, inference_model, args.refit_method) else: inference_model = model @@ -2407,6 +2407,7 @@ def megatron_rl_inference_mode( loop = get_asyncio_loop() nvtx_range = get_nvtx_range() + print(f"[{dist.get_rank()}:DP] Entering inference mode") # If we get a lower precision wrapper, we go one object deeper. lang_module = model[0].module.module if hasattr(model[0].module, "module") else model[0].module @@ -2458,6 +2459,8 @@ def megatron_rl_inference_mode( inference_interface._inference_engine.create_cuda_graphs(reset_context=True) loop.run_until_complete(inference_interface.resume()) + + print(f"[{dist.get_rank()}:DP] Exited inference mode") yield inference_interface with nvtx_range("suspend-engine"): @@ -2533,12 +2536,3 @@ def get_sequence_packing_tensorboard_metrics(args): metrics['bin-batch-size'] = bin_batch_size metrics['consumed-bins'] = args.consumed_train_bins return metrics - -def swap_train_to_inference_model(train_model: LanguageModule, inference_model: list[LanguageModule], refit_method: str): - """Swap the train model to the inference model. - - Args: - train_model: The train model to swap to the inference model. - inference_model: The inference model (list) to swap to the train model. - """ - swap_model_weights(train_model, inference_model, refit_method) \ No newline at end of file From 32d8d9f88ed01e56b353868f81acc6b1206cd500 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 5 Dec 2025 11:28:55 -0800 Subject: [PATCH 24/44] add fix --- megatron/rl/rl_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 8cd17e430cd..6d49433ab30 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -58,6 +58,7 @@ from megatron.training.utils import get_ltor_masks_and_position_ids, get_nvtx_range, print_rank_0 from megatron.training.utils import unwrap_model from megatron.core.utils import get_pg_size, get_attr_wrapped_model +from megatron.core.process_groups_config import ProcessGroupCollection logger = logging.getLogger(__name__) # Global variable to store packing context for forward_step From 6cd248b6166b6625b881b11005e90fb280b4ea78 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 7 Dec 2025 17:26:35 -0800 Subject: [PATCH 25/44] test --- .../core/resharding/copy_services/__init__.py | 3 +- .../copy_services/nvshmem_copy_service.py | 160 +++++++ megatron/core/resharding/execution.py | 13 +- .../nvshmem_copy_service/__init__.py | 34 ++ .../nvshmem_copy_service/core/__init__.py | 9 + .../core/gpu_resource_manager.py | 194 ++++++++ .../core/kernel_launcher.py | 145 ++++++ .../core/pipeline_executor.py | 289 ++++++++++++ .../kernels/chunked_kernel.cu | 103 ++++ .../resharding/nvshmem_copy_service/logger.py | 199 ++++++++ .../nvshmem_copy_service/memory/__init__.py | 8 + .../memory/double_buffer_manager.py | 76 +++ .../memory/tensor_pointer_utils.py | 33 ++ .../nvshmem_copy_service/nvshmem_types.py | 61 +++ .../nvshmem_copy_service/planning/__init__.py | 15 + .../planning/communication_scheduler.py | 214 +++++++++ .../planning/gpu_execution_planner.py | 248 ++++++++++ .../planning/task_segmenter.py | 97 ++++ .../planning/workload_packer.py | 107 +++++ .../nvshmem_copy_service/service.py | 441 ++++++++++++++++++ .../nvshmem_copy_service/validation.py | 154 ++++++ megatron/core/resharding/planner.py | 14 +- megatron/core/resharding/refit.py | 6 +- megatron/core/resharding/utils.py | 5 + megatron/training/arguments.py | 6 +- .../inference/test_nccl_model_swap.py | 4 +- 26 files changed, 2630 insertions(+), 8 deletions(-) create mode 100644 megatron/core/resharding/copy_services/nvshmem_copy_service.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/__init__.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/core/__init__.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/kernels/chunked_kernel.cu create mode 100644 megatron/core/resharding/nvshmem_copy_service/logger.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/memory/__init__.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/planning/__init__.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/service.py create mode 100644 megatron/core/resharding/nvshmem_copy_service/validation.py diff --git a/megatron/core/resharding/copy_services/__init__.py b/megatron/core/resharding/copy_services/__init__.py index 15986e4d28e..447588f7b3a 100644 --- a/megatron/core/resharding/copy_services/__init__.py +++ b/megatron/core/resharding/copy_services/__init__.py @@ -3,5 +3,6 @@ from .base import CopyService from .nccl_copy_service import NCCLCopyService +from .nvshmem_copy_service import NVSHMEMCopyService -__all__ = ["CopyService", "NCCLCopyService"] +__all__ = ["CopyService", "NCCLCopyService", "NVSHMEMCopyService"] diff --git a/megatron/core/resharding/copy_services/nvshmem_copy_service.py b/megatron/core/resharding/copy_services/nvshmem_copy_service.py new file mode 100644 index 00000000000..e5abb31dbd5 --- /dev/null +++ b/megatron/core/resharding/copy_services/nvshmem_copy_service.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +""" +NVSHMEM-based implementation of the CopyService interface. + +This wraps the higher-level RemoteCopyService so it can be used anywhere a +CopyService is expected (e.g., refit/reshard execution). + +NOTE: This is a first, minimal wiring. It currently mirrors the point-to-point +semantics of execute_reshard_plan by treating each send/recv pair as an +independent NVSHMEM "task" defined over contiguous slices. +""" + +import logging +from typing import Dict, List, Tuple + +import torch +import torch.distributed as dist + +from ..nvshmem_copy_service import RemoteCopyService +from .base import CopyService + +logger = logging.getLogger(__name__) + + +class NVSHMEMCopyService(CopyService): + """CopyService implementation backed by NVSHMEM RemoteCopyService.""" + + def __init__(self): + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized before NVSHMEMCopyService()" + ) + + self._remote = RemoteCopyService() + # Lazily initialized on first use to avoid side effects at import time + self._initialized = False + + # Internal bookkeeping of registration calls before schedule/run + self._next_task_id: int = 0 + self._registered_pairs: List[Tuple[int, torch.Tensor, torch.Tensor, int]] = [] + + logger.info("NVSHMEMCopyService constructed") + + def _ensure_initialized(self): + if not self._initialized: + self._remote.init(log_level="INFO") + self._initialized = True + logger.info( + "NVSHMEMCopyService initialized: PE %d / %d", + self._remote.my_pe, + self._remote.n_pes, + ) + + def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): + """ + Basic CopyService API is not rich enough to drive the NVSHMEM planner + (it lacks a globally shared task identifier), so this method is kept + only for interface compatibility and should not be used directly. + + The resharding path calls into NVSHMEMCopyService via the + submit_send_with_id/submit_recv_with_id helpers instead. + """ + raise RuntimeError( + "NVSHMEMCopyService.submit_send() is not supported; " + "use submit_send_with_id(...) from execute_reshard_plan." + ) + + def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): + raise RuntimeError( + "NVSHMEMCopyService.submit_recv() is not supported; " + "use submit_recv_with_id(...) from execute_reshard_plan." + ) + + # + # New helper API used from execute_reshard_plan via monkey-patching: + # we avoid changing the existing execute_reshard_plan signature by adding + # a small adapter layer that batches up matched send/recv slices. + # + + def submit_send_with_id( + self, + task_id: int, + src_tensor: torch.Tensor, + dest_rank: int, + ): + """Register a send with an explicit, globally shared task_id.""" + self._ensure_initialized() + + if not src_tensor.is_contiguous(): + src_tensor = src_tensor.contiguous() + + num_bytes = src_tensor.numel() * src_tensor.element_size() + src_bytes = src_tensor.view(torch.uint8) + + logger.debug( + "NVSHMEMCopyService: register_send task_id=%d, %d bytes (%d → %d)", + task_id, + num_bytes, + dist.get_rank(), + dest_rank, + ) + + # Use public API on RemoteCopyService + self._remote.register_send( + task_id=task_id, + src_tensor=src_bytes, + src_pos=0, + size=num_bytes, + dest_pe=dest_rank, + ) + + def submit_recv_with_id( + self, + task_id: int, + dest_tensor: torch.Tensor, + src_rank: int, + ): + """Register a recv with an explicit, globally shared task_id.""" + self._ensure_initialized() + + if not dest_tensor.is_contiguous(): + dest_tensor = dest_tensor.contiguous() + + num_bytes = dest_tensor.numel() * dest_tensor.element_size() + dst_bytes = dest_tensor.view(torch.uint8) + + logger.debug( + "NVSHMEMCopyService: register_recv task_id=%d, %d bytes (%d ← %d)", + task_id, + num_bytes, + dist.get_rank(), + src_rank, + ) + + self._remote.register_receive( + task_id=task_id, + dest_tensor=dst_bytes, + dest_pos=0, + size=num_bytes, + src_pe=src_rank, + ) + + def run(self): + """ + Execute all registered transfer pairs via NVSHMEM. + + This converts the registered pairs into RemoteCopyService send/receive + requests, builds a schedule, runs the pipelined NVSHMEM transfer, and + then clears internal state. + """ + # Execute schedule built from submit_send_with_id/submit_recv_with_id + self._ensure_initialized() + logger.info("NVSHMEMCopyService: building NVSHMEM schedule and executing") + self._remote.schedule() + self._remote.run() + self._remote.clear_requests() + logger.info("NVSHMEMCopyService: NVSHMEM transfers complete") + + diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py index d1b9962facf..eb7e7dfef83 100644 --- a/megatron/core/resharding/execution.py +++ b/megatron/core/resharding/execution.py @@ -8,6 +8,7 @@ import torch.distributed as dist from .copy_services.base import CopyService +from .copy_services.nvshmem_copy_service import NVSHMEMCopyService from .utils import ReshardPlan logger = logging.getLogger(__name__) @@ -41,12 +42,17 @@ def execute_reshard_plan( dst_view = dst_param.data[dst_slice] dst_view.copy_(src_view) + is_nvshmem = isinstance(service, NVSHMEMCopyService) + # Submit sends for op in plan.send_ops: src_param = src_params.get(op.param_name) if src_param is not None: src_view = src_param.data[op.my_slice].contiguous() - service.submit_send(src_view, op.peer_rank) + if is_nvshmem and op.task_id is not None: + service.submit_send_with_id(op.task_id, src_view, op.peer_rank) + else: + service.submit_send(src_view, op.peer_rank) # Submit recvs recv_writebacks: List[Tuple[torch.Tensor, torch.nn.Parameter, tuple[slice, ...]]] = [] @@ -55,7 +61,10 @@ def execute_reshard_plan( if dst_param is not None: dst_slice_view = dst_param.data[op.my_slice] recv_buffer = torch.empty_like(dst_slice_view.contiguous()) - service.submit_recv(recv_buffer, op.peer_rank) + if is_nvshmem and op.task_id is not None: + service.submit_recv_with_id(op.task_id, recv_buffer, op.peer_rank) + else: + service.submit_recv(recv_buffer, op.peer_rank) recv_writebacks.append((recv_buffer, dst_param, op.my_slice)) # Execute diff --git a/megatron/core/resharding/nvshmem_copy_service/__init__.py b/megatron/core/resharding/nvshmem_copy_service/__init__.py new file mode 100644 index 00000000000..2019c518039 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/__init__.py @@ -0,0 +1,34 @@ +""" +NVSHMEM-based remote copy service and supporting components. + +This package is an in-tree integration of the standalone +`nvshmem_copy_service/python` implementation so that Megatron +can use it without relying on an external library. +""" + +from .service import RemoteCopyService +from . import nvshmem_types +from .core import GPUResourceManager, KernelLauncher, PipelineExecutor +from .memory import DoubleBufferManager, TensorPointerExtractor +from .planning import ( + CommunicationScheduler, + GPUExecutionPlanner, + TaskSegmenter, + WorkloadPacker, +) + +__all__ = [ + "RemoteCopyService", + "nvshmem_types", + "GPUResourceManager", + "KernelLauncher", + "PipelineExecutor", + "DoubleBufferManager", + "TensorPointerExtractor", + "CommunicationScheduler", + "GPUExecutionPlanner", + "TaskSegmenter", + "WorkloadPacker", +] + + diff --git a/megatron/core/resharding/nvshmem_copy_service/core/__init__.py b/megatron/core/resharding/nvshmem_copy_service/core/__init__.py new file mode 100644 index 00000000000..41ca4bad9b6 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/core/__init__.py @@ -0,0 +1,9 @@ +"""Core execution components for NVSHMEM operations.""" + +from .gpu_resource_manager import GPUResourceManager +from .kernel_launcher import KernelLauncher +from .pipeline_executor import PipelineExecutor + +__all__ = ["GPUResourceManager", "KernelLauncher", "PipelineExecutor"] + + diff --git a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py new file mode 100644 index 00000000000..000ddebdb4f --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py @@ -0,0 +1,194 @@ +""" +GPU resource management for NVSHMEM operations. + +Handles NVSHMEM initialization, CUDA device setup, stream management, +and event lifecycle. +""" + +from typing import Optional, Dict + +import nvshmem.core +import torch +import torch.distributed as dist +from cuda.core.experimental import Device, system # type: ignore[attr-defined] + + +class GPUResourceManager: + """Manages GPU resources including NVSHMEM, streams, and events.""" + + def __init__(self): + self.device: Optional[Device] = None + self.my_pe: int = -1 + self.n_pes: int = -1 + self.initialized: bool = False + + # Dedicated torch.distributed process group for NVSHMEM collectives. + # This isolates NVSHMEM's use of collectives from the default WORLD + # group that Megatron and the test harness use for their own ops. + self.pg: Optional[dist.ProcessGroup] = None + + # CUDA streams (cuda.core.experimental) + self.pack_stream = None + self.unpack_stream = None + self.send_stream = None + self.copy_stream = None + + # PyTorch stream wrappers + self.torch_pack_stream = None + self.torch_unpack_stream = None + self.torch_send_stream = None + self.torch_copy_stream = None + + # Stream name to PyTorch stream mapping + self._torch_streams: Dict[str, torch.cuda.ExternalStream] = {} + + def init(self) -> None: + """ + Initialize NVSHMEM, CUDA device, and streams. + + Expects torch.distributed to be already initialized. + """ + if self.initialized: + return + + # torch.distributed must be initialized before calling this + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized before " + "GPUResourceManager.init()" + ) + + # Get current CUDA device (already set by caller based on LOCAL_RANK) + local_rank = torch.cuda.current_device() + + # nvshmem4py requires a cuda.core Device at init time + self.device = Device(local_rank) + self.device.set_current() + + # Extract rank, nranks from process group + num_ranks = dist.get_world_size() + rank_id = dist.get_rank() + + # Create a dedicated process group for NVSHMEM collectives. + # Using a private group avoids interfering with Megatron's own + # WORLD-group collectives (e.g., during test setup/teardown), + # which can otherwise trigger "collective mismatch" runtime errors. + self.pg = dist.new_group(ranks=list(range(num_ranks))) + + # Create/Broadcast UniqueID using broadcast_object_list + uniqueid = nvshmem.core.get_unique_id(empty=True) + if rank_id == 0: + uniqueid = nvshmem.core.get_unique_id() + broadcast_objects = [uniqueid] + else: + broadcast_objects = [None] + + # Broadcast ID to all ranks + dist.broadcast_object_list(broadcast_objects, src=0, group=self.pg) + + # Barrier to ensure everyone has the ID before NVSHMEM init + dist.barrier(group=self.pg) + + # Initialize NVSHMEM with the broadcasted UID + nvshmem.core.init( + device=self.device, + uid=broadcast_objects[0], + rank=rank_id, + nranks=num_ranks, + initializer_method="uid", + ) + + print("NVSHMEM initialized") + + self.my_pe = nvshmem.core.my_pe() + self.n_pes = nvshmem.core.n_pes() + + # Create CUDA streams + self.pack_stream = self.device.create_stream() + self.unpack_stream = self.device.create_stream() + self.send_stream = self.device.create_stream() + self.copy_stream = self.device.create_stream() + + # Get stream pointers and create PyTorch wrappers + _, pack_stream_ptr = self.pack_stream.__cuda_stream__() + _, unpack_stream_ptr = self.unpack_stream.__cuda_stream__() + _, send_stream_ptr = self.send_stream.__cuda_stream__() + _, copy_stream_ptr = self.copy_stream.__cuda_stream__() + + self.torch_pack_stream = torch.cuda.ExternalStream(pack_stream_ptr) + self.torch_unpack_stream = torch.cuda.ExternalStream(unpack_stream_ptr) + self.torch_send_stream = torch.cuda.ExternalStream(send_stream_ptr) + self.torch_copy_stream = torch.cuda.ExternalStream(copy_stream_ptr) + + # Build stream mapping + self._torch_streams = { + "pack": self.torch_pack_stream, + "unpack": self.torch_unpack_stream, + "send": self.torch_send_stream, + "copy": self.torch_copy_stream, + } + + print("Stream mapping built") + + self.initialized = True + + # Initial barrier to ensure all PEs are ready + nvshmem.core.barrier_all(stream=self.send_stream) + + def get_stream(self, name: str): + """ + Get CUDA stream by name. + + Args: + name: Stream name ('pack', 'unpack', 'send', 'copy') + + Returns: + CUDA stream object + """ + streams = { + "pack": self.pack_stream, + "unpack": self.unpack_stream, + "send": self.send_stream, + "copy": self.copy_stream, + } + return streams.get(name) + + def get_torch_stream(self, name: str) -> Optional[torch.cuda.ExternalStream]: + """ + Get PyTorch ExternalStream by name. + + Args: + name: Stream name ('pack', 'unpack', 'send', 'copy') + + Returns: + PyTorch ExternalStream + """ + return self._torch_streams.get(name) + + def create_events(self, num_events: int = 2): + """ + Create double-buffered CUDA events for pack and unpack operations. + + Args: + num_events: Number of events to create for each type + (default: 2 for double buffering) + + Returns: + tuple: (pack_events, unpack_events) lists of torch.cuda.Event + """ + pack_events = [ + torch.cuda.Event(enable_timing=False) for _ in range(num_events) + ] + unpack_events = [ + torch.cuda.Event(enable_timing=False) for _ in range(num_events) + ] + return pack_events, unpack_events + + def finalize(self) -> None: + """Cleanup resources (streams are automatically managed by CUDA).""" + self.initialized = False + self.my_pe = -1 + self.n_pes = -1 + # Streams are automatically cleaned up when objects are deleted + + diff --git a/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py b/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py new file mode 100644 index 00000000000..a67564fc659 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py @@ -0,0 +1,145 @@ +""" +CUDA kernel management and launching for pack/unpack operations. + +Handles kernel compilation, launching, and stream coordination. +""" + +import os +from typing import Tuple, Optional, Any + +import cupy as cp +import torch +import torch.cuda.nvtx as nvtx + + +class KernelLauncher: + """Manages CUDA kernel loading and launching for data pack/unpack operations.""" + + def __init__(self): + self.chunked_copy_kernel: Optional[cp.RawKernel] = None + # Cached CuPy stream wrappers for efficient kernel launching + self.cp_pack_stream: Optional[cp.cuda.ExternalStream] = None + self.cp_unpack_stream: Optional[cp.cuda.ExternalStream] = None + + def load_kernels(self) -> None: + """Load and compile CUDA kernels from source.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + kernel_path = os.path.join( + current_dir, + "..", + "kernels", + "chunked_kernel.cu", + ) + + with open(kernel_path, "r") as f: + kernel_source = f.read() + + self.chunked_copy_kernel = cp.RawKernel( + kernel_source, + "chunked_batched_copy_kernel", + options=("-std=c++11",), + ) + + def set_streams(self, pack_stream, unpack_stream) -> None: + """ + Cache CuPy stream wrappers for kernel launching. + + This eliminates per-launch overhead of stream pointer extraction + and CuPy ExternalStream creation. + + Args: + pack_stream: CUDA stream for pack operations + unpack_stream: CUDA stream for unpack operations + """ + _, pack_stream_ptr = pack_stream.__cuda_stream__() + _, unpack_stream_ptr = unpack_stream.__cuda_stream__() + self.cp_pack_stream = cp.cuda.ExternalStream(pack_stream_ptr) + self.cp_unpack_stream = cp.cuda.ExternalStream(unpack_stream_ptr) + + def launch_pack( + self, + gpu_plan: Tuple[Any, Any, Any, int], + pack_stream, + torch_pack_stream: torch.cuda.ExternalStream, + pack_event: torch.cuda.Event, + ) -> None: + """ + Launch pack kernel to copy data from user tensors to send buffer. + + Args: + gpu_plan: Tuple of (cp_src_addrs, cp_dst_addrs, cp_sizes, num_chunks) + as CuPy arrays + pack_stream: CUDA stream (cuda.core.experimental.Stream) - unused, + kept for compatibility + torch_pack_stream: PyTorch external stream wrapper + pack_event: CUDA event to record after kernel launch + """ + nvtx.range_push("Launch Pack Kernel") + if not gpu_plan: + nvtx.range_pop() + return + + # Unpack cached CuPy arrays from gpu_plan + cp_src, cp_dst, cp_sizes, num_chunks = gpu_plan + + # Grid/Block configuration + threads_per_block = 1024 + num_blocks = 75 + + # Launch kernel using cached CuPy stream + assert self.chunked_copy_kernel is not None + assert self.cp_pack_stream is not None + self.chunked_copy_kernel( + (num_blocks,), + (threads_per_block,), + (cp_src, cp_dst, cp_sizes, num_chunks), + stream=self.cp_pack_stream, + ) + nvtx.range_pop() + # Record event on PyTorch stream + pack_event.record(stream=torch_pack_stream) + + def launch_unpack( + self, + gpu_plan: Tuple[Any, Any, Any, int], + unpack_stream, + torch_unpack_stream: torch.cuda.ExternalStream, + unpack_event: torch.cuda.Event, + ) -> None: + """ + Launch unpack kernel to copy data from receive buffer to user tensors. + + Args: + gpu_plan: Tuple of (cp_src_addrs, cp_dst_addrs, cp_sizes, num_chunks) + as CuPy arrays + unpack_stream: CUDA stream (cuda.core.experimental.Stream) - unused, + kept for compatibility + torch_unpack_stream: PyTorch external stream wrapper + unpack_event: CUDA event to record after kernel launch + """ + nvtx.range_push("Launch Unpack Kernel") + if not gpu_plan: + nvtx.range_pop() + return + + # Unpack cached CuPy arrays from gpu_plan + cp_src, cp_dst, cp_sizes, num_chunks = gpu_plan + + # Grid/Block configuration + threads_per_block = 1024 + num_blocks = 75 + + # Launch kernel using cached CuPy stream + assert self.chunked_copy_kernel is not None + assert self.cp_unpack_stream is not None + self.chunked_copy_kernel( + (num_blocks,), + (threads_per_block,), + (cp_src, cp_dst, cp_sizes, num_chunks), + stream=self.cp_unpack_stream, + ) + nvtx.range_pop() + # Record event on PyTorch stream + unpack_event.record(stream=torch_unpack_stream) + + diff --git a/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py b/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py new file mode 100644 index 00000000000..bcd43ea1da2 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py @@ -0,0 +1,289 @@ +""" +Pipelined communication execution engine. + +Orchestrates the pack/send/unpack pipeline with double-buffering +and proper stream synchronization. +""" + +from typing import List, Dict, Optional + +import nvshmem.core +import torch + +from ..logger import PELogger +from ..nvshmem_types import SendRequest, ReceiveRequest, ScheduledBatch +from .kernel_launcher import KernelLauncher +from ..memory.double_buffer_manager import DoubleBufferManager + + +class PipelineExecutor: + """Executes pipelined NVSHMEM communication with pack/send/unpack overlap.""" + + def __init__( + self, + kernel_launcher: KernelLauncher, + buffer_manager: DoubleBufferManager, + my_pe: int, + ): + """ + Initialize pipeline executor. + + Args: + kernel_launcher: KernelLauncher instance for pack/unpack kernels + buffer_manager: DoubleBufferManager for send/recv buffers + my_pe: This PE's rank + """ + self.kernel_launcher = kernel_launcher + self.buffer_manager = buffer_manager + self.my_pe = my_pe + + # Streams (will be set by service) + self.pack_stream = None + self.unpack_stream = None + self.send_stream = None + self.copy_stream = None + + self.torch_pack_stream = None + self.torch_unpack_stream = None + self.torch_copy_stream = None + + # Events for double-buffered synchronization + self.pack_events = [] + self.unpack_events = [] + + def set_streams( + self, + pack_stream, + unpack_stream, + send_stream, + copy_stream, + torch_pack_stream, + torch_unpack_stream, + torch_copy_stream, + ): + """Set CUDA streams for execution.""" + self.pack_stream = pack_stream + self.unpack_stream = unpack_stream + self.send_stream = send_stream + self.copy_stream = copy_stream + + self.torch_pack_stream = torch_pack_stream + self.torch_unpack_stream = torch_unpack_stream + self.torch_copy_stream = torch_copy_stream + + def set_events(self, pack_events: List, unpack_events: List): + """Set double-buffered CUDA events.""" + self.pack_events = pack_events + self.unpack_events = unpack_events + + def execute_pipeline( + self, + iter_schedules: List[Dict[str, Optional[ScheduledBatch]]], + num_iterations: int, + ) -> None: + """ + Execute pipelined communication. + + Pipeline stages: + 1. Pack NEXT iteration (async) + 2. Unpack PRIOR iteration (async) + 3. Send CURRENT iteration (sync) + 4. Barrier + 5. Wait for async pack/unpack to complete + + Args: + iter_schedules: List of iteration schedules + num_iterations: Total number of iterations + """ + PELogger.info(f"Executing pipeline: {num_iterations} iterations") + + # Priming: Pack iteration 0 and WAIT for completion + if num_iterations > 0 and iter_schedules[0]["send"]: + torch.cuda.nvtx.range_push("Priming") + PELogger.debug("Priming: Packing iteration 0") + self._launch_pack(0, iter_schedules[0]["send"]) + self.pack_events[0].synchronize() + torch.cuda.nvtx.range_pop() + + for i in range(num_iterations): + torch.cuda.nvtx.range_push(f"Iteration {i}") + has_send = iter_schedules[i]["send"] is not None + has_recv = iter_schedules[i]["recv"] is not None + has_next_send = ( + i + 1 < num_iterations + and iter_schedules[i + 1]["send"] is not None + ) + has_prior_recv = i > 0 and iter_schedules[i - 1]["recv"] is not None + + slot = i % 2 + + # Log iteration start + send_info = ( + f" → PE {iter_schedules[i]['send'].dest_pe} " + f"({iter_schedules[i]['send'].total_size} bytes)" + if has_send + else "" + ) + recv_info = ( + f" ← PE {iter_schedules[i]['recv'].src_pe} " + f"({iter_schedules[i]['recv'].total_size} bytes)" + if has_recv + else "" + ) + PELogger.debug( + f"Iteration {i}/{num_iterations}: slot={slot}{send_info}{recv_info}" + ) + + # Step 1: Pack NEXT iteration (async) + if has_next_send: + torch.cuda.nvtx.range_push("Step 1: Pack Next") + next_batch = iter_schedules[i + 1]["send"] + assert next_batch is not None + PELogger.debug( + f" Pack next (iter {i+1}): {len(next_batch.tasks)} tasks " + f"→ PE {next_batch.dest_pe}" + ) + self._launch_pack(i + 1, next_batch) + torch.cuda.nvtx.range_pop() + + # Step 2: Unpack PRIOR iteration (async) + if has_prior_recv: + torch.cuda.nvtx.range_push("Step 2: Unpack Prior") + prior_batch = iter_schedules[i - 1]["recv"] + assert prior_batch is not None + PELogger.debug( + f" Unpack prior (iter {i-1}): {prior_batch.total_size} bytes " + f"← PE {prior_batch.src_pe}" + ) + self._launch_unpack(i - 1, prior_batch) + torch.cuda.nvtx.range_pop() + + # Step 3: Send CURRENT iteration + if has_send: + torch.cuda.nvtx.range_push("Step 3: Send Current") + batch = iter_schedules[i]["send"] + assert batch is not None + transfer_size = batch.total_size + PELogger.debug( + f" Send current: {transfer_size} bytes → PE {batch.dest_pe}" + ) + + nvshmem.core.put( + self.buffer_manager.recv_slots[slot][0:transfer_size], + self.buffer_manager.send_slots[slot][0:transfer_size], + batch.dest_pe, + stream=self.send_stream, + ) + torch.cuda.nvtx.range_pop() + + # Ensure send completes + self.send_stream.sync() + nvshmem.core.quiet(stream=self.send_stream) + + # Step 4: Global barrier + torch.cuda.nvtx.range_push("Step 4: Barrier") + nvshmem.core.barrier_all(stream=self.send_stream) + self.send_stream.sync() + torch.cuda.nvtx.range_pop() + + # Step 5: Wait for async pack/unpack to complete + torch.cuda.nvtx.range_push("Step 5: Wait Async") + if has_prior_recv: + self.unpack_events[(i - 1) % 2].synchronize() + if has_next_send: + self.pack_events[(i + 1) % 2].synchronize() + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_pop() + + # Final unpack for last iteration + if num_iterations > 0 and iter_schedules[num_iterations - 1]["recv"]: + torch.cuda.nvtx.range_push("Final Unpack") + PELogger.debug(f"Final unpack: iteration {num_iterations-1}") + last_recv = iter_schedules[num_iterations - 1]["recv"] + assert last_recv is not None + self._launch_unpack(num_iterations - 1, last_recv) + self.unpack_events[(num_iterations - 1) % 2].synchronize() + torch.cuda.nvtx.range_pop() + + PELogger.info(f"Pipeline complete: {num_iterations} iterations") + + def _launch_pack(self, iteration: int, batch: ScheduledBatch) -> None: + """Launch pack kernel for given iteration.""" + if not batch.gpu_plan: + return + + self.kernel_launcher.launch_pack( + batch.gpu_plan, + self.pack_stream, + self.torch_pack_stream, + self.pack_events[iteration % 2], + ) + + def _launch_unpack(self, iteration: int, batch: ScheduledBatch) -> None: + """Launch unpack kernel for given iteration.""" + if not batch.gpu_plan: + return + + self.kernel_launcher.launch_unpack( + batch.gpu_plan, + self.unpack_stream, + self.torch_unpack_stream, + self.unpack_events[iteration % 2], + ) + + def process_self_moves( + self, + send_requests: List[SendRequest], + receive_requests: List[ReceiveRequest], + ) -> None: + """ + Handle same-PE transfers (where src_pe == dest_pe == my_pe). + + Uses PyTorch copy on the copy stream for efficiency. + + Args: + send_requests: List of send requests + receive_requests: List of receive requests + """ + # Match send/recv requests where src_pe == dest_pe == my_pe + local_sends = { + r.task_id: r for r in send_requests if r.dest_pe == self.my_pe + } + local_recvs = [ + r for r in receive_requests if r.src_pe == self.my_pe + ] + + if local_recvs: + PELogger.debug(f"Processing {len(local_recvs)} self-moves") + + num_processed = 0 + with torch.cuda.stream(self.torch_copy_stream): + for recv_req in local_recvs: + if recv_req.task_id in local_sends: + send_req = local_sends[recv_req.task_id] + PELogger.debug( + " Self-move: task_id=%d, size=%d bytes", + recv_req.task_id, + send_req.size, + ) + + # Create views of the tensors with offsets + src_view = send_req.src_tensor[ + send_req.src_pos : send_req.src_pos + send_req.size + ] + dest_view = recv_req.dest_tensor[ + recv_req.dest_pos : recv_req.dest_pos + send_req.size + ] + + # Async copy on the copy stream + dest_view.copy_(src_view, non_blocking=True) + num_processed += 1 + + # Synchronize the PyTorch stream + self.torch_copy_stream.synchronize() + + if num_processed > 0: + PELogger.info("Self-moves complete: %d transfers", num_processed) + + diff --git a/megatron/core/resharding/nvshmem_copy_service/kernels/chunked_kernel.cu b/megatron/core/resharding/nvshmem_copy_service/kernels/chunked_kernel.cu new file mode 100644 index 00000000000..e5b8fcc9a85 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/kernels/chunked_kernel.cu @@ -0,0 +1,103 @@ + +#include + +// CUDA-compatible types (no C++ standard library headers for NVRTC) +typedef unsigned char uint8_t; +typedef unsigned long long uint64_t; +typedef uint64_t uintptr_t; + +// ============================================================================ +// Kernel Configuration Constants (from ChunkedKernel.h) +// ============================================================================ + +constexpr int CHUNK_SIZE = 128 * 1024; // 128KB per chunk +constexpr int NUM_BLOCKS = 75; // Fixed grid size +constexpr int THREADS_PER_BLOCK = 1024; // Fixed block size +constexpr int FLOAT4_SIZE = 16; // 16 bytes per float4 +constexpr int MAX_CHUNKS_PER_BLOCK = 512; // Max chunks per block for shared memory + +extern "C" { + +/** + * Chunked batched copy kernel implementation + * + * This kernel performs efficient batched memory copies using: + * 1. Contiguous block assignment for better load balancing + * 2. Shared memory prefetching of chunk metadata + * 3. Vectorized float4 (16-byte) copies for aligned data + * 4. Byte-by-byte fallback for unaligned or small data + */ +__global__ void chunked_batched_copy_kernel( + uint8_t** src_addrs, + uint8_t** dst_addrs, + size_t* sizes, + int total_chunks +) { + // Shared memory for metadata prefetching + __shared__ uint8_t* s_src_addrs[MAX_CHUNKS_PER_BLOCK]; + __shared__ uint8_t* s_dst_addrs[MAX_CHUNKS_PER_BLOCK]; + __shared__ size_t s_sizes[MAX_CHUNKS_PER_BLOCK]; + + // Contiguous block assignment: block i processes chunks [start_chunk, end_chunk) + int chunks_per_block = (total_chunks + gridDim.x - 1) / gridDim.x; // Ceiling division + int start_chunk = blockIdx.x * chunks_per_block; + int end_chunk = start_chunk + chunks_per_block; + if (end_chunk > total_chunks) { + end_chunk = total_chunks; + } + int num_chunks_this_block = end_chunk - start_chunk; + + // Phase 1: Cooperative loading of metadata to shared memory + // All 1024 threads cooperate to load metadata from global memory + for (int i = threadIdx.x; i < num_chunks_this_block; i += blockDim.x) { + int global_chunk_id = start_chunk + i; + s_src_addrs[i] = src_addrs[global_chunk_id]; + s_dst_addrs[i] = dst_addrs[global_chunk_id]; + s_sizes[i] = sizes[global_chunk_id]; + } + __syncthreads(); + + // Phase 2: Process each chunk using metadata from shared memory + for (int chunk_id = 0; chunk_id < num_chunks_this_block; chunk_id++) { + uint8_t* src = s_src_addrs[chunk_id]; + uint8_t* dst = s_dst_addrs[chunk_id]; + size_t size = s_sizes[chunk_id]; + + // Check if both src and dst are aligned to 16 bytes for float4 access + uintptr_t src_addr = (uintptr_t)src; + uintptr_t dst_addr = (uintptr_t)dst; + bool is_aligned = ((src_addr % FLOAT4_SIZE) == 0) && ((dst_addr % FLOAT4_SIZE) == 0); + + if (is_aligned && size >= FLOAT4_SIZE) { + // Fast path: vectorized float4 copies + size_t aligned_size = (size / FLOAT4_SIZE) * FLOAT4_SIZE; + + // All 1024 threads cooperate on float4 copies + #pragma unroll 4 + for (size_t offset = threadIdx.x * FLOAT4_SIZE; + offset < aligned_size; + offset += blockDim.x * FLOAT4_SIZE) { + // Vectorized 16-byte load and store + float4 data = *((float4*)(src + offset)); + *((float4*)(dst + offset)) = data; + } + + // Handle remaining bytes (< 16 bytes) with byte-by-byte copy + for (size_t offset = aligned_size + threadIdx.x; + offset < size; + offset += blockDim.x) { + dst[offset] = src[offset]; + } + } else { + // Fallback path: byte-by-byte copy for unaligned addresses + // Still use all threads for parallelism + for (size_t offset = threadIdx.x; offset < size; offset += blockDim.x) { + dst[offset] = src[offset]; + } + } + } +} + +} + + diff --git a/megatron/core/resharding/nvshmem_copy_service/logger.py b/megatron/core/resharding/nvshmem_copy_service/logger.py new file mode 100644 index 00000000000..d4516c5761f --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/logger.py @@ -0,0 +1,199 @@ +""" +Per-PE Logger with colored console and file output. + +Copied in-tree from the standalone nvshmem_copy_service implementation. +""" + +import logging +import os +from datetime import datetime +from typing import Optional + + +class ColoredFormatter(logging.Formatter): + """Custom formatter that adds color codes for console output.""" + + def __init__(self, fmt: str, pe_id: int, use_color: bool = True): + super().__init__(fmt) + self.pe_id = pe_id + self.use_color = use_color + + # ANSI color codes matching C++ implementation + self.colors = { + 0: "\033[31m", # Red + 1: "\033[32m", # Green + 2: "\033[33m", # Yellow + 3: "\033[34m", # Blue + 4: "\033[35m", # Magenta + 5: "\033[36m", # Cyan + 6: "\033[91m", # Bright Red + 7: "\033[92m", # Bright Green + } + self.reset = "\033[0m" + + def formatTime(self, record, datefmt=None): + ct = self.converter(record.created) + if datefmt: + s = datetime.fromtimestamp(record.created).strftime(datefmt) + # For file logs, replace %f with milliseconds + if "%f" in datefmt: + s = s.replace("%f", f"{int(record.msecs):03d}") + else: + s = datetime.fromtimestamp(record.created).strftime("%H:%M:%S") + s = f"{s}.{int(record.msecs):03d}" + return s + + def format(self, record): + # Save original message + original_msg = record.msg + + if self.use_color and self.pe_id >= 0: + color = self.colors.get(self.pe_id, "\033[37m") # White for others + record.msg = f"{color}{record.msg}{self.reset}" + + result = super().format(record) + + # Restore original message for other handlers + record.msg = original_msg + + return result + + +class PELogger: + """Per-PE logger with colored console and file output.""" + + _logger: Optional[logging.Logger] = None + _pe_id: int = -1 + _level: int = logging.INFO + + @classmethod + def init(cls, pe_id: int, level: str = "INFO", logs_dir: str = "logs"): + """ + Initialize logger for this PE. + + Args: + pe_id: Process element ID + level: Log level (TRACE, DEBUG, INFO, WARN, ERROR) + logs_dir: Directory for log files + """ + cls._pe_id = pe_id + + # Convert level string to logging level + level_map = { + "TRACE": logging.DEBUG - 5, # Custom level below DEBUG + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "SUMMARY": logging.INFO, + "WARN": logging.WARNING, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + cls._level = level_map.get(level.upper(), logging.INFO) + + # Create logs directory if it doesn't exist + os.makedirs(logs_dir, exist_ok=True) + + # Create logger + logger_name = f"PE_{pe_id}" + cls._logger = logging.getLogger(logger_name) + cls._logger.setLevel(cls._level) + cls._logger.propagate = False + + # Remove existing handlers to avoid duplicates + cls._logger.handlers.clear() + + # 1. Console handler with color + console_handler = logging.StreamHandler() + console_handler.setLevel(cls._level) + console_format = "[PE %d] [%%(asctime)s] [%%(levelname)s] %%(message)s" % pe_id + console_formatter = ColoredFormatter(console_format, pe_id, use_color=True) + console_handler.setFormatter(console_formatter) + cls._logger.addHandler(console_handler) + + # 2. File handler without color + log_filename = os.path.join(logs_dir, f"pe_{pe_id}.log") + file_handler = logging.FileHandler(log_filename, mode="w") + file_handler.setLevel(cls._level) + file_format = "[PE %d] [%%(asctime)s] [%%(levelname)s] %%(message)s" % pe_id + file_formatter = ColoredFormatter(file_format, pe_id, use_color=False) + file_handler.setFormatter(file_formatter) + cls._logger.addHandler(file_handler) + + @classmethod + def set_level(cls, level: str): + """Set the logging level.""" + level_map = { + "TRACE": logging.DEBUG - 5, + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "SUMMARY": logging.INFO, + "WARN": logging.WARNING, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + cls._level = level_map.get(level.upper(), logging.INFO) + if cls._logger: + cls._logger.setLevel(cls._level) + for handler in cls._logger.handlers: + handler.setLevel(cls._level) + + @classmethod + def trace(cls, msg: str): + """Log at TRACE level (most detailed).""" + if cls._logger: + cls._logger.log(logging.DEBUG - 5, msg) + + @classmethod + def debug(cls, msg: str): + """Log at DEBUG level.""" + if cls._logger: + cls._logger.debug(msg) + + @classmethod + def info(cls, msg: str): + """Log at INFO level.""" + if cls._logger: + cls._logger.info(msg) + + @classmethod + def summary(cls, msg: str): + """Log summary information (INFO level with [SUMMARY] prefix).""" + if cls._logger: + cls._logger.info(f"[SUMMARY] {msg}") + + @classmethod + def warn(cls, msg: str): + """Log at WARNING level.""" + if cls._logger: + cls._logger.warning(msg) + + @classmethod + def warning(cls, msg: str): + """Log at WARNING level (alias for warn).""" + cls.warn(msg) + + @classmethod + def error(cls, msg: str): + """Log at ERROR level.""" + if cls._logger: + cls._logger.error(msg) + + @classmethod + def critical(cls, msg: str): + """Log at CRITICAL level.""" + if cls._logger: + cls._logger.critical(msg) + + @classmethod + def shutdown(cls): + """Shutdown the logger and flush all handlers.""" + if cls._logger: + for handler in cls._logger.handlers: + handler.flush() + handler.close() + cls._logger.handlers.clear() + cls._logger = None + + diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py b/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py new file mode 100644 index 00000000000..5c9f8b573f4 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py @@ -0,0 +1,8 @@ +"""Memory management utilities for NVSHMEM operations.""" + +from .double_buffer_manager import DoubleBufferManager +from .tensor_pointer_utils import TensorPointerExtractor + +__all__ = ["DoubleBufferManager", "TensorPointerExtractor"] + + diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py b/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py new file mode 100644 index 00000000000..314db889385 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py @@ -0,0 +1,76 @@ +""" +Double buffer management for NVSHMEM symmetric memory. + +Manages send and receive buffers with double-buffering for pipelined communication. +""" + +import nvshmem.core.interop.torch +import torch + +from ..nvshmem_types import MAX_SEGMENT_SIZE + + +class DoubleBufferManager: + """Manages double-buffered NVSHMEM symmetric buffers for send/receive operations.""" + + def __init__(self, slot_size: int = MAX_SEGMENT_SIZE): + """ + Initialize buffer manager. + + Args: + slot_size: Size of each buffer slot in bytes (default: 256MB) + """ + self.slot_size = slot_size + self.send_slots = [None, None] + self.recv_slots = [None, None] + + def allocate(self) -> None: + """Allocate NVSHMEM symmetric buffers for double-buffering.""" + for i in range(2): + self.send_slots[i] = nvshmem.core.interop.torch.bytetensor( + (self.slot_size,), + dtype=torch.uint8, + ) + self.recv_slots[i] = nvshmem.core.interop.torch.bytetensor( + (self.slot_size,), + dtype=torch.uint8, + ) + # Zero out buffers + self.send_slots[i].zero_() + self.recv_slots[i].zero_() + + def get_send_slot(self, iteration: int): + """ + Get send buffer for given iteration. + + Args: + iteration: Iteration number + + Returns: + NVSHMEM tensor for sending + """ + return self.send_slots[iteration % 2] + + def get_recv_slot(self, iteration: int): + """ + Get receive buffer for given iteration. + + Args: + iteration: Iteration number + + Returns: + NVSHMEM tensor for receiving + """ + return self.recv_slots[iteration % 2] + + def free(self) -> None: + """Free NVSHMEM symmetric buffers.""" + for i in range(2): + if self.send_slots[i] is not None: + nvshmem.core.interop.torch.free_tensor(self.send_slots[i]) + self.send_slots[i] = None + if self.recv_slots[i] is not None: + nvshmem.core.interop.torch.free_tensor(self.recv_slots[i]) + self.recv_slots[i] = None + + diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py b/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py new file mode 100644 index 00000000000..60e651aa998 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py @@ -0,0 +1,33 @@ +""" +Utilities for extracting data pointers from different tensor types. + +Supports PyTorch tensors, CuPy arrays, and raw integer pointers. +""" + +from typing import Any + +import torch + + +class TensorPointerExtractor: + """Extract memory pointers from various tensor types.""" + + @staticmethod + def get_pointer(tensor: Any) -> int: + """ + Extract the data pointer from a tensor. + + Args: + tensor: Can be torch.Tensor, CuPy array, or raw int pointer + + Returns: + int: Memory address of the tensor data + """ + if isinstance(tensor, torch.Tensor): + return tensor.data_ptr() + elif hasattr(tensor, "data"): # CuPy array + return tensor.data.ptr + else: # Assume raw integer pointer + return int(tensor) + + diff --git a/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py b/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py new file mode 100644 index 00000000000..e83dbc51d60 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass, field +from typing import List, Any + +# Constants +MAX_SEGMENT_SIZE = 256 * 1024 * 1024 # 256MB +MAX_TASKS_PER_BATCH = 10000 + + +@dataclass +class SendRequest: + task_id: int + src_tensor: Any # cupy.ndarray or pointer + src_pos: int + size: int + dest_pe: int + + +@dataclass +class ReceiveRequest: + task_id: int + dest_tensor: Any # cupy.ndarray or pointer + dest_pos: int + size: int + src_pe: int + + +@dataclass +class WorkloadGroup: + dest_pe: int + tasks: List[SendRequest] = field(default_factory=list) + total_size: int = 0 + + +@dataclass +class ScheduledBatch: + src_pe: int + dest_pe: int + batch_index: int + iteration: int + # Metadata for GPU execution + gpu_plan: Any = None # Placeholder for GPU-resident plan + tasks: List[SendRequest] = field(default_factory=list) + total_size: int = 0 + tasks_summary: Any = None # WorkloadSummary + + +@dataclass +class WorkloadSummary: + total_size: int + task_ids: List[int] + task_sizes: List[int] + + +@dataclass +class TransferMetadata: + ptrs: Any # cupy array of uint64 (pointers) + sizes: Any # cupy array of uint64 (sizes) + num_tasks: int + total_size: int + + diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py b/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py new file mode 100644 index 00000000000..d00914b6ef0 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py @@ -0,0 +1,15 @@ +"""Planning components for task segmentation, workload packing, and scheduling.""" + +from .communication_scheduler import CommunicationScheduler +from .gpu_execution_planner import GPUExecutionPlanner +from .task_segmenter import TaskSegmenter +from .workload_packer import WorkloadPacker + +__all__ = [ + "CommunicationScheduler", + "GPUExecutionPlanner", + "TaskSegmenter", + "WorkloadPacker", +] + + diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py new file mode 100644 index 00000000000..6399cc2a393 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py @@ -0,0 +1,214 @@ +from typing import Dict, List, Tuple + +import torch + +from ..logger import PELogger +from ..nvshmem_types import WorkloadGroup, ScheduledBatch, WorkloadSummary + + +class CommunicationScheduler: + """ + Builds a conflict-free, iteration-based schedule for communication. + Ensures that in any given iteration, a PE is not overloaded. + """ + + def __init__(self): + self.num_iterations = 0 + + def build_schedule( + self, + workloads: Dict[int, List[WorkloadGroup]], + my_pe: int, + n_pes: int, + group=None, + ) -> Tuple[Dict[int, List[ScheduledBatch]], Dict[Tuple[int, int, int], WorkloadSummary]]: + """ + Main scheduling method. + 1. Exchanges workload info with other PEs. + 2. Assigns batches to iterations. + 3. Returns: + - local schedule (iteration -> list of batches) + - global workload summaries (key: (src, dest, batch_idx) -> summary) + """ + total_local_batches = sum(len(groups) for groups in workloads.values()) + PELogger.info( + f"Building schedule: {total_local_batches} local batches, {n_pes} PEs" + ) + + # Step 1: Collect all batches across all PE pairs + PELogger.debug("Collecting batches from all PEs...") + all_batches = self._collect_all_batches(workloads, my_pe, n_pes, group) + PELogger.debug(f"Collected {len(all_batches)} total batches globally") + + # Step 2: Assign batches to iterations using conflict-free algorithm + PELogger.debug("Assigning batches to iterations...") + self._assign_iterations(all_batches) + PELogger.info(f"Schedule built: {self.num_iterations} iterations") + + # Step 3: Exchange detailed workload summaries (Task IDs/Sizes) + # This is needed for receivers to know what tasks are in each batch + PELogger.debug("Exchanging workload summaries...") + global_summaries = self._exchange_workload_summaries( + workloads, + my_pe, + n_pes, + group, + ) + PELogger.debug(f"Exchanged {len(global_summaries)} workload summaries") + + # Step 4: Build schedule map for this PE + my_batches = [ + b for b in all_batches if b.src_pe == my_pe or b.dest_pe == my_pe + ] + my_batches.sort(key=lambda x: x.iteration) + + final_schedule: Dict[int, List[ScheduledBatch]] = {} + for b in my_batches: + final_schedule.setdefault(b.iteration, []).append(b) + + return final_schedule, global_summaries + + def _collect_all_batches( + self, + workloads: Dict[int, List[WorkloadGroup]], + my_pe: int, + n_pes: int, + group=None, + ) -> List[ScheduledBatch]: + """ + Exchanges batch counts and details with all PEs to build a global view. + Uses torch.distributed for reliable communication. + """ + import torch.distributed as dist + + # Build local batch list + local_batches: List[Tuple[int, int, int]] = [] + for dest_pe, groups in workloads.items(): + if dest_pe == my_pe: + continue + for i, _ in enumerate(groups): + local_batches.append((my_pe, dest_pe, i)) # (src, dest, batch_idx) + + PELogger.debug(f" Local batch count: {len(local_batches)}") + PELogger.debug(f" Local batches: {local_batches}") + + # Gather all batches from all PEs using torch.distributed + all_batches_list: List[List[Tuple[int, int, int]] | None] = [None] * n_pes + dist.all_gather_object(all_batches_list, local_batches, group=group) + + # Flatten into global batch list + global_batches: List[ScheduledBatch] = [] + for pe_batches in all_batches_list: + if pe_batches is None: + continue + for src, dest, idx in pe_batches: + global_batches.append( + ScheduledBatch( + src_pe=src, + dest_pe=dest, + batch_index=idx, + iteration=-1, + ) + ) + + PELogger.debug(f" Global batches collected: {len(global_batches)} total") + + # Group by source for readability + batches_by_src: Dict[int, List[Tuple[int, int]]] = {} + for b in global_batches: + batches_by_src.setdefault(b.src_pe, []).append((b.dest_pe, b.batch_index)) + for src_pe in sorted(batches_by_src.keys()): + PELogger.debug(f" PE {src_pe} sends to: {batches_by_src[src_pe]}") + + return global_batches + + def _assign_iterations(self, batches: List[ScheduledBatch]): + self.num_iterations = 0 + batches.sort(key=lambda x: (x.src_pe, x.dest_pe, x.batch_index)) + + for batch in batches: + iteration = 0 + assigned = False + while not assigned: + if not self._has_conflict(batch, iteration, batches): + batch.iteration = iteration + self.num_iterations = max(self.num_iterations, iteration + 1) + assigned = True + PELogger.debug( + f" Assigned batch ({batch.src_pe} → {batch.dest_pe}, " + f"idx={batch.batch_index}) to iteration {iteration}" + ) + else: + iteration += 1 + + def _has_conflict( + self, + batch: ScheduledBatch, + iteration: int, + all_batches: List[ScheduledBatch], + ) -> bool: + for other in all_batches: + if other.iteration == iteration and other is not batch: + if other.src_pe == batch.src_pe or other.dest_pe == batch.dest_pe: + return True + return False + + def _exchange_workload_summaries( + self, + workloads: Dict[int, List[WorkloadGroup]], + my_pe: int, + n_pes: int, + group=None, + ) -> Dict[Tuple[int, int, int], WorkloadSummary]: + """ + Exchange detailed workload content using torch.distributed. + Simple and reliable - no NVSHMEM symmetric memory issues. + """ + import torch.distributed as dist + + # Build local summaries as a simple dict: + # (src, dest, batch_idx) -> {total_size, task_ids, task_sizes} + local_summaries: Dict[Tuple[int, int, int], Dict[str, object]] = {} + batch_count = 0 + total_tasks = 0 + + for dest_pe, groups in workloads.items(): + if dest_pe == my_pe: + continue + for batch_idx, group in enumerate(groups): + key = (my_pe, dest_pe, batch_idx) + local_summaries[key] = { + "total_size": group.total_size, + "task_ids": [t.task_id for t in group.tasks], + "task_sizes": [t.size for t in group.tasks], + } + batch_count += 1 + total_tasks += len(group.tasks) + + PELogger.debug( + f" Local summaries: {batch_count} batches, {total_tasks} tasks" + ) + + # Gather all summaries from all PEs using torch.distributed + all_summaries_list: List[Dict[Tuple[int, int, int], Dict[str, object]] | None] = [ # noqa: E501 + None + ] * n_pes + dist.all_gather_object(all_summaries_list, local_summaries, group=group) + + # Merge into global map + global_map: Dict[Tuple[int, int, int], WorkloadSummary] = {} + for pe_summaries in all_summaries_list: + if pe_summaries is None: + continue + for key, data in pe_summaries.items(): + summary = WorkloadSummary( + total_size=int(data["total_size"]), + task_ids=list(data["task_ids"]), # type: ignore[arg-type] + task_sizes=list(data["task_sizes"]), # type: ignore[arg-type] + ) + global_map[key] = summary + + PELogger.debug(f" Exchanged {len(global_map)} workload summaries") + return global_map + + diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py new file mode 100644 index 00000000000..22428f320e3 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py @@ -0,0 +1,248 @@ +""" +GPU execution planning for pack/unpack operations. + +Converts high-level task descriptions into GPU-ready metadata +(pointer arrays, sizes, chunking) for kernel execution. +""" + +from typing import List, Dict, Tuple, Any, Optional + +import cupy as cp +import torch + +from ..logger import PELogger +from ..memory.tensor_pointer_utils import TensorPointerExtractor +from ..nvshmem_types import ( + SendRequest, + ReceiveRequest, + ScheduledBatch, + WorkloadGroup, + WorkloadSummary, +) + + +class GPUExecutionPlanner: + """Plans GPU kernel execution by building pointer arrays and metadata.""" + + def __init__(self): + self.tensor_utils = TensorPointerExtractor() + self.CHUNK_SIZE = 128 * 1024 # 128KB chunks + + def create_gpu_plans( + self, + iter_schedules: List[Dict[str, Optional[ScheduledBatch]]], + send_slots: List, + recv_slots: List, + receive_requests: List[ReceiveRequest], + ) -> None: + """ + Build GPU execution plans for all iterations. + + Modifies iter_schedules in-place by adding gpu_plan to each batch. + + Args: + iter_schedules: List of iteration schedules (dicts with 'send' and 'recv') + send_slots: List of send buffer slots + recv_slots: List of receive buffer slots + receive_requests: List of all receive requests for matching + """ + PELogger.debug("Creating GPU plans for %d iterations", len(iter_schedules)) + for i, sched in enumerate(iter_schedules): + send_batch = sched["send"] + if send_batch: + # Build Pack Metadata + ptrs: List[int] = [] + positions: List[int] = [] + sizes: List[int] = [] + + for t in send_batch.tasks: + # Extract pointer from tensor + ptr = self.tensor_utils.get_pointer(t.src_tensor) + ptrs.append(ptr) + positions.append(t.src_pos) + sizes.append(t.size) + + # Plan kernel args for packing + send_batch.gpu_plan = self._plan_kernel_args( + ptrs, + positions, + sizes, + is_pack=True, + buffer_base=send_slots[i % 2].data_ptr(), + ) + task_ids = [t.task_id for t in send_batch.tasks] + PELogger.debug( + " Iter %d send plan: %d tasks → PE %d, %d bytes", + i, + len(send_batch.tasks), + send_batch.dest_pe, + send_batch.total_size, + ) + PELogger.debug( + " Send task IDs: %s", + task_ids[:10] if len(task_ids) <= 10 else task_ids[:10] + ["..."], + ) + + recv_batch = sched["recv"] + if recv_batch: + # Build Unpack Metadata + summary = recv_batch.tasks_summary + + # Skip if no summary available (shouldn't happen in normal operation) + if summary is None: + PELogger.error( + "Iter %d: recv batch from PE %d has no tasks_summary - " + "UNPACK WILL BE SKIPPED!", + i, + recv_batch.src_pe, + ) + recv_batch.gpu_plan = None + continue + + PELogger.debug( + " Iter %d recv from PE %d: %d tasks in summary", + i, + recv_batch.src_pe, + len(summary.task_ids), + ) + + ptrs = [] + positions = [] + sizes = [] + + # Create fast lookup map for receive requests + relevant_reqs: Dict[int, ReceiveRequest] = { + r.task_id: r + for r in receive_requests + if r.src_pe == recv_batch.src_pe + } + + # Match summary tasks with receive requests + matched_task_ids: List[int] = [] + unmatched_task_ids: List[int] = [] + for t_id, t_size in zip(summary.task_ids, summary.task_sizes): + if t_id in relevant_reqs: + req = relevant_reqs[t_id] + ptr = self.tensor_utils.get_pointer(req.dest_tensor) + ptrs.append(ptr) + positions.append(req.dest_pos) + sizes.append(t_size) # Use sender's size + matched_task_ids.append(t_id) + else: + unmatched_task_ids.append(t_id) + PELogger.error( + "Iter %d: Unexpected task %d from PE %d - " + "no matching recv request!", + i, + t_id, + recv_batch.src_pe, + ) + + if unmatched_task_ids: + PELogger.error( + " Iter %d: %d unmatched tasks from PE %d: %s", + i, + len(unmatched_task_ids), + recv_batch.src_pe, + unmatched_task_ids[:10], + ) + + # Plan kernel args for unpacking + recv_batch.gpu_plan = self._plan_kernel_args( + ptrs, + positions, + sizes, + is_pack=False, + buffer_base=recv_slots[i % 2].data_ptr(), + ) + + if recv_batch.gpu_plan is None: + PELogger.error( + " Iter %d recv plan: FAILED - no gpu_plan created for %d " + "tasks from PE %d", + i, + len(sizes), + recv_batch.src_pe, + ) + else: + PELogger.debug( + " Iter %d recv plan: %d tasks ← PE %d, %d bytes", + i, + len(sizes), + recv_batch.src_pe, + recv_batch.total_size, + ) + PELogger.debug( + " Recv task IDs: %s", + matched_task_ids[:10] + if len(matched_task_ids) <= 10 + else matched_task_ids[:10] + ["..."], + ) + + def _plan_kernel_args( + self, + ptrs: List[int], + positions: List[int], + sizes: List[int], + is_pack: bool, + buffer_base: int, + ) -> Optional[Tuple[cp.ndarray, cp.ndarray, cp.ndarray, int]]: + """ + Generate GPU-ready pointer arrays for kernel execution. + + Applies 128KB chunking to break large transfers into smaller pieces. + + Args: + ptrs: List of tensor data pointers + positions: List of positions within tensors + sizes: List of transfer sizes + is_pack: True for pack (user->buffer), False for unpack (buffer->user) + buffer_base: Base pointer of the buffer + + Returns: + Tuple of (cp_src_addrs, cp_dst_addrs, cp_sizes, num_chunks) as + CuPy arrays, or None if no work. + """ + h_src_addrs: List[int] = [] + h_dst_addrs: List[int] = [] + h_sizes: List[int] = [] + + packed_offset = 0 + + for ptr, pos, size in zip(ptrs, positions, sizes): + num_chunks = (size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE + + for c in range(num_chunks): + chunk_offset = c * self.CHUNK_SIZE + chunk_size = min(self.CHUNK_SIZE, size - chunk_offset) + + if is_pack: + # Pack: user tensor -> buffer + h_src_addrs.append(ptr + pos + chunk_offset) + h_dst_addrs.append(buffer_base + packed_offset + chunk_offset) + else: + # Unpack: buffer -> user tensor + h_src_addrs.append(buffer_base + packed_offset + chunk_offset) + h_dst_addrs.append(ptr + pos + chunk_offset) + + h_sizes.append(chunk_size) + + packed_offset += size + + total_chunks = len(h_sizes) + if total_chunks == 0: + return None + + # Move to GPU using PyTorch, then convert to CuPy for kernel launching + d_src_addrs = torch.tensor(h_src_addrs, dtype=torch.int64, device="cuda") + d_dst_addrs = torch.tensor(h_dst_addrs, dtype=torch.int64, device="cuda") + d_sizes = torch.tensor(h_sizes, dtype=torch.int64, device="cuda") + + # Convert to CuPy arrays (zero-copy) for kernel launching + cp_src_addrs = cp.asarray(d_src_addrs) + cp_dst_addrs = cp.asarray(d_dst_addrs) + cp_sizes = cp.asarray(d_sizes) + + return (cp_src_addrs, cp_dst_addrs, cp_sizes, total_chunks) + + diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py b/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py new file mode 100644 index 00000000000..0e98b8a7811 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py @@ -0,0 +1,97 @@ +from typing import List + +from ..nvshmem_types import SendRequest, ReceiveRequest, MAX_SEGMENT_SIZE + +# Constants for ID encoding (from C++ implementation) +REQUEST_ID_BASE = 1000000000 +SEGMENT_ID_MULTIPLIER = 1000 +MAX_REQUESTS = 1000000 +MAX_SEGMENTS_PER_REQUEST = 1000 + + +class TaskSegmenter: + """ + Splits large tasks (>256MB) into smaller segments to fit + into the fixed-size communication slots. + """ + + def _encode_segment_id(self, task_id: int, segment_index: int) -> int: + return REQUEST_ID_BASE + (task_id * SEGMENT_ID_MULTIPLIER) + segment_index + + def _calculate_num_segments(self, size: int) -> int: + return (size + MAX_SEGMENT_SIZE - 1) // MAX_SEGMENT_SIZE + + def _validate_segmentation(self, task_id: int, size: int) -> bool: + num_segments = self._calculate_num_segments(size) + if num_segments > MAX_SEGMENTS_PER_REQUEST: + print( + f"Error: Task {task_id} requires {num_segments} segments, " + f"exceeds max {MAX_SEGMENTS_PER_REQUEST}" + ) + return False + if task_id >= MAX_REQUESTS: + print(f"Error: Task ID {task_id} exceeds max {MAX_REQUESTS}") + return False + return True + + def segment_send_request(self, req: SendRequest) -> List[SendRequest]: + """ + Splits a single send request into multiple segments + if larger than MAX_SEGMENT_SIZE. + """ + if req.size <= MAX_SEGMENT_SIZE: + return [req] + + if not self._validate_segmentation(req.task_id, req.size): + raise ValueError(f"Task {req.task_id} validation failed") + + num_segments = self._calculate_num_segments(req.size) + output_requests: List[SendRequest] = [] + + for i in range(num_segments): + segment_offset = i * MAX_SEGMENT_SIZE + segment_size = min(MAX_SEGMENT_SIZE, req.size - segment_offset) + segment_task_id = self._encode_segment_id(req.task_id, i) + + new_req = SendRequest( + task_id=segment_task_id, + src_tensor=req.src_tensor, + src_pos=req.src_pos + segment_offset, + size=segment_size, + dest_pe=req.dest_pe, + ) + output_requests.append(new_req) + + return output_requests + + def segment_receive_request(self, req: ReceiveRequest) -> List[ReceiveRequest]: + """ + Splits a single receive request into multiple segments + if larger than MAX_SEGMENT_SIZE. + """ + if req.size <= MAX_SEGMENT_SIZE: + return [req] + + if not self._validate_segmentation(req.task_id, req.size): + raise ValueError(f"Task {req.task_id} validation failed") + + num_segments = self._calculate_num_segments(req.size) + output_requests: List[ReceiveRequest] = [] + + for i in range(num_segments): + segment_offset = i * MAX_SEGMENT_SIZE + segment_size = min(MAX_SEGMENT_SIZE, req.size - segment_offset) + segment_task_id = self._encode_segment_id(req.task_id, i) + + new_req = ReceiveRequest( + task_id=segment_task_id, + dest_tensor=req.dest_tensor, + dest_pos=req.dest_pos + segment_offset, + size=segment_size, + src_pe=req.src_pe, + ) + output_requests.append(new_req) + + return output_requests + + diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py b/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py new file mode 100644 index 00000000000..d6643220498 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py @@ -0,0 +1,107 @@ +from typing import List, Dict + +from ..logger import PELogger +from ..nvshmem_types import SendRequest, WorkloadGroup, MAX_SEGMENT_SIZE, MAX_TASKS_PER_BATCH + + +class WorkloadPacker: + """ + Packs individual SendRequests into WorkloadGroups (batches) + destined for the same PE, respecting size limits. + """ + + def pack_workloads( + self, + send_requests: List[SendRequest], + n_pes: int, + ) -> Dict[int, List[WorkloadGroup]]: + """ + Groups requests by destination PE and packs them into batches. + Returns a map: dest_pe -> list of batches + """ + PELogger.debug(f"Packing {len(send_requests)} send requests for {n_pes} PEs") + workloads: Dict[int, List[WorkloadGroup]] = {} + + # Group requests by destination PE + tasks_by_dest: Dict[int, List[SendRequest]] = {} + for req in send_requests: + tasks_by_dest.setdefault(req.dest_pe, []).append(req) + + # Pack tasks for each destination + for dest_pe in range(n_pes): + if dest_pe not in tasks_by_dest: + workloads[dest_pe] = [] + PELogger.debug(f" Dest PE {dest_pe}: 0 tasks → 0 batches") + continue + + tasks = tasks_by_dest[dest_pe] + workloads[dest_pe] = self._pack_single_destination(tasks, dest_pe) + + if workloads[dest_pe]: + total_size = sum(b.total_size for b in workloads[dest_pe]) + PELogger.debug( + " Dest PE %d: %d tasks → %d batches, %d bytes total", + dest_pe, + len(tasks), + len(workloads[dest_pe]), + total_size, + ) + else: + PELogger.debug( + " Dest PE %d: %d tasks → 0 batches (empty after packing)", + dest_pe, + len(tasks), + ) + + return workloads + + def _pack_single_destination( + self, + tasks: List[SendRequest], + dest_pe: int, + ) -> List[WorkloadGroup]: + if not tasks: + return [] + + # Sort tasks by size (descending) for better bin packing efficiency + tasks.sort(key=lambda x: x.size, reverse=True) + + batches: List[WorkloadGroup] = [] + current_batch = WorkloadGroup(dest_pe=dest_pe, tasks=[], total_size=0) + + for task in tasks: + # Check if adding this task would exceed batch constraints + would_exceed_size = current_batch.total_size + task.size > MAX_SEGMENT_SIZE + would_exceed_task_cap = len(current_batch.tasks) >= MAX_TASKS_PER_BATCH + + if (would_exceed_size or would_exceed_task_cap) and current_batch.tasks: + # Finalize current batch + batches.append(current_batch) + task_first_10_string = ", ".join( + [str(t.task_id) for t in current_batch.tasks[:10]] + ) + PELogger.debug( + " Packed batch to PE %d idx %d: %s... (total %d tasks)", + dest_pe, + len(batches) - 1, + task_first_10_string, + len(current_batch.tasks), + ) + # Start new batch + current_batch = WorkloadGroup( + dest_pe=dest_pe, + tasks=[], + total_size=0, + ) + + # Add task to current batch + current_batch.tasks.append(task) + current_batch.total_size += task.size + + # Add final batch if not empty + if current_batch.tasks: + batches.append(current_batch) + + return batches + + diff --git a/megatron/core/resharding/nvshmem_copy_service/service.py b/megatron/core/resharding/nvshmem_copy_service/service.py new file mode 100644 index 00000000000..55e3b212d53 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/service.py @@ -0,0 +1,441 @@ +""" +Remote Copy Service - Main orchestrator for NVSHMEM-based GPU-to-GPU transfers. + +This is an in-tree copy of the standalone Python implementation from +`kan/mcore-reshard/nvshmem_copy_service/python/service.py`, with imports +updated to use the Megatron package layout. +""" + +from typing import List, Dict, Tuple, Optional + +import nvshmem.core +import torch.cuda.nvtx as nvtx + +from .core import GPUResourceManager, KernelLauncher, PipelineExecutor +from .memory import DoubleBufferManager +from .nvshmem_types import ( + SendRequest, + ReceiveRequest, + ScheduledBatch, + WorkloadSummary, +) +from .planning import ( + TaskSegmenter, + WorkloadPacker, + CommunicationScheduler, + GPUExecutionPlanner, +) +from .logger import PELogger + + +class RemoteCopyService: + """ + Main service for managing remote GPU-to-GPU data transfers. + + Provides high-level API for registering transfers, scheduling, + and executing pipelined communication with NVSHMEM. + """ + + def __init__(self): + # Core components + self.gpu_resources = GPUResourceManager() + self.buffer_manager = DoubleBufferManager() + self.kernel_launcher = KernelLauncher() + self.pipeline_executor = None # Created after init + + # Planning components + self.task_segmenter = TaskSegmenter() + self.workload_packer = WorkloadPacker() + self.comm_scheduler = CommunicationScheduler() + self.gpu_planner = GPUExecutionPlanner() + + # State + self.send_requests: List[SendRequest] = [] + self.receive_requests: List[ReceiveRequest] = [] + self.iter_schedules: Optional[List[Dict]] = None + self.num_iterations: int = 0 + + # Events for double-buffering + self.pack_events = [] + self.unpack_events = [] + + @property + def my_pe(self) -> int: + """Get this PE's rank.""" + return self.gpu_resources.my_pe + + @property + def n_pes(self) -> int: + """Get total number of PEs.""" + return self.gpu_resources.n_pes + + @property + def device(self): + """Get CUDA device.""" + return self.gpu_resources.device + + @property + def initialized(self) -> bool: + """Check if service is initialized.""" + return self.gpu_resources.initialized + + def init(self, log_level: str = "INFO") -> None: + """ + Initialize the service. + + Sets up NVSHMEM, CUDA device, streams, buffers, and kernels. + Expects to be launched with torchrun. + + Args: + log_level: Logging level (TRACE, DEBUG, INFO, WARN, ERROR) + """ + # Initialize GPU resources (NVSHMEM, device, streams) + self.gpu_resources.init() + + # Initialize logger after PE ID is known + PELogger.init(self.my_pe, level=log_level) + PELogger.info(f"Initializing RemoteCopyService on PE {self.my_pe}/{self.n_pes}") + + # Allocate double-buffered send/recv slots + self.buffer_manager.allocate() + PELogger.debug("Allocated double-buffered send/recv slots") + + # Load CUDA kernels + self.kernel_launcher.load_kernels() + PELogger.debug("Loaded CUDA kernels") + + # Cache CuPy stream wrappers for efficient kernel launching + self.kernel_launcher.set_streams( + self.gpu_resources.pack_stream, + self.gpu_resources.unpack_stream, + ) + PELogger.debug("Cached CuPy stream wrappers") + + # Create pipeline executor with dependencies + self.pipeline_executor = PipelineExecutor( + self.kernel_launcher, + self.buffer_manager, + self.my_pe, + ) + + # Set streams on pipeline executor + self.pipeline_executor.set_streams( + self.gpu_resources.pack_stream, + self.gpu_resources.unpack_stream, + self.gpu_resources.send_stream, + self.gpu_resources.copy_stream, + self.gpu_resources.torch_pack_stream, + self.gpu_resources.torch_unpack_stream, + self.gpu_resources.torch_copy_stream, + ) + PELogger.info("Initialization complete") + + def register_send( + self, + task_id: int, + src_tensor, + src_pos: int, + size: int, + dest_pe: int, + ) -> None: + """ + Register a send operation. + + Args: + task_id: Unique task identifier + src_tensor: Source tensor (PyTorch/CuPy tensor or pointer) + src_pos: Starting position in source tensor + size: Number of bytes to send + dest_pe: Destination PE rank + """ + if dest_pe >= self.n_pes or dest_pe < 0: + print(f"Error: Invalid destination PE {dest_pe}") + return + + req = SendRequest(task_id, src_tensor, src_pos, size, dest_pe) + self.send_requests.append(req) + + def register_receive( + self, + task_id: int, + dest_tensor, + dest_pos: int, + size: int, + src_pe: int, + ) -> None: + """ + Register a receive operation. + + Args: + task_id: Unique task identifier + dest_tensor: Destination tensor (PyTorch/CuPy tensor or pointer) + dest_pos: Starting position in destination tensor + size: Number of bytes to receive + src_pe: Source PE rank + """ + if src_pe >= self.n_pes or src_pe < 0: + print(f"Error: Invalid source PE {src_pe}") + return + + req = ReceiveRequest(task_id, dest_tensor, dest_pos, size, src_pe) + self.receive_requests.append(req) + + def schedule(self) -> None: + """ + Build execution schedule. + + Can be called once and followed by multiple run() calls for + repeated execution with the same communication pattern. + + Steps: + 1. Segment large tasks into manageable chunks + 2. Pack tasks into batches + 3. Schedule batches to iterations (conflict-free) + 4. Build GPU execution plans (pointer arrays, chunking) + 5. Create synchronization events + """ + if not self.initialized: + raise RuntimeError("RemoteCopyService not initialized") + + PELogger.info( + f"Starting schedule: {len(self.send_requests)} send requests, " + f"{len(self.receive_requests)} receive requests" + ) + + # Step 1: Segment tasks (break large tasks into chunks) + PELogger.debug("Step 1: Segmenting tasks...") + orig_send_count = len(self.send_requests) + orig_recv_count = len(self.receive_requests) + self._segment_tasks() + PELogger.info( + f"Segmented: {orig_send_count} sends → {len(self.send_requests)} segments, " + f"{orig_recv_count} recvs → {len(self.receive_requests)} segments" + ) + + # Step 2: Pack tasks into workload groups + PELogger.debug("Step 2: Packing workloads...") + workloads = self.workload_packer.pack_workloads( + self.send_requests, + self.n_pes, + ) + total_batches = sum(len(batches) for batches in workloads.values()) + active_pes = sum(1 for batches in workloads.values() if batches) + PELogger.info( + f"Packed: {total_batches} batches across {active_pes} destination PEs" + ) + + # Step 3: Schedule workloads to iterations + PELogger.debug("Step 3: Building communication schedule...") + schedule, global_summaries = self.comm_scheduler.build_schedule( + workloads, + self.my_pe, + self.n_pes, + group=self.gpu_resources.pg, + ) + + self.num_iterations = self.comm_scheduler.num_iterations + PELogger.info( + f"Scheduled: {total_batches} batches → {self.num_iterations} iterations" + ) + + # Step 4: Prepare iteration schedules + PELogger.debug("Step 4: Preparing iteration schedules...") + self.iter_schedules = self._prepare_iter_schedules( + schedule, + workloads, + global_summaries, + self.num_iterations, + ) + + # Step 5: Build GPU execution plans + PELogger.debug("Step 5: Building GPU execution plans...") + self.gpu_planner.create_gpu_plans( + self.iter_schedules, + self.buffer_manager.send_slots, + self.buffer_manager.recv_slots, + self.receive_requests, + ) + + # Step 6: Create double-buffered events + PELogger.debug("Step 6: Creating synchronization events...") + self.pack_events, self.unpack_events = self.gpu_resources.create_events( + num_events=2 + ) + self.pipeline_executor.set_events(self.pack_events, self.unpack_events) + + PELogger.info( + f"Schedule complete: {self.num_iterations} iterations ready" + ) + + def run(self) -> None: + """ + Execute the scheduled communication. + + Can be called multiple times after a single schedule() call + to repeat the same communication pattern. + """ + + if not self.initialized: + raise RuntimeError("RemoteCopyService not initialized") + if self.iter_schedules is None: + raise RuntimeError("Must call schedule() before run()") + + PELogger.info(f"Starting execution: {self.num_iterations} iterations") + + # Start timing + nvtx.range_push("RemoteCopyService.run_total") + + # Global barrier before execution + PELogger.debug("Barrier: Synchronizing all PEs before execution") + nvshmem.core.barrier_all(stream=self.gpu_resources.send_stream) + self.gpu_resources.send_stream.sync() + + # Execute pipelined communication + nvtx.range_push("execute_pipeline") + self.pipeline_executor.execute_pipeline( + self.iter_schedules, + self.num_iterations, + ) + nvtx.range_pop() # execute_pipeline + + # Global barrier after execution + PELogger.debug("Barrier: Synchronizing all PEs after pipeline") + nvshmem.core.barrier_all(stream=self.gpu_resources.send_stream) + + # Process same-PE transfers + self.pipeline_executor.process_self_moves( + self.send_requests, + self.receive_requests, + ) + + # End timing range + nvtx.range_pop() # RemoteCopyService.run_total + + def clear_requests(self) -> None: + """ + Clear registered requests and schedule. + + Call this before registering a new set of transfers. + """ + self.send_requests = [] + self.receive_requests = [] + self.iter_schedules = None + self.num_iterations = 0 + self.pack_events = [] + self.unpack_events = [] + + def finalize(self) -> None: + """Cleanup resources.""" + PELogger.info("Finalizing RemoteCopyService") + + # Barrier to ensure all PEs are ready to finalize + try: + PELogger.debug("Barrier: Synchronizing all PEs before finalize") + nvshmem.core.barrier_all(stream=self.gpu_resources.send_stream) + self.gpu_resources.send_stream.sync() + except Exception as e: # pragma: no cover - defensive logging + PELogger.error(f"Error in final barrier: {e}") + + # Free buffers + self.buffer_manager.free() + + # Finalize GPU resources (this will call nvshmem.core.finalize internally) + self.gpu_resources.finalize() + + PELogger.info("RemoteCopyService finalized") + PELogger.shutdown() + + def _segment_tasks(self) -> None: + """Segment tasks into manageable chunks.""" + new_sends: List[SendRequest] = [] + for req in self.send_requests: + segments = self.task_segmenter.segment_send_request(req) + new_sends.extend(segments) + if len(segments) > 1: + PELogger.debug( + f" Segmented send task {req.task_id}: " + f"{req.size} bytes → {len(segments)} segments" + ) + self.send_requests = new_sends + + new_recvs: List[ReceiveRequest] = [] + for req in self.receive_requests: + segments = self.task_segmenter.segment_receive_request(req) + new_recvs.extend(segments) + if len(segments) > 1: + PELogger.debug( + f" Segmented recv task {req.task_id}: " + f"{req.size} bytes → {len(segments)} segments" + ) + self.receive_requests = new_recvs + + def _prepare_iter_schedules( + self, + schedule_batches: Dict[int, List[ScheduledBatch]], + workloads: Dict[int, List], + global_summaries: Dict[Tuple[int, int, int], WorkloadSummary], + num_iterations: int, + ) -> List[Dict]: + """ + Organize schedule into iteration-based structure. + + Returns: + List of dicts with 'send' and 'recv' keys for each iteration + """ + iter_schedules: List[Dict[str, Optional[ScheduledBatch]]] = [] + + for i in range(num_iterations): + sched: Dict[str, Optional[ScheduledBatch]] = {"send": None, "recv": None} + + if i in schedule_batches: + batches = schedule_batches[i] + + for b in batches: + # Skip same-PE transfers (handled separately by process_self_moves) + if b.src_pe == b.dest_pe: + PELogger.debug( + f" Iter {i}: Skipping same-PE batch " + f"({b.src_pe} → {b.dest_pe})" + ) + continue + + if b.src_pe == self.my_pe: + # This PE sends in this iteration + b.tasks = workloads[b.dest_pe][b.batch_index].tasks + b.total_size = workloads[b.dest_pe][b.batch_index].total_size + sched["send"] = b + PELogger.debug( + f" Iter {i}: Send to PE {b.dest_pe}, batch " + f"{b.batch_index}, {len(b.tasks)} tasks, " + f"{b.total_size} bytes" + ) + + elif b.dest_pe == self.my_pe: + # This PE receives in this iteration + key = (b.src_pe, b.dest_pe, b.batch_index) + if key in global_summaries: + summary = global_summaries[key] + b.tasks_summary = summary + b.total_size = summary.total_size + else: + PELogger.error( + f" Iter {i}: Missing workload summary for " + f"recv from PE {b.src_pe}, batch {b.batch_index}" + ) + PELogger.error( + " Available keys in global_summaries: " + f"{list(global_summaries.keys())}" + ) + b.tasks_summary = None + b.total_size = 0 + sched["recv"] = b + PELogger.debug( + f" Iter {i}: Recv from PE {b.src_pe}, batch " + f"{b.batch_index}, {b.total_size} bytes" + ) + + iter_schedules.append(sched) + + return iter_schedules + + diff --git a/megatron/core/resharding/nvshmem_copy_service/validation.py b/megatron/core/resharding/nvshmem_copy_service/validation.py new file mode 100644 index 00000000000..02d4ddd792e --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/validation.py @@ -0,0 +1,154 @@ +""" +Validation utilities for GPU-to-GPU communication. + +Copied in-tree from the standalone nvshmem_copy_service implementation. +""" + +from dataclasses import dataclass +from typing import List + +import torch + +from .logger import PELogger + + +@dataclass +class ValidationResult: + """Result of validating a single task.""" + + task_id: int + size: int + passed: bool + src_pe: int = -1 + mismatches: int = 0 + first_mismatch_idx: int = -1 + first_mismatch_expected: int = 0 + first_mismatch_actual: int = 0 + # Scheduling info - which batch/iteration this task was supposed to be handled + batch_index: int = -1 + iteration: int = -1 + + +@dataclass +class ValidationSummary: + """Summary of validation across all tasks.""" + + total_tasks: int + passed_tasks: int + failed_tasks: int + total_bytes: int + results: List[ValidationResult] + + @property + def all_passed(self) -> bool: + return self.failed_tasks == 0 + + +def generate_deterministic_data( + task_id: int, + size: int, + device: str = "cuda", +) -> torch.Tensor: + """ + Generate deterministic data pattern for a task. + + Pattern: Each byte = (task_id * 31 + position) % 256 + This creates a unique pattern per task that varies along the data. + + Args: + task_id: Unique task identifier + size: Number of bytes to generate + device: Device to create tensor on ('cuda' or 'cpu') + + Returns: + torch.Tensor of uint8 with deterministic pattern + """ + positions = torch.arange(size, dtype=torch.int64, device=device) + pattern = ((task_id * 31 + positions) % 256).to(torch.uint8) + return pattern + + +def validate_received_data( + task_id: int, + tensor: torch.Tensor, + size: int, + src_pe: int = -1, +) -> ValidationResult: + """ + Validate received data against expected deterministic pattern. + + Args: + task_id: Task identifier to regenerate expected data + tensor: Received tensor to validate + size: Number of bytes to validate + + Returns: + ValidationResult with pass/fail status and details + """ + # Get the data slice to validate + recv_data = tensor[:size] + + # Generate expected pattern on same device + expected = generate_deterministic_data( + task_id, + size, + device=recv_data.device.type, + ) + + # Compare + mismatches_mask = recv_data != expected + num_mismatches = mismatches_mask.sum().item() + + result = ValidationResult( + task_id=task_id, + size=size, + passed=(num_mismatches == 0), + src_pe=src_pe, + mismatches=num_mismatches, + ) + + if num_mismatches > 0: + # Find first mismatch for debugging + first_idx = mismatches_mask.nonzero(as_tuple=True)[0][0].item() + result.first_mismatch_idx = first_idx + result.first_mismatch_expected = expected[first_idx].item() + result.first_mismatch_actual = recv_data[first_idx].item() + + return result + + +def log_validation_summary(summary: ValidationSummary) -> None: + """Log validation summary.""" + if summary.all_passed: + PELogger.info( + "Validation PASSED: %d/%d tasks, %d bytes validated", + summary.passed_tasks, + summary.total_tasks, + summary.total_bytes, + ) + else: + PELogger.error( + "Validation FAILED: %d/%d tasks passed, %d failed", + summary.passed_tasks, + summary.total_tasks, + summary.failed_tasks, + ) + + # Group failures by source PE + failures_by_src = {} + for r in summary.results: + if not r.passed: + failures_by_src.setdefault(r.src_pe, []).append(r) + + PELogger.error(" Failures by source PE:") + for src_pe in sorted(failures_by_src.keys()): + failed_tasks = failures_by_src[src_pe] + task_ids = [r.task_id for r in failed_tasks] + PELogger.error( + " PE %d: %d failed tasks: %s", + src_pe, + len(failed_tasks), + task_ids[:15] if len(task_ids) <= 15 else task_ids[:15] + ["..."], + ) + + diff --git a/megatron/core/resharding/planner.py b/megatron/core/resharding/planner.py index 2deb5f0ec6d..6055a4aa315 100644 --- a/megatron/core/resharding/planner.py +++ b/megatron/core/resharding/planner.py @@ -260,6 +260,11 @@ def build_centralized_reshard_plan( # Build the plan on global rank 0 and broadcast to all ranks if my_global_rank == 0: plans_for_all_ranks = {r: ReshardPlan([], [], []) for r in range(world_size)} + # Global monotonically increasing ID for non-local transfers. + # This is shared between the corresponding send/recv ops so that + # advanced backends (e.g., NVSHMEM) can build richer schedules. + next_task_id = 0 + for dst_rank in range(world_size): dst_rank_params = dst_param_metadata_by_rank.get(dst_rank, {}) for resolved_name, dst_metadata in dst_rank_params.items(): @@ -270,7 +275,9 @@ def build_centralized_reshard_plan( "not found in source model." ) # Choose a representative source metadata with DP round-robin balancing - src_metadata = select_src_metadata_balanced(src_meta_list, dst_metadata, dst_rank) + src_metadata = select_src_metadata_balanced( + src_meta_list, dst_metadata, dst_rank + ) sources = _determine_source_ranks_for_dst_param( resolved_name, src_metadata, dst_metadata, dst_rank ) @@ -280,6 +287,9 @@ def build_centralized_reshard_plan( (dst_metadata.name, None, None, src_slice, dst_slice) ) else: + task_id = next_task_id + next_task_id += 1 + plans_for_all_ranks[dst_rank].recv_ops.append( TransferOp( param_name=dst_metadata.name, @@ -287,6 +297,7 @@ def build_centralized_reshard_plan( is_send=False, my_slice=dst_slice, peer_slice=src_slice, + task_id=task_id, ) ) plans_for_all_ranks[src_rank].send_ops.append( @@ -296,6 +307,7 @@ def build_centralized_reshard_plan( is_send=True, my_slice=src_slice, peer_slice=dst_slice, + task_id=task_id, ) ) plans_list = [plans_for_all_ranks[r] for r in range(world_size)] diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index b55d7e2e13e..06a69a1fc86 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -17,9 +17,10 @@ from .copy_services.base import CopyService from .copy_services.gloo_copy_service import GlooCopyService from .copy_services.nccl_copy_service import NCCLCopyService +from .copy_services.nvshmem_copy_service import NVSHMEMCopyService # Supported refit backend names -RefitBackendName = Literal["nccl", "gloo"] +RefitBackendName = Literal["nccl", "gloo", "nvshmem"] def swap_model_weights( @@ -44,6 +45,9 @@ def swap_model_weights( # Debug / fallback backend: run refit over CPU/Gloo instead of NCCL. service = GlooCopyService() reshard_model_weights(src_model, target_model, service=service) + elif refit_method == "nvshmem": + service = NVSHMEMCopyService() + reshard_model_weights(src_model, target_model, service=service) else: raise ValueError(f"Unknown refit_method '{refit_method}'") else: diff --git a/megatron/core/resharding/utils.py b/megatron/core/resharding/utils.py index 54b05beee63..4f842f52056 100644 --- a/megatron/core/resharding/utils.py +++ b/megatron/core/resharding/utils.py @@ -25,6 +25,11 @@ class TransferOp: my_slice: tuple[slice, ...] # My tensor slice peer_slice: tuple[slice, ...] # Peer's tensor slice (for reference) + # Optional global task identifier for advanced backends (e.g., NVSHMEM) + # When present, this ID is shared between the matching send/recv ops + # across ranks and can be used to build richer communication schedules. + task_id: int | None = None + @dataclass class ParameterMetadata: diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5b70e7e540c..9acd1d5cd27 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2001,9 +2001,11 @@ def _add_rl_args(parser): group.add_argument('--rl-inference-tensor-model-parallel-size', type=int, default=None, help='Degree of tensor model parallelism for inference for RL.') group.add_argument('--refit-method', type=str, default='nccl', - choices=['nccl', 'gloo'], + choices=['nccl', 'gloo', 'nvshmem'], help=('Method to refit the model weights between training and inference models during RL. ' - 'nccl: use NCCLCopyService to refit the model weights between training and inference models during RL.')) + 'nccl: use NCCLCopyService to refit the model weights between training and inference models during RL; ' + 'gloo: use GlooCopyService over CPU; ' + 'nvshmem: use NVSHMEMCopyService to refit using the in-tree NVSHMEM copy service.')) return parser def _add_training_args(parser): diff --git a/tests/unit_tests/inference/test_nccl_model_swap.py b/tests/unit_tests/inference/test_nccl_model_swap.py index 695c44f70b0..cf86d0c3f44 100644 --- a/tests/unit_tests/inference/test_nccl_model_swap.py +++ b/tests/unit_tests/inference/test_nccl_model_swap.py @@ -115,6 +115,7 @@ def _set_pg_collection(module, tp_group, dp_group): return module +@pytest.mark.parametrize("refit_backend", ["nccl", "nvshmem"]) @pytest.mark.parametrize( "src_tp,src_pp,src_ep,dst_tp,dst_pp,dst_ep,num_experts", [ @@ -136,6 +137,7 @@ def _set_pg_collection(module, tp_group, dp_group): ], ) def test_nccl_swap_gpt_parametrized( + refit_backend: str, src_tp: int, src_pp: int, src_ep: int, @@ -248,7 +250,7 @@ def test_nccl_swap_gpt_parametrized( dist.broadcast(ref_logits, src=src_last_pp_rank, group=src_pgs.pp) # Swap weights - swap_model_weights([src_model], [dst_model], refit_method="nccl") + swap_model_weights([src_model], [dst_model], refit_method=refit_backend) # Collect destination logits (parallel_output=False ensures full vocab on last PP stage) dst_logits = torch.empty(batch, seq_len, vocab_size, device=device, dtype=torch.float32) From 5948930e256cb3db70093f797c1d1929d07e3a15 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 7 Dec 2025 17:53:32 -0800 Subject: [PATCH 26/44] fix tests --- .../core/gpu_resource_manager.py | 19 ++---- .../planning/communication_scheduler.py | 10 +-- .../planning/gpu_execution_planner.py | 61 +++++++------------ .../planning/workload_packer.py | 18 ++---- .../nvshmem_copy_service/service.py | 1 - .../inference/test_nccl_model_swap.py | 26 ++++---- 6 files changed, 46 insertions(+), 89 deletions(-) diff --git a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py index 000ddebdb4f..cf3c9139264 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py @@ -22,11 +22,6 @@ def __init__(self): self.n_pes: int = -1 self.initialized: bool = False - # Dedicated torch.distributed process group for NVSHMEM collectives. - # This isolates NVSHMEM's use of collectives from the default WORLD - # group that Megatron and the test harness use for their own ops. - self.pg: Optional[dist.ProcessGroup] = None - # CUDA streams (cuda.core.experimental) self.pack_stream = None self.unpack_stream = None @@ -65,16 +60,10 @@ def init(self) -> None: self.device = Device(local_rank) self.device.set_current() - # Extract rank, nranks from process group + # Extract rank, nranks from the default process group num_ranks = dist.get_world_size() rank_id = dist.get_rank() - # Create a dedicated process group for NVSHMEM collectives. - # Using a private group avoids interfering with Megatron's own - # WORLD-group collectives (e.g., during test setup/teardown), - # which can otherwise trigger "collective mismatch" runtime errors. - self.pg = dist.new_group(ranks=list(range(num_ranks))) - # Create/Broadcast UniqueID using broadcast_object_list uniqueid = nvshmem.core.get_unique_id(empty=True) if rank_id == 0: @@ -83,11 +72,11 @@ def init(self) -> None: else: broadcast_objects = [None] - # Broadcast ID to all ranks - dist.broadcast_object_list(broadcast_objects, src=0, group=self.pg) + # Broadcast ID to all ranks using the default group + dist.broadcast_object_list(broadcast_objects, src=0) # Barrier to ensure everyone has the ID before NVSHMEM init - dist.barrier(group=self.pg) + dist.barrier() # Initialize NVSHMEM with the broadcasted UID nvshmem.core.init( diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py index 6399cc2a393..cc871e2c234 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py @@ -20,7 +20,6 @@ def build_schedule( workloads: Dict[int, List[WorkloadGroup]], my_pe: int, n_pes: int, - group=None, ) -> Tuple[Dict[int, List[ScheduledBatch]], Dict[Tuple[int, int, int], WorkloadSummary]]: """ Main scheduling method. @@ -37,7 +36,7 @@ def build_schedule( # Step 1: Collect all batches across all PE pairs PELogger.debug("Collecting batches from all PEs...") - all_batches = self._collect_all_batches(workloads, my_pe, n_pes, group) + all_batches = self._collect_all_batches(workloads, my_pe, n_pes) PELogger.debug(f"Collected {len(all_batches)} total batches globally") # Step 2: Assign batches to iterations using conflict-free algorithm @@ -52,7 +51,6 @@ def build_schedule( workloads, my_pe, n_pes, - group, ) PELogger.debug(f"Exchanged {len(global_summaries)} workload summaries") @@ -73,7 +71,6 @@ def _collect_all_batches( workloads: Dict[int, List[WorkloadGroup]], my_pe: int, n_pes: int, - group=None, ) -> List[ScheduledBatch]: """ Exchanges batch counts and details with all PEs to build a global view. @@ -94,7 +91,7 @@ def _collect_all_batches( # Gather all batches from all PEs using torch.distributed all_batches_list: List[List[Tuple[int, int, int]] | None] = [None] * n_pes - dist.all_gather_object(all_batches_list, local_batches, group=group) + dist.all_gather_object(all_batches_list, local_batches) # Flatten into global batch list global_batches: List[ScheduledBatch] = [] @@ -158,7 +155,6 @@ def _exchange_workload_summaries( workloads: Dict[int, List[WorkloadGroup]], my_pe: int, n_pes: int, - group=None, ) -> Dict[Tuple[int, int, int], WorkloadSummary]: """ Exchange detailed workload content using torch.distributed. @@ -193,7 +189,7 @@ def _exchange_workload_summaries( all_summaries_list: List[Dict[Tuple[int, int, int], Dict[str, object]] | None] = [ # noqa: E501 None ] * n_pes - dist.all_gather_object(all_summaries_list, local_summaries, group=group) + dist.all_gather_object(all_summaries_list, local_summaries) # Merge into global map global_map: Dict[Tuple[int, int, int], WorkloadSummary] = {} diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py index 22428f320e3..a568906f4c3 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py @@ -46,7 +46,7 @@ def create_gpu_plans( recv_slots: List of receive buffer slots receive_requests: List of all receive requests for matching """ - PELogger.debug("Creating GPU plans for %d iterations", len(iter_schedules)) + PELogger.debug(f"Creating GPU plans for {len(iter_schedules)} iterations") for i, sched in enumerate(iter_schedules): send_batch = sched["send"] if send_batch: @@ -72,16 +72,13 @@ def create_gpu_plans( ) task_ids = [t.task_id for t in send_batch.tasks] PELogger.debug( - " Iter %d send plan: %d tasks → PE %d, %d bytes", - i, - len(send_batch.tasks), - send_batch.dest_pe, - send_batch.total_size, + f" Iter {i} send plan: {len(send_batch.tasks)} tasks → " + f"PE {send_batch.dest_pe}, {send_batch.total_size} bytes" ) - PELogger.debug( - " Send task IDs: %s", - task_ids[:10] if len(task_ids) <= 10 else task_ids[:10] + ["..."], + displayed_ids = ( + task_ids[:10] if len(task_ids) <= 10 else task_ids[:10] + ["..."] ) + PELogger.debug(f" Send task IDs: {displayed_ids}") recv_batch = sched["recv"] if recv_batch: @@ -91,19 +88,15 @@ def create_gpu_plans( # Skip if no summary available (shouldn't happen in normal operation) if summary is None: PELogger.error( - "Iter %d: recv batch from PE %d has no tasks_summary - " - "UNPACK WILL BE SKIPPED!", - i, - recv_batch.src_pe, + f"Iter {i}: recv batch from PE {recv_batch.src_pe} has no " + "tasks_summary - UNPACK WILL BE SKIPPED!" ) recv_batch.gpu_plan = None continue PELogger.debug( - " Iter %d recv from PE %d: %d tasks in summary", - i, - recv_batch.src_pe, - len(summary.task_ids), + f" Iter {i} recv from PE {recv_batch.src_pe}: " + f"{len(summary.task_ids)} tasks in summary" ) ptrs = [] @@ -131,20 +124,14 @@ def create_gpu_plans( else: unmatched_task_ids.append(t_id) PELogger.error( - "Iter %d: Unexpected task %d from PE %d - " - "no matching recv request!", - i, - t_id, - recv_batch.src_pe, + f"Iter {i}: Unexpected task {t_id} from PE " + f"{recv_batch.src_pe} - no matching recv request!" ) if unmatched_task_ids: PELogger.error( - " Iter %d: %d unmatched tasks from PE %d: %s", - i, - len(unmatched_task_ids), - recv_batch.src_pe, - unmatched_task_ids[:10], + f" Iter {i}: {len(unmatched_task_ids)} unmatched tasks " + f"from PE {recv_batch.src_pe}: {unmatched_task_ids[:10]}" ) # Plan kernel args for unpacking @@ -158,26 +145,20 @@ def create_gpu_plans( if recv_batch.gpu_plan is None: PELogger.error( - " Iter %d recv plan: FAILED - no gpu_plan created for %d " - "tasks from PE %d", - i, - len(sizes), - recv_batch.src_pe, + f" Iter {i} recv plan: FAILED - no gpu_plan created for " + f"{len(sizes)} tasks from PE {recv_batch.src_pe}" ) else: PELogger.debug( - " Iter %d recv plan: %d tasks ← PE %d, %d bytes", - i, - len(sizes), - recv_batch.src_pe, - recv_batch.total_size, + f" Iter {i} recv plan: {len(sizes)} tasks ← " + f"PE {recv_batch.src_pe}, {recv_batch.total_size} bytes" ) - PELogger.debug( - " Recv task IDs: %s", + displayed_recv_ids = ( matched_task_ids[:10] if len(matched_task_ids) <= 10 - else matched_task_ids[:10] + ["..."], + else matched_task_ids[:10] + ["..."] ) + PELogger.debug(f" Recv task IDs: {displayed_recv_ids}") def _plan_kernel_args( self, diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py b/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py index d6643220498..b4cdffb7767 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py @@ -40,17 +40,12 @@ def pack_workloads( if workloads[dest_pe]: total_size = sum(b.total_size for b in workloads[dest_pe]) PELogger.debug( - " Dest PE %d: %d tasks → %d batches, %d bytes total", - dest_pe, - len(tasks), - len(workloads[dest_pe]), - total_size, + f" Dest PE {dest_pe}: {len(tasks)} tasks → " + f"{len(workloads[dest_pe])} batches, {total_size} bytes total" ) else: PELogger.debug( - " Dest PE %d: %d tasks → 0 batches (empty after packing)", - dest_pe, - len(tasks), + f" Dest PE {dest_pe}: {len(tasks)} tasks → 0 batches (empty after packing)" ) return workloads @@ -81,11 +76,8 @@ def _pack_single_destination( [str(t.task_id) for t in current_batch.tasks[:10]] ) PELogger.debug( - " Packed batch to PE %d idx %d: %s... (total %d tasks)", - dest_pe, - len(batches) - 1, - task_first_10_string, - len(current_batch.tasks), + f" Packed batch to PE {dest_pe} idx {len(batches) - 1}: " + f"{task_first_10_string}... (total {len(current_batch.tasks)} tasks)" ) # Start new batch current_batch = WorkloadGroup( diff --git a/megatron/core/resharding/nvshmem_copy_service/service.py b/megatron/core/resharding/nvshmem_copy_service/service.py index 55e3b212d53..9d2056901fa 100644 --- a/megatron/core/resharding/nvshmem_copy_service/service.py +++ b/megatron/core/resharding/nvshmem_copy_service/service.py @@ -230,7 +230,6 @@ def schedule(self) -> None: workloads, self.my_pe, self.n_pes, - group=self.gpu_resources.pg, ) self.num_iterations = self.comm_scheduler.num_iterations diff --git a/tests/unit_tests/inference/test_nccl_model_swap.py b/tests/unit_tests/inference/test_nccl_model_swap.py index cf86d0c3f44..781c1c8fc76 100644 --- a/tests/unit_tests/inference/test_nccl_model_swap.py +++ b/tests/unit_tests/inference/test_nccl_model_swap.py @@ -114,8 +114,8 @@ def _set_pg_collection(module, tp_group, dp_group): module.pg_collection = types.SimpleNamespace(tp=tp_group, dp=dp_group, ep=None, pp=None) return module - -@pytest.mark.parametrize("refit_backend", ["nccl", "nvshmem"]) +#"nvshmem" +@pytest.mark.parametrize("refit_backend", ["nvshmem"]) @pytest.mark.parametrize( "src_tp,src_pp,src_ep,dst_tp,dst_pp,dst_ep,num_experts", [ @@ -123,17 +123,17 @@ def _set_pg_collection(module, tp_group, dp_group): (2, 1, 1, 1, 1, 1, None), # TP2 -> TP1 (1, 1, 1, 2, 1, 1, None), # TP1 -> TP2 # PP only changes - (1, 2, 1, 1, 1, 1, None), # PP2 -> PP1 - (1, 1, 1, 1, 2, 1, None), # PP1 -> PP2 - # Both TP and PP change - (2, 2, 1, 1, 1, 1, None), # TP2,PP2 -> TP1,PP1 - (1, 1, 1, 2, 2, 1, None), # TP1,PP1 -> TP2,PP2 - (2, 1, 1, 1, 2, 1, None), # TP2,PP1 -> TP1,PP2 - (1, 2, 1, 2, 1, 1, None), # TP1,PP2 -> TP2,PP1 - (1, 1, 2, 1, 1, 4, 4), # EP2 -> EP4 - (1, 1, 2, 1, 1, 1, 4), - (1, 1, 1, 1, 1, 2, 4), - (1, 1, 2, 1, 2, 2, 4), + # (1, 2, 1, 1, 1, 1, None), # PP2 -> PP1 + # (1, 1, 1, 1, 2, 1, None), # PP1 -> PP2 + # # Both TP and PP change + # (2, 2, 1, 1, 1, 1, None), # TP2,PP2 -> TP1,PP1 + # (1, 1, 1, 2, 2, 1, None), # TP1,PP1 -> TP2,PP2 + # (2, 1, 1, 1, 2, 1, None), # TP2,PP1 -> TP1,PP2 + # (1, 2, 1, 2, 1, 1, None), # TP1,PP2 -> TP2,PP1 + # (1, 1, 2, 1, 1, 4, 4), # EP2 -> EP4 + # (1, 1, 2, 1, 1, 1, 4), + # (1, 1, 1, 1, 1, 2, 4), + # (1, 1, 2, 1, 2, 2, 4), ], ) def test_nccl_swap_gpt_parametrized( From 6098e5b22304a9b76b5dbaa721a6124402bdfa51 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 8 Dec 2025 08:02:44 -0800 Subject: [PATCH 27/44] clean up --- .../core/gpu_resource_manager.py | 2 +- .../nvshmem_copy_service/core/kernel_launcher.py | 16 ++++++++-------- .../resharding/nvshmem_copy_service/logger.py | 13 ++++++++++++- .../memory/tensor_pointer_utils.py | 14 +++++++++++++- .../planning/communication_scheduler.py | 6 +++--- .../resharding/nvshmem_copy_service/service.py | 8 ++++---- .../nvshmem_copy_service/validation.py | 5 +++-- 7 files changed, 44 insertions(+), 20 deletions(-) diff --git a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py index cf3c9139264..2e95b7f75a4 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py @@ -10,7 +10,7 @@ import nvshmem.core import torch import torch.distributed as dist -from cuda.core.experimental import Device, system # type: ignore[attr-defined] +from cuda.core.experimental import Device, system class GPUResourceManager: diff --git a/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py b/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py index a67564fc659..042f2c81608 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py @@ -83,15 +83,15 @@ def launch_pack( cp_src, cp_dst, cp_sizes, num_chunks = gpu_plan # Grid/Block configuration - threads_per_block = 1024 - num_blocks = 75 + THREADS_PER_BLOCK = 1024 + NUM_BLOCKS = 75 # Launch kernel using cached CuPy stream assert self.chunked_copy_kernel is not None assert self.cp_pack_stream is not None self.chunked_copy_kernel( - (num_blocks,), - (threads_per_block,), + (NUM_BLOCKS,), + (THREADS_PER_BLOCK,), (cp_src, cp_dst, cp_sizes, num_chunks), stream=self.cp_pack_stream, ) @@ -126,15 +126,15 @@ def launch_unpack( cp_src, cp_dst, cp_sizes, num_chunks = gpu_plan # Grid/Block configuration - threads_per_block = 1024 - num_blocks = 75 + THREADS_PER_BLOCK = 1024 + NUM_BLOCKS = 75 # Launch kernel using cached CuPy stream assert self.chunked_copy_kernel is not None assert self.cp_unpack_stream is not None self.chunked_copy_kernel( - (num_blocks,), - (threads_per_block,), + (NUM_BLOCKS,), + (THREADS_PER_BLOCK,), (cp_src, cp_dst, cp_sizes, num_chunks), stream=self.cp_unpack_stream, ) diff --git a/megatron/core/resharding/nvshmem_copy_service/logger.py b/megatron/core/resharding/nvshmem_copy_service/logger.py index d4516c5761f..3523f3dd5b4 100644 --- a/megatron/core/resharding/nvshmem_copy_service/logger.py +++ b/megatron/core/resharding/nvshmem_copy_service/logger.py @@ -1,8 +1,19 @@ """ + Per-PE Logger with colored console and file output. -Copied in-tree from the standalone nvshmem_copy_service implementation. + + +Similar to the C++ Logger implementation, provides: + +- Per-PE colored console output + +- Per-PE file logging + +- Support for TRACE, DEBUG, INFO, SUMMARY, WARN, ERROR levels + """ +#TODO(Peter): We need to remove this logger and use the regular Megatron logger. import logging import os diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py b/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py index 60e651aa998..f39dbb0ae95 100644 --- a/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py +++ b/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py @@ -22,12 +22,24 @@ def get_pointer(tensor: Any) -> int: Returns: int: Memory address of the tensor data + + Examples: + + >>> import torch + + >>> t = torch.zeros(100, device='cuda') + + >>> ptr = TensorPointerExtractor.get_pointer(t) + + >>> isinstance(ptr, int) + + True """ if isinstance(tensor, torch.Tensor): return tensor.data_ptr() elif hasattr(tensor, "data"): # CuPy array return tensor.data.ptr else: # Assume raw integer pointer - return int(tensor) + return tensor diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py index cc871e2c234..d70eb559ce5 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py @@ -186,7 +186,7 @@ def _exchange_workload_summaries( ) # Gather all summaries from all PEs using torch.distributed - all_summaries_list: List[Dict[Tuple[int, int, int], Dict[str, object]] | None] = [ # noqa: E501 + all_summaries_list: List[Dict[Tuple[int, int, int], Dict[str, object]] | None] = [ None ] * n_pes dist.all_gather_object(all_summaries_list, local_summaries) @@ -199,8 +199,8 @@ def _exchange_workload_summaries( for key, data in pe_summaries.items(): summary = WorkloadSummary( total_size=int(data["total_size"]), - task_ids=list(data["task_ids"]), # type: ignore[arg-type] - task_sizes=list(data["task_sizes"]), # type: ignore[arg-type] + task_ids=list(data["task_ids"]), + task_sizes=list(data["task_sizes"]), ) global_map[key] = summary diff --git a/megatron/core/resharding/nvshmem_copy_service/service.py b/megatron/core/resharding/nvshmem_copy_service/service.py index 9d2056901fa..f545b1ca7f1 100644 --- a/megatron/core/resharding/nvshmem_copy_service/service.py +++ b/megatron/core/resharding/nvshmem_copy_service/service.py @@ -1,9 +1,9 @@ """ Remote Copy Service - Main orchestrator for NVSHMEM-based GPU-to-GPU transfers. -This is an in-tree copy of the standalone Python implementation from -`kan/mcore-reshard/nvshmem_copy_service/python/service.py`, with imports -updated to use the Megatron package layout. +This service coordinates task segmentation, workload packing, scheduling, + +GPU resource management, and pipelined execution. """ from typing import List, Dict, Tuple, Optional @@ -332,7 +332,7 @@ def finalize(self) -> None: PELogger.debug("Barrier: Synchronizing all PEs before finalize") nvshmem.core.barrier_all(stream=self.gpu_resources.send_stream) self.gpu_resources.send_stream.sync() - except Exception as e: # pragma: no cover - defensive logging + except Exception as e: PELogger.error(f"Error in final barrier: {e}") # Free buffers diff --git a/megatron/core/resharding/nvshmem_copy_service/validation.py b/megatron/core/resharding/nvshmem_copy_service/validation.py index 02d4ddd792e..f2197b7067f 100644 --- a/megatron/core/resharding/nvshmem_copy_service/validation.py +++ b/megatron/core/resharding/nvshmem_copy_service/validation.py @@ -1,8 +1,9 @@ """ Validation utilities for GPU-to-GPU communication. -Copied in-tree from the standalone nvshmem_copy_service implementation. -""" +Provides deterministic data generation and validation for verifying + +correctness of communication operations.""" from dataclasses import dataclass from typing import List From 368cf3c6fa8927db4a176d0355c663acd13cc584 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 12 Dec 2025 13:24:05 -0800 Subject: [PATCH 28/44] fix nvshmem, all backends working in tests --- .../copy_services/gloo_copy_service.py | 63 +++++++++++++++++-- .../copy_services/nccl_copy_service.py | 58 +++++++++++++++-- .../copy_services/nvshmem_copy_service.py | 58 +++++++++++++++-- megatron/core/resharding/execution.py | 25 ++------ .../nvshmem_copy_service/service.py | 3 + megatron/core/resharding/planner.py | 60 +++++++----------- megatron/core/resharding/utils.py | 12 +--- .../inference/test_nccl_model_swap.py | 25 ++++---- 8 files changed, 210 insertions(+), 94 deletions(-) diff --git a/megatron/core/resharding/copy_services/gloo_copy_service.py b/megatron/core/resharding/copy_services/gloo_copy_service.py index af70c33d5bd..ebdc05e8bde 100644 --- a/megatron/core/resharding/copy_services/gloo_copy_service.py +++ b/megatron/core/resharding/copy_services/gloo_copy_service.py @@ -17,6 +17,7 @@ class SendOp: """Simple container describing a single send operation.""" + task_id: int | None tensor: torch.Tensor dest_rank: int @@ -25,6 +26,7 @@ class SendOp: class RecvOp: """Simple container describing a single receive operation.""" + task_id: int | None tensor: torch.Tensor src_rank: int @@ -41,16 +43,24 @@ def __init__(self): self.gloo_pg = dist.new_group(backend="gloo") self.send_ops: List[SendOp] = [] self.recv_ops: List[Tuple[RecvOp, torch.Tensor]] = [] + self._copy_stream = torch.cuda.Stream() logger.info(f"GlooCopyService initialized on rank {self.rank} with {self.world_size} ranks") def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): - self.send_ops.append(SendOp(tensor=src_tensor, dest_rank=dest_rank)) + self.send_ops.append(SendOp(task_id=None, tensor=src_tensor, dest_rank=dest_rank)) + + def submit_send_with_id(self, task_id: int, src_tensor: torch.Tensor, dest_rank: int): + self.send_ops.append(SendOp(task_id=task_id, tensor=src_tensor, dest_rank=dest_rank)) def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): # Allocate a CPU buffer that matches the destination view; we'll # copy into dest_tensor after the Gloo recv completes. cpu_buffer = torch.empty_like(dest_tensor, device="cpu").contiguous() - self.recv_ops.append((RecvOp(tensor=cpu_buffer, src_rank=src_rank), dest_tensor)) + self.recv_ops.append((RecvOp(task_id=None, tensor=cpu_buffer, src_rank=src_rank), dest_tensor)) + + def submit_recv_with_id(self, task_id: int, dest_tensor: torch.Tensor, src_rank: int): + cpu_buffer = torch.empty_like(dest_tensor, device="cpu").contiguous() + self.recv_ops.append((RecvOp(task_id=task_id, tensor=cpu_buffer, src_rank=src_rank), dest_tensor)) def run(self): total_ops = len(self.send_ops) + len(self.recv_ops) @@ -61,12 +71,52 @@ def run(self): p2p_ops: List[dist.P2POp] = [] + # Short-circuit self transfers into local device copies. + local_sends = [op for op in self.send_ops if op.dest_rank == self.rank] + remote_sends = [op for op in self.send_ops if op.dest_rank != self.rank] + local_recvs = [(recv, dst) for (recv, dst) in self.recv_ops if recv.src_rank == self.rank] + remote_recvs = [(recv, dst) for (recv, dst) in self.recv_ops if recv.src_rank != self.rank] + + if local_sends or local_recvs: + local_sends_by_id = {op.task_id: op for op in local_sends} + if None in local_sends_by_id: + raise RuntimeError( + "GlooCopyService: local send missing task_id; " + "use submit_send_with_id/submit_recv_with_id for local copies" + ) + local_recvs_by_id = {recv.task_id: (recv, dst) for (recv, dst) in local_recvs} + if None in local_recvs_by_id: + raise RuntimeError( + "GlooCopyService: local recv missing task_id; " + "use submit_send_with_id/submit_recv_with_id for local copies" + ) + if len(local_sends_by_id) != len(local_sends) or len(local_recvs_by_id) != len( + local_recvs + ): + raise RuntimeError( + f"GlooCopyService: unmatched local ops on rank {self.rank}: " + f"{len(local_sends)} local sends vs {len(local_recvs)} local recvs" + ) + for task_id, (recv_op, dst_tensor) in local_recvs_by_id.items(): + send_op = local_sends_by_id.get(task_id) + if send_op is None: + raise RuntimeError( + f"GlooCopyService: missing local send for task_id={task_id} " + f"on rank {self.rank}" + ) + with torch.no_grad(): + src_tensor = send_op.tensor + if dst_tensor.device != src_tensor.device: + dst_tensor.copy_(src_tensor.to(dst_tensor.device)) + else: + dst_tensor.copy_(src_tensor) + # Build Gloo P2P ops over CPU tensors. For sends we clone to CPU; # for recvs we use the preallocated CPU buffers. - for op in self.send_ops: + for op in remote_sends: cpu_tensor = op.tensor.detach().to("cpu").contiguous() p2p_ops.append(dist.P2POp(dist.isend, cpu_tensor, op.dest_rank, group=self.gloo_pg)) - for recv, _dst_tensor in self.recv_ops: + for recv, _dst_tensor in remote_recvs: p2p_ops.append(dist.P2POp(dist.irecv, recv.tensor, recv.src_rank, group=self.gloo_pg)) if p2p_ops: @@ -75,12 +125,15 @@ def run(self): req.wait() # Copy received CPU buffers back into the original destination tensors. - for recv, dst_tensor in self.recv_ops: + for recv, dst_tensor in remote_recvs: if dst_tensor.is_cuda: dst_tensor.copy_(recv.tensor.to(dst_tensor.device)) else: dst_tensor.copy_(recv.tensor) + if self._copy_stream is not None: + torch.cuda.current_stream().wait_stream(self._copy_stream) + logger.info("GlooCopyService: batched communication completed") self.send_ops.clear() self.recv_ops.clear() diff --git a/megatron/core/resharding/copy_services/nccl_copy_service.py b/megatron/core/resharding/copy_services/nccl_copy_service.py index fe02d108550..678a03cbf1b 100644 --- a/megatron/core/resharding/copy_services/nccl_copy_service.py +++ b/megatron/core/resharding/copy_services/nccl_copy_service.py @@ -17,6 +17,7 @@ class SendOp: """Simple container describing a single NCCL send operation.""" + task_id: int | None tensor: torch.Tensor dest_rank: int @@ -25,6 +26,7 @@ class SendOp: class RecvOp: """Simple container describing a single NCCL receive operation.""" + task_id: int | None tensor: torch.Tensor src_rank: int @@ -40,13 +42,22 @@ def __init__(self): self.world_size = dist.get_world_size() self.send_ops: List[SendOp] = [] self.recv_ops: List[RecvOp] = [] + # Dedicated stream for local (same-rank) copies to avoid unnecessary + # serialization with work on the default stream. + self._copy_stream = torch.cuda.Stream() logger.info(f"NCCLCopyService initialized with {self.world_size} ranks") def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): - self.send_ops.append(SendOp(tensor=src_tensor, dest_rank=dest_rank)) + self.send_ops.append(SendOp(task_id=None, tensor=src_tensor, dest_rank=dest_rank)) + + def submit_send_with_id(self, task_id: int, src_tensor: torch.Tensor, dest_rank: int): + self.send_ops.append(SendOp(task_id=task_id, tensor=src_tensor, dest_rank=dest_rank)) def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): - self.recv_ops.append(RecvOp(tensor=dest_tensor, src_rank=src_rank)) + self.recv_ops.append(RecvOp(task_id=None, tensor=dest_tensor, src_rank=src_rank)) + + def submit_recv_with_id(self, task_id: int, dest_tensor: torch.Tensor, src_rank: int): + self.recv_ops.append(RecvOp(task_id=task_id, tensor=dest_tensor, src_rank=src_rank)) def run(self): total_ops = len(self.send_ops) + len(self.recv_ops) @@ -57,10 +68,46 @@ def run(self): total_ops, ) + local_sends = [op for op in self.send_ops if op.dest_rank == self.rank] + remote_sends = [op for op in self.send_ops if op.dest_rank != self.rank] + local_recvs = [op for op in self.recv_ops if op.src_rank == self.rank] + remote_recvs = [op for op in self.recv_ops if op.src_rank != self.rank] + + if local_sends or local_recvs: + local_sends_by_id = {op.task_id: op for op in local_sends} + if None in local_sends_by_id: + raise RuntimeError( + "NCCLCopyService: local send missing task_id; " + "use submit_send_with_id/submit_recv_with_id for local copies" + ) + local_recvs_by_id = {op.task_id: op for op in local_recvs} + if None in local_recvs_by_id: + raise RuntimeError( + "NCCLCopyService: local recv missing task_id; " + "use submit_send_with_id/submit_recv_with_id for local copies" + ) + if len(local_sends_by_id) != len(local_sends) or len(local_recvs_by_id) != len( + local_recvs + ): + raise RuntimeError( + f"NCCLCopyService: unmatched local ops on rank {self.rank}: " + f"{len(local_sends)} local sends vs {len(local_recvs)} local recvs" + ) + for task_id, recv_op in local_recvs_by_id.items(): + send_op = local_sends_by_id.get(task_id) + if send_op is None: + raise RuntimeError( + f"NCCLCopyService: missing local send for task_id={task_id} " + f"on rank {self.rank}" + ) + with torch.no_grad(): + with torch.cuda.stream(self._copy_stream): + recv_op.tensor.copy_(send_op.tensor) + p2p_ops = [] - for op in self.send_ops: + for op in remote_sends: p2p_ops.append(dist.P2POp(dist.isend, op.tensor, op.dest_rank)) - for op in self.recv_ops: + for op in remote_recvs: p2p_ops.append(dist.P2POp(dist.irecv, op.tensor, op.src_rank)) if p2p_ops: @@ -68,6 +115,9 @@ def run(self): for req in reqs: req.wait() + # Make sure the copy stream is finished + torch.cuda.current_stream().wait_stream(self._copy_stream) + logger.info("Batched communication completed") self.send_ops.clear() self.recv_ops.clear() diff --git a/megatron/core/resharding/copy_services/nvshmem_copy_service.py b/megatron/core/resharding/copy_services/nvshmem_copy_service.py index e5abb31dbd5..b3e46deef6b 100644 --- a/megatron/core/resharding/copy_services/nvshmem_copy_service.py +++ b/megatron/core/resharding/copy_services/nvshmem_copy_service.py @@ -32,13 +32,16 @@ def __init__(self): "torch.distributed must be initialized before NVSHMEMCopyService()" ) + self.rank = dist.get_rank() self._remote = RemoteCopyService() # Lazily initialized on first use to avoid side effects at import time self._initialized = False - # Internal bookkeeping of registration calls before schedule/run - self._next_task_id: int = 0 - self._registered_pairs: List[Tuple[int, torch.Tensor, torch.Tensor, int]] = [] + # NOTE: keep the original typed tensors here (not uint8 views) so local copies + # preserve shape/strides semantics and avoid byte-offset pitfalls. + self._local_send_ops: Dict[int, torch.Tensor] = {} + self._local_recv_ops: Dict[int, torch.Tensor] = {} + self._local_copy_stream = torch.cuda.Stream() logger.info("NVSHMEMCopyService constructed") @@ -90,6 +93,11 @@ def submit_send_with_id( if not src_tensor.is_contiguous(): src_tensor = src_tensor.contiguous() + # Local transfers: keep them out of RemoteCopyService entirely. + if dest_rank == self.rank: + self._local_send_ops[task_id] = src_tensor + return + num_bytes = src_tensor.numel() * src_tensor.element_size() src_bytes = src_tensor.view(torch.uint8) @@ -97,7 +105,7 @@ def submit_send_with_id( "NVSHMEMCopyService: register_send task_id=%d, %d bytes (%d → %d)", task_id, num_bytes, - dist.get_rank(), + self.rank, dest_rank, ) @@ -122,6 +130,11 @@ def submit_recv_with_id( if not dest_tensor.is_contiguous(): dest_tensor = dest_tensor.contiguous() + # Local transfers: keep them out of RemoteCopyService entirely. + if src_rank == self.rank: + self._local_recv_ops[task_id] = dest_tensor + return + num_bytes = dest_tensor.numel() * dest_tensor.element_size() dst_bytes = dest_tensor.view(torch.uint8) @@ -129,7 +142,7 @@ def submit_recv_with_id( "NVSHMEMCopyService: register_recv task_id=%d, %d bytes (%d ← %d)", task_id, num_bytes, - dist.get_rank(), + self.rank, src_rank, ) @@ -149,8 +162,41 @@ def run(self): requests, builds a schedule, runs the pipelined NVSHMEM transfer, and then clears internal state. """ - # Execute schedule built from submit_send_with_id/submit_recv_with_id self._ensure_initialized() + + # 1) Run same-rank copies (match by task_id), like NCCL backend. + if self._local_send_ops or self._local_recv_ops: + missing_sends = set(self._local_recv_ops.keys()) - set(self._local_send_ops.keys()) + missing_recvs = set(self._local_send_ops.keys()) - set(self._local_recv_ops.keys()) + if missing_sends or missing_recvs: + raise RuntimeError( + "NVSHMEMCopyService: unmatched local ops on rank " + f"{self.rank}: missing_sends={sorted(list(missing_sends))[:10]} " + f"missing_recvs={sorted(list(missing_recvs))[:10]}" + ) + + with torch.no_grad(): + with torch.cuda.stream(self._local_copy_stream): + for task_id, dst in self._local_recv_ops.items(): + src = self._local_send_ops[task_id] + if src.numel() != dst.numel() or src.element_size() != dst.element_size(): + raise RuntimeError( + "NVSHMEMCopyService: local copy size mismatch on rank " + f"{self.rank} task_id={task_id}: " + f"src=({tuple(src.shape)}, {src.dtype}) " + f"dst=({tuple(dst.shape)}, {dst.dtype})" + ) + dst.copy_(src, non_blocking=True) + + torch.cuda.current_stream().wait_stream(self._local_copy_stream) + self._local_send_ops.clear() + self._local_recv_ops.clear() + + # 2) Execute remote schedule (if any remote sends/recvs were registered). + if not self._remote.send_requests and not self._remote.receive_requests: + logger.info("NVSHMEMCopyService: no remote requests; local copies complete") + return + logger.info("NVSHMEMCopyService: building NVSHMEM schedule and executing") self._remote.schedule() self._remote.run() diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py index eb7e7dfef83..f911bbbff8a 100644 --- a/megatron/core/resharding/execution.py +++ b/megatron/core/resharding/execution.py @@ -28,29 +28,16 @@ def execute_reshard_plan( src_params = {name: p for name, p in src_module.named_parameters(recurse=True)} dst_params = {name: p for name, p in dst_module.named_parameters(recurse=True)} - - # TODO(Peter) do this on like a separate stream? - # Execute local copies - for param_name, src_param, dst_param, src_slice, dst_slice in plan.local_copy_ops: - if src_param is None: - src_param = src_params.get(param_name) - if dst_param is None: - dst_param = dst_params.get(param_name) - if src_param is not None and dst_param is not None: - with torch.no_grad(): - src_view = src_param.data[src_slice] - dst_view = dst_param.data[dst_slice] - dst_view.copy_(src_view) - - is_nvshmem = isinstance(service, NVSHMEMCopyService) + submit_send_with_id = getattr(service, "submit_send_with_id", None) + submit_recv_with_id = getattr(service, "submit_recv_with_id", None) # Submit sends for op in plan.send_ops: src_param = src_params.get(op.param_name) if src_param is not None: src_view = src_param.data[op.my_slice].contiguous() - if is_nvshmem and op.task_id is not None: - service.submit_send_with_id(op.task_id, src_view, op.peer_rank) + if submit_send_with_id is not None and op.task_id is not None: + submit_send_with_id(op.task_id, src_view, op.peer_rank) else: service.submit_send(src_view, op.peer_rank) @@ -61,8 +48,8 @@ def execute_reshard_plan( if dst_param is not None: dst_slice_view = dst_param.data[op.my_slice] recv_buffer = torch.empty_like(dst_slice_view.contiguous()) - if is_nvshmem and op.task_id is not None: - service.submit_recv_with_id(op.task_id, recv_buffer, op.peer_rank) + if submit_recv_with_id is not None and op.task_id is not None: + submit_recv_with_id(op.task_id, recv_buffer, op.peer_rank) else: service.submit_recv(recv_buffer, op.peer_rank) recv_writebacks.append((recv_buffer, dst_param, op.my_slice)) diff --git a/megatron/core/resharding/nvshmem_copy_service/service.py b/megatron/core/resharding/nvshmem_copy_service/service.py index f545b1ca7f1..fff5cdd092e 100644 --- a/megatron/core/resharding/nvshmem_copy_service/service.py +++ b/megatron/core/resharding/nvshmem_copy_service/service.py @@ -273,6 +273,9 @@ def run(self) -> None: Can be called multiple times after a single schedule() call to repeat the same communication pattern. """ + #import torch + #torch.save(self.send_requests, f"send_requests_{torch.distributed.get_rank()}.pt") + #torch.save(self.receive_requests, f"receive_requests_{torch.distributed.get_rank()}.pt") if not self.initialized: raise RuntimeError("RemoteCopyService not initialized") diff --git a/megatron/core/resharding/planner.py b/megatron/core/resharding/planner.py index 6055a4aa315..f446df1481d 100644 --- a/megatron/core/resharding/planner.py +++ b/megatron/core/resharding/planner.py @@ -259,10 +259,10 @@ def build_centralized_reshard_plan( # Build the plan on global rank 0 and broadcast to all ranks if my_global_rank == 0: - plans_for_all_ranks = {r: ReshardPlan([], [], []) for r in range(world_size)} + plans_for_all_ranks = {r: ReshardPlan([], []) for r in range(world_size)} # Global monotonically increasing ID for non-local transfers. # This is shared between the corresponding send/recv ops so that - # advanced backends (e.g., NVSHMEM) can build richer schedules. + # NVSHMEM can build schedule. next_task_id = 0 for dst_rank in range(world_size): @@ -282,50 +282,38 @@ def build_centralized_reshard_plan( resolved_name, src_metadata, dst_metadata, dst_rank ) for src_rank, src_slice, dst_slice in sources: - if src_rank == dst_rank and src_metadata.name == dst_metadata.name: - plans_for_all_ranks[dst_rank].local_copy_ops.append( - (dst_metadata.name, None, None, src_slice, dst_slice) + task_id = next_task_id + next_task_id += 1 + + plans_for_all_ranks[dst_rank].recv_ops.append( + TransferOp( + param_name=dst_metadata.name, + peer_rank=src_rank, + is_send=False, + my_slice=dst_slice, + peer_slice=src_slice, + task_id=task_id, ) - else: - task_id = next_task_id - next_task_id += 1 - - plans_for_all_ranks[dst_rank].recv_ops.append( - TransferOp( - param_name=dst_metadata.name, - peer_rank=src_rank, - is_send=False, - my_slice=dst_slice, - peer_slice=src_slice, - task_id=task_id, - ) - ) - plans_for_all_ranks[src_rank].send_ops.append( - TransferOp( - param_name=src_metadata.name, - peer_rank=dst_rank, - is_send=True, - my_slice=src_slice, - peer_slice=dst_slice, - task_id=task_id, - ) + ) + plans_for_all_ranks[src_rank].send_ops.append( + TransferOp( + param_name=src_metadata.name, + peer_rank=dst_rank, + is_send=True, + my_slice=src_slice, + peer_slice=dst_slice, + task_id=task_id, ) + ) plans_list = [plans_for_all_ranks[r] for r in range(world_size)] else: plans_list = [None] * world_size torch.distributed.broadcast_object_list(plans_list, src=0) my_plan = plans_list[my_global_rank] - # Fill in actual parameter references for local copies - for i, (param_name, _, _, src_slice, dst_slice) in enumerate(my_plan.local_copy_ops): - src_param = my_src_params.get(param_name) - dst_param = my_dst_params.get(param_name) - if src_param is not None and dst_param is not None: - my_plan.local_copy_ops[i] = (param_name, src_param, dst_param, src_slice, dst_slice) - logger.info( f"Rank {my_global_rank}: Received plan - {len(my_plan.recv_ops)} recvs, " - f"{len(my_plan.send_ops)} sends, {len(my_plan.local_copy_ops)} local copies" + f"{len(my_plan.send_ops)} sends" ) return my_plan diff --git a/megatron/core/resharding/utils.py b/megatron/core/resharding/utils.py index 4f842f52056..0a5658f5eef 100644 --- a/megatron/core/resharding/utils.py +++ b/megatron/core/resharding/utils.py @@ -80,20 +80,10 @@ class ReshardPlan: send_ops: list[TransferOp] recv_ops: list[TransferOp] - local_copy_ops: list[ - tuple[ - str, - torch.nn.Parameter | None, - torch.nn.Parameter | None, - tuple[slice, ...], - tuple[slice, ...], - ] - ] # (name, src_param, dst_param, src_slice, dst_slice) def __str__(self): return ( - f"ReshardPlan(sends={len(self.send_ops)}, recvs={len(self.recv_ops)}, " - f"local_copies={len(self.local_copy_ops)})" + f"ReshardPlan(sends={len(self.send_ops)}, recvs={len(self.recv_ops)})" ) diff --git a/tests/unit_tests/inference/test_nccl_model_swap.py b/tests/unit_tests/inference/test_nccl_model_swap.py index 781c1c8fc76..a3aaf38d95f 100644 --- a/tests/unit_tests/inference/test_nccl_model_swap.py +++ b/tests/unit_tests/inference/test_nccl_model_swap.py @@ -114,8 +114,7 @@ def _set_pg_collection(module, tp_group, dp_group): module.pg_collection = types.SimpleNamespace(tp=tp_group, dp=dp_group, ep=None, pp=None) return module -#"nvshmem" -@pytest.mark.parametrize("refit_backend", ["nvshmem"]) +@pytest.mark.parametrize("refit_backend", ["nvshmem","nccl","gloo"]) @pytest.mark.parametrize( "src_tp,src_pp,src_ep,dst_tp,dst_pp,dst_ep,num_experts", [ @@ -123,17 +122,17 @@ def _set_pg_collection(module, tp_group, dp_group): (2, 1, 1, 1, 1, 1, None), # TP2 -> TP1 (1, 1, 1, 2, 1, 1, None), # TP1 -> TP2 # PP only changes - # (1, 2, 1, 1, 1, 1, None), # PP2 -> PP1 - # (1, 1, 1, 1, 2, 1, None), # PP1 -> PP2 - # # Both TP and PP change - # (2, 2, 1, 1, 1, 1, None), # TP2,PP2 -> TP1,PP1 - # (1, 1, 1, 2, 2, 1, None), # TP1,PP1 -> TP2,PP2 - # (2, 1, 1, 1, 2, 1, None), # TP2,PP1 -> TP1,PP2 - # (1, 2, 1, 2, 1, 1, None), # TP1,PP2 -> TP2,PP1 - # (1, 1, 2, 1, 1, 4, 4), # EP2 -> EP4 - # (1, 1, 2, 1, 1, 1, 4), - # (1, 1, 1, 1, 1, 2, 4), - # (1, 1, 2, 1, 2, 2, 4), + (1, 2, 1, 1, 1, 1, None), # PP2 -> PP1 + (1, 1, 1, 1, 2, 1, None), # PP1 -> PP2 + # Both TP and PP change + (2, 2, 1, 1, 1, 1, None), # TP2,PP2 -> TP1,PP1 + (1, 1, 1, 2, 2, 1, None), # TP1,PP1 -> TP2,PP2 + (2, 1, 1, 1, 2, 1, None), # TP2,PP1 -> TP1,PP2 + (1, 2, 1, 2, 1, 1, None), # TP1,PP2 -> TP2,PP1 + (1, 1, 2, 1, 1, 4, 4), # EP2 -> EP4 + (1, 1, 2, 1, 1, 1, 4), # EP2 -> EP1 + (1, 1, 1, 1, 1, 2, 4), + (1, 1, 2, 1, 2, 2, 4), ], ) def test_nccl_swap_gpt_parametrized( From ed449fef3d06dab0653008cabe1e2cdf947e907d Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Dec 2025 09:57:48 -0800 Subject: [PATCH 29/44] fix test --- megatron/core/resharding/execution.py | 33 ++++++++++++++++- .../test_model_swap.py} | 37 ++++++++++++++++--- 2 files changed, 63 insertions(+), 7 deletions(-) rename tests/unit_tests/{inference/test_nccl_model_swap.py => resharding/test_model_swap.py} (92%) diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py index f911bbbff8a..cd24bdd3ea6 100644 --- a/megatron/core/resharding/execution.py +++ b/megatron/core/resharding/execution.py @@ -25,12 +25,43 @@ def execute_reshard_plan( A communication service must be provided to abstract transport. Expected service API: submit_send(tensor, dest_rank), submit_recv(tensor, src_rank), run(). """ - + my_rank = dist.get_rank() + src_params = {name: p for name, p in src_module.named_parameters(recurse=True)} dst_params = {name: p for name, p in dst_module.named_parameters(recurse=True)} submit_send_with_id = getattr(service, "submit_send_with_id", None) submit_recv_with_id = getattr(service, "submit_recv_with_id", None) + # DEBUG: Print plan summary - focus on QKV weight for GQA debugging + if my_rank == 0: + print(f"\n[Rank {my_rank}] ========== RESHARD PLAN DEBUG ==========") + print(f"[Rank {my_rank}] Total send_ops: {len(plan.send_ops)}, recv_ops: {len(plan.recv_ops)}") + + # Show QKV weight ops specifically (these are critical for GQA) + print(f"[Rank {my_rank}] QKV WEIGHT RECV ops:") + for op in plan.recv_ops: + if "linear_qkv.weight" in op.param_name: + dst_param = dst_params.get(op.param_name) + dst_shape = list(dst_param.shape) if dst_param is not None else "NOT FOUND" + print(f"[Rank {my_rank}] {op.param_name} slice={op.my_slice} <- rank {op.peer_rank}, dst_shape={dst_shape}") + + # Show QKV bias ops + print(f"[Rank {my_rank}] QKV BIAS RECV ops:") + for op in plan.recv_ops: + if "linear_qkv.bias" in op.param_name: + dst_param = dst_params.get(op.param_name) + dst_shape = list(dst_param.shape) if dst_param is not None else "NOT FOUND" + print(f"[Rank {my_rank}] {op.param_name} slice={op.my_slice} <- rank {op.peer_rank}, dst_shape={dst_shape}") + + # Show linear_proj weight ops (row parallel - different sharding pattern) + print(f"[Rank {my_rank}] LINEAR_PROJ WEIGHT RECV ops:") + for op in plan.recv_ops: + if "linear_proj.weight" in op.param_name: + dst_param = dst_params.get(op.param_name) + dst_shape = list(dst_param.shape) if dst_param is not None else "NOT FOUND" + print(f"[Rank {my_rank}] {op.param_name} slice={op.my_slice} <- rank {op.peer_rank}, dst_shape={dst_shape}") + print(f"[Rank {my_rank}] ========== END RESHARD PLAN DEBUG ==========\n") + # Submit sends for op in plan.send_ops: src_param = src_params.get(op.param_name) diff --git a/tests/unit_tests/inference/test_nccl_model_swap.py b/tests/unit_tests/resharding/test_model_swap.py similarity index 92% rename from tests/unit_tests/inference/test_nccl_model_swap.py rename to tests/unit_tests/resharding/test_model_swap.py index a3aaf38d95f..30d05e87eed 100644 --- a/tests/unit_tests/inference/test_nccl_model_swap.py +++ b/tests/unit_tests/resharding/test_model_swap.py @@ -2,7 +2,7 @@ import copy import os import types -from typing import Optional, Tuple +from typing import Optional, Tuple, List import pytest import torch @@ -21,9 +21,18 @@ from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.cuda_graphs import CudaGraphManager, _CudagraphGlobalRecord from tests.unit_tests.test_utilities import Utils +try: + import nvshmem.core + has_nvshmem = True +except Exception: + has_nvshmem = False + + + def _build_pg_collection( tp_size: int, pp_size: int = None, ep_size: int = 1 ) -> ProcessGroupCollection: @@ -114,17 +123,31 @@ def _set_pg_collection(module, tp_group, dp_group): module.pg_collection = types.SimpleNamespace(tp=tp_group, dp=dp_group, ep=None, pp=None) return module -@pytest.mark.parametrize("refit_backend", ["nvshmem","nccl","gloo"]) +@pytest.mark.parametrize( + "refit_backend", + [ + pytest.param( + "nvshmem", + marks=pytest.mark.skipif( + not has_nvshmem, + reason="nvshmem.core is not available (NVSHMEM Python bindings not installed)", + ), + ), + "nccl", + "gloo", + ], +) @pytest.mark.parametrize( "src_tp,src_pp,src_ep,dst_tp,dst_pp,dst_ep,num_experts", [ # TP only changes (2, 1, 1, 1, 1, 1, None), # TP2 -> TP1 (1, 1, 1, 2, 1, 1, None), # TP1 -> TP2 - # PP only changes + (2, 1, 1, 4, 1, 1, None), # TP2 -> TP4 + # # PP only changes (1, 2, 1, 1, 1, 1, None), # PP2 -> PP1 (1, 1, 1, 1, 2, 1, None), # PP1 -> PP2 - # Both TP and PP change + # # Both TP and PP change (2, 2, 1, 1, 1, 1, None), # TP2,PP2 -> TP1,PP1 (1, 1, 1, 2, 2, 1, None), # TP1,PP1 -> TP2,PP2 (2, 1, 1, 1, 2, 1, None), # TP2,PP1 -> TP1,PP2 @@ -135,7 +158,7 @@ def _set_pg_collection(module, tp_group, dp_group): (1, 1, 2, 1, 2, 2, 4), ], ) -def test_nccl_swap_gpt_parametrized( +def test_swap_gpt_parametrized( refit_backend: str, src_tp: int, src_pp: int, @@ -164,16 +187,18 @@ def test_nccl_swap_gpt_parametrized( # Small GPT config seq_len = 8 vocab_size = 128 + # --group-query-attention --num-query-groups 8 cfg = TransformerConfig( num_layers=4 if (src_pp > 1 or dst_pp > 1) else 2, hidden_size=32, - num_attention_heads=4, + num_attention_heads=8, use_cpu_initialization=True, pipeline_dtype=torch.float32, hidden_dropout=0.0, attention_dropout=0.0, moe_router_dtype="fp64", moe_token_dispatcher_type="alltoall", + num_query_groups=4, ) # Build PGs and models (always use unified PG builder so we can set EP) From 1bbe010d379086fe4f426d0a908b612a238dd733 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Dec 2025 09:58:52 -0800 Subject: [PATCH 30/44] fix execution mistake --- megatron/core/resharding/execution.py | 35 ++------------------------- 1 file changed, 2 insertions(+), 33 deletions(-) diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py index cd24bdd3ea6..99cbf1f0c0f 100644 --- a/megatron/core/resharding/execution.py +++ b/megatron/core/resharding/execution.py @@ -25,43 +25,12 @@ def execute_reshard_plan( A communication service must be provided to abstract transport. Expected service API: submit_send(tensor, dest_rank), submit_recv(tensor, src_rank), run(). """ - my_rank = dist.get_rank() - + src_params = {name: p for name, p in src_module.named_parameters(recurse=True)} dst_params = {name: p for name, p in dst_module.named_parameters(recurse=True)} submit_send_with_id = getattr(service, "submit_send_with_id", None) submit_recv_with_id = getattr(service, "submit_recv_with_id", None) - # DEBUG: Print plan summary - focus on QKV weight for GQA debugging - if my_rank == 0: - print(f"\n[Rank {my_rank}] ========== RESHARD PLAN DEBUG ==========") - print(f"[Rank {my_rank}] Total send_ops: {len(plan.send_ops)}, recv_ops: {len(plan.recv_ops)}") - - # Show QKV weight ops specifically (these are critical for GQA) - print(f"[Rank {my_rank}] QKV WEIGHT RECV ops:") - for op in plan.recv_ops: - if "linear_qkv.weight" in op.param_name: - dst_param = dst_params.get(op.param_name) - dst_shape = list(dst_param.shape) if dst_param is not None else "NOT FOUND" - print(f"[Rank {my_rank}] {op.param_name} slice={op.my_slice} <- rank {op.peer_rank}, dst_shape={dst_shape}") - - # Show QKV bias ops - print(f"[Rank {my_rank}] QKV BIAS RECV ops:") - for op in plan.recv_ops: - if "linear_qkv.bias" in op.param_name: - dst_param = dst_params.get(op.param_name) - dst_shape = list(dst_param.shape) if dst_param is not None else "NOT FOUND" - print(f"[Rank {my_rank}] {op.param_name} slice={op.my_slice} <- rank {op.peer_rank}, dst_shape={dst_shape}") - - # Show linear_proj weight ops (row parallel - different sharding pattern) - print(f"[Rank {my_rank}] LINEAR_PROJ WEIGHT RECV ops:") - for op in plan.recv_ops: - if "linear_proj.weight" in op.param_name: - dst_param = dst_params.get(op.param_name) - dst_shape = list(dst_param.shape) if dst_param is not None else "NOT FOUND" - print(f"[Rank {my_rank}] {op.param_name} slice={op.my_slice} <- rank {op.peer_rank}, dst_shape={dst_shape}") - print(f"[Rank {my_rank}] ========== END RESHARD PLAN DEBUG ==========\n") - # Submit sends for op in plan.send_ops: src_param = src_params.get(op.param_name) @@ -96,4 +65,4 @@ def execute_reshard_plan( with torch.no_grad(): dst_param.data[dst_slice].copy_(recv_buffer) - logger.info("Reshard complete") + logger.info("Reshard complete") \ No newline at end of file From 83d8f4aad7d84191c8ebc3fd73ffd2e59362a5af Mon Sep 17 00:00:00 2001 From: William Dykas Date: Wed, 17 Dec 2025 11:42:11 -0800 Subject: [PATCH 31/44] verified with runs --- .../core/extensions/transformer_engine.py | 23 +++- .../abstract_model_inference_wrapper.py | 7 +- .../gpt/gpt_inference_wrapper.py | 6 +- megatron/core/resharding/utils.py | 10 ++ megatron/core/transformer/mlp.py | 6 + megatron/rl/inference/megatron.py | 6 +- megatron/rl/rl_utils.py | 111 ++++++++++++++++++ megatron/training/arguments.py | 10 +- megatron/training/training.py | 10 +- 9 files changed, 170 insertions(+), 19 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index e95409e08e9..f9f0cd0456f 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -469,6 +469,7 @@ def __init__( skip_weight_param_allocation: bool = False, tp_comm_buffer_name: Optional[str] = None, tp_group: Optional[torch.distributed.ProcessGroup] = None, + stride: int = 1, ): if not HAVE_TE: raise ImportError( @@ -559,6 +560,8 @@ def __init__( ), "Must have at least TE version 2.3 or higher to use symmetric memory all reduce" extra_kwargs["symmetric_ar_type"] = self.config.symmetric_ar_type + self.stride = stride + super().__init__( in_features=input_size, out_features=output_size, @@ -583,6 +586,11 @@ def __init__( **extra_kwargs, ) + # Set proper partition_stride + setattr(self.weight, 'partition_stride', stride) + if bias and hasattr(self, 'bias') and self.bias is not None: + setattr(self.bias, 'partition_stride', stride) + if config.use_cpu_initialization: output_size_per_partition = divide(output_size, self.tp_size) _ = _initialize_affine_weight_cpu( @@ -592,7 +600,7 @@ def __init__( output_size_per_partition, 0, init_method=condition_init_method(config, init_method), - stride=1, + stride=stride, return_master_weight=False, rank=self.tp_rank, world_size=self.tp_size, @@ -602,7 +610,7 @@ def __init__( self.bias = Parameter( torch.empty(output_size_per_partition, dtype=config.params_dtype) ) - set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + set_tensor_model_parallel_attributes(self.bias, True, 0, stride) with torch.no_grad(): self.bias.zero_() setattr(self.bias, "allreduce", True) @@ -659,6 +667,7 @@ def __init__( skip_weight_param_allocation: bool = False, tp_comm_buffer_name: Optional[str] = None, tp_group: Optional[torch.distributed.ProcessGroup] = None, + stride: int = 1, ): if not HAVE_TE: raise ImportError( @@ -671,6 +680,7 @@ def __init__( tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert) world_size = get_pg_size(tp_group) rank = get_pg_rank(tp_group) + self.stride = stride super().__init__( input_size=input_size, @@ -691,6 +701,11 @@ def __init__( tp_group=tp_group, ) + # Set proper partition_stride + setattr(self.weight, 'partition_stride', stride) + if bias and hasattr(self, 'bias') and self.bias is not None: + setattr(self.bias, 'partition_stride', stride) + if config.use_cpu_initialization: output_size_per_partition = divide(output_size, world_size) _ = _initialize_affine_weight_cpu( @@ -700,7 +715,7 @@ def __init__( output_size_per_partition, 0, init_method=condition_init_method(config, init_method), - stride=1, + stride=stride, return_master_weight=False, rank=rank, world_size=world_size, @@ -710,7 +725,7 @@ def __init__( self.bias = Parameter( torch.empty(output_size_per_partition, dtype=config.params_dtype) ) - set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + set_tensor_model_parallel_attributes(self.bias, True, 0, stride) with torch.no_grad(): self.bias.zero_() setattr(self.bias, "allreduce", True) diff --git a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py index 95d476a9f83..694de4ffcc5 100644 --- a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py @@ -73,10 +73,7 @@ def __init__( self.inference_context = inference_context if pg_collection is None: - pg_collection = ProcessGroupCollection( - tp=parallel_state.get_tensor_model_parallel_group(), - pp=parallel_state.get_pipeline_model_parallel_group(), - ) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.tp_group = pg_collection.tp self.pp_group = pg_collection.pp @@ -365,7 +362,7 @@ def run_one_forward_step( """ # Check if we are in a PP model if not ( - parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() + is_pipeline_first_stage(self.pp_group) and is_pipeline_last_stage(self.pp_group) ): tokens = inference_input["tokens"] current_batch_size, seq_len = self._get_batch_size_and_seq_len( diff --git a/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py index 430126816a7..ba89fbc2f6c 100644 --- a/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py @@ -12,6 +12,7 @@ ) from megatron.core.inference.utils import get_attention_mask from megatron.core.models.gpt import GPTModel +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.enums import AttnBackend from megatron.core.utils import get_model_config @@ -28,6 +29,8 @@ class GPTInferenceWrapper(AbstractModelInferenceWrapper): size, etc. inference_context (BaseInferenceContext): Manages KV cache, and tracks sequence/token/batch offsets. + pg_collection (ProcessGroupCollection): Process groups for model communication. + If not provided, defaults to global parallel state groups. """ def __init__( @@ -35,8 +38,9 @@ def __init__( model: GPTModel, inference_wrapper_config: InferenceWrapperConfig, inference_context: Optional[BaseInferenceContext] = None, + pg_collection: Optional[ProcessGroupCollection] = None, ): - super().__init__(model, inference_wrapper_config, inference_context) + super().__init__(model, inference_wrapper_config, inference_context, pg_collection) def prep_inference_input(self, prompts_tokens: torch.Tensor) -> Dict[str, Any]: """Prepares the inference input data. diff --git a/megatron/core/resharding/utils.py b/megatron/core/resharding/utils.py index 0a5658f5eef..b188063fa65 100644 --- a/megatron/core/resharding/utils.py +++ b/megatron/core/resharding/utils.py @@ -162,6 +162,16 @@ def extract_param_metadata( is_tp = bool(getattr(param, 'tensor_model_parallel', False)) partition_dim = int(getattr(param, 'partition_dim', 0)) partition_stride = int(getattr(param, 'partition_stride', 1)) + + # SwiGLU/GLU compatibility: For gated linear units, fc1 stores interleaved [gate, up] portions + # and requires partition_stride=2 for correct resharding. New models set this at construction + # time (MLP sets partition_stride=2 on weight when gated_linear_unit=True). For legacy models + # where stride=1 was left as default, we apply stride=2 as a fallback for fc1 parameters. + # This is safe because: (1) gated models need it, and (2) non-gated models have smaller fc1 + # and stride doesn't affect single-block transfers. + # if 'mlp.linear_fc1' in param_name and is_tp and partition_stride == 1: + # partition_stride = 2 + # EP detection: Megatron convention - expert params are not allreduced is_ep = not bool(getattr(param, 'allreduce', True)) diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 9602beb2f71..bd4e0447e59 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -101,8 +101,13 @@ def __init__( # If this is a gated linear unit we double the output width # see https://arxiv.org/pdf/2002.05202.pdf + # For GLU/SwiGLU, use stride=2 because each TP rank stores interleaved [gate, up] portions. + # This is critical for correct weight resharding across different TP sizes. if self.config.gated_linear_unit: ffn_hidden_size *= 2 + fc1_stride = 2 + else: + fc1_stride = 1 self.linear_fc1 = build_module( submodules.linear_fc1, @@ -116,6 +121,7 @@ def __init__( is_expert=is_expert, tp_comm_buffer_name="fc1", tp_group=tp_group, + stride=fc1_stride, ) if self.config.use_te_activation_func and not (submodules.activation_func is None): diff --git a/megatron/rl/inference/megatron.py b/megatron/rl/inference/megatron.py index 54acc112dd9..44471ece33d 100644 --- a/megatron/rl/inference/megatron.py +++ b/megatron/rl/inference/megatron.py @@ -155,10 +155,8 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen metrics_writer=metrics_writer, ) - inference_wrapped_model = GPTInferenceWrapper(model, args, inference_context) - - inference_wrapped_model.model_is_pipeline_parallel = not ( - is_pp_first_stage(pg_collection.pp) and is_pp_last_stage(pg_collection.pp) + inference_wrapped_model = GPTInferenceWrapper( + model, args, inference_context, pg_collection=pg_collection ) text_generation_controller = SimpleTextGenerationController( diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 6d49433ab30..0c3f44cf050 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -64,6 +64,98 @@ # Global variable to store packing context for forward_step _GLOBAL_PACKING_CONTEXT = None + +def verify_model_weights_swap( + train_model: LanguageModule, + inference_model: LanguageModule, + seq_len: int = 8, + batch_size: int = 2, + atol: float = 1e-4, + rtol: float = 1e-4, +) -> None: + """Verify that the inference model produces the same forward pass outputs + as the training model after the weights have been swapped. + + This function should be called after swap_model_weights to ensure the weight + transfer was successful. It runs a forward pass on both models and asserts + the outputs match. This is meant for debugging purposes only. + + Args: + train_model: The training model (source of weights). + inference_model: The inference model (target of weights). + seq_len: Sequence length for test input. + batch_size: Batch size for test input. + atol: Absolute tolerance for comparing outputs. + rtol: Relative tolerance for comparing outputs. + + Raises: + AssertionError: If forward pass outputs do not match within tolerance. + """ + args = get_args() + + # Unwrap models to get the core module + train_lm = train_model[0] if isinstance(train_model, (list, tuple)) else train_model + inf_lm = inference_model[0] if isinstance(inference_model, (list, tuple)) else inference_model + + train_core = unwrap_model(train_lm) + inf_core = unwrap_model(inf_lm) + + actual_vocab_size = getattr(args, 'padded_vocab_size', 128256) + actual_seq_len = min(seq_len, getattr(args, 'seq_length', seq_len)) + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + # Generate deterministic test input - same across ALL ranks + torch.manual_seed(1234) + test_tokens = torch.randint( + low=0, high=actual_vocab_size, size=(batch_size, actual_seq_len), + device=device, dtype=torch.long + ) + test_position_ids = ( + torch.arange(actual_seq_len, device=device, dtype=torch.long) + .unsqueeze(0) + .expand(batch_size, -1) + ) + test_attention_mask = torch.ones( + (batch_size, 1, actual_seq_len, actual_seq_len), device=device, dtype=torch.bool + ) + + # Save and restore training state + train_was_training = train_core.training + inf_was_training = inf_core.training + + train_core.eval() + inf_core.eval() + + try: + with torch.no_grad(): + train_output = train_lm( + test_tokens, test_position_ids, test_attention_mask, + runtime_gather_output=True + ) + + inf_output = inf_lm( + test_tokens, test_position_ids, test_attention_mask, + runtime_gather_output=True + ) + + # Only check on ranks that have output (last PP stage) + if train_output is not None and inf_output is not None: + assert train_output.shape == inf_output.shape, ( + f"Output shape mismatch: train={train_output.shape}, infer={inf_output.shape}" + ) + + max_diff = (train_output - inf_output).abs().max().item() + assert torch.allclose(train_output, inf_output, atol=atol, rtol=rtol), ( + f"Forward pass outputs do not match: max_diff={max_diff:.6e}, atol={atol}, rtol={rtol}" + ) + + finally: + # Restore training state + if train_was_training: + train_core.train() + if inf_was_training: + inf_core.train() + GroupedRollouts = list[list[TokenRollout | Rollout]] @@ -660,6 +752,13 @@ def get_environment_rollouts( # If we have seperate training and inference models we to refit weights from the training model to the inference model. if inference_model is not None: swap_model_weights(model, inference_model, args.refit_method) + if args.rl_verify_model_weights_swap and args.curr_iteration == 0: + verify_model_weights_swap( + train_model=model, + inference_model=inference_model, + atol=.1, + rtol=5e-4, + ) else: inference_model = model @@ -1282,6 +1381,18 @@ def prepare_data_for_update( wandb_writer = get_wandb_writer() tb_writer = get_tensorboard_writer() nvtx_range = get_nvtx_range() + + # RL policy updates + logprob computations should run eagerly; only rollout generation + # (inference engine) should use CUDA graphs until training cuda-graphs MR goes in. + # In the single-model case this is naturally handled by `megatron_rl_inference_mode` + # toggling graphs on/off around inference. In the refit case (separate inference_model), + # we must explicitly keep the training model (this `model`) with CUDA graphs disabled, + # otherwise training/logprobs can get cudagraphed. + if args.cuda_graph_impl != "none": + lang_module = ( + model[0].module.module if hasattr(model[0].module, "module") else model[0].module + ) + toggle_cuda_graphs(lang_module, "none", reset_cuda_graphs=False) model = model[0] with nvtx_range("prepare-data-for-update"): diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 9acd1d5cd27..d505701f3c9 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2000,12 +2000,16 @@ def _add_rl_args(parser): 'round-robin: distribute bins cyclically across ranks for better load balancing') group.add_argument('--rl-inference-tensor-model-parallel-size', type=int, default=None, help='Degree of tensor model parallelism for inference for RL.') - group.add_argument('--refit-method', type=str, default='nccl', + group.add_argument('--refit-method', type=str, default='nvshmem', choices=['nccl', 'gloo', 'nvshmem'], help=('Method to refit the model weights between training and inference models during RL. ' - 'nccl: use NCCLCopyService to refit the model weights between training and inference models during RL; ' + 'nccl: use NCCLCopyService to refit using NCCL; ' 'gloo: use GlooCopyService over CPU; ' - 'nvshmem: use NVSHMEMCopyService to refit using the in-tree NVSHMEM copy service.')) + 'nvshmem: use NVSHMEMCopyService to refit using the NVSHMEM.')) + group.add_argument('--rl-verify-model-weights-swap', action=argparse.BooleanOptionalAction, default=False, + help='If set, verify that the model weights were correctly transferred by comparing forward pass outputs on' + 'the first swap of model weights.') + return parser def _add_training_args(parser): diff --git a/megatron/training/training.py b/megatron/training/training.py index 7f665023376..4b04d505a0f 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -718,8 +718,14 @@ def pretrain( dp_size = args.world_size // (tp_size * cp_size * pp_size) assert dp_size >= 1 and (tp_size * cp_size * pp_size * dp_size) == args.world_size, \ "World size must be divisible by tp*cp*pp for inference PG layout" - - grid = HyperCommGrid([tp_size, cp_size, ep_size, pp_size, dp_size], ["tp", "cp", "ep", "pp", "dp"]) + # Default mpu order is 'tp-cp-ep-dp-pp', unless use_tp_pp_dp_mapping is set. + grid_order = 'tp-cp-ep-pp-dp' if args.use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp' + if args.use_tp_pp_dp_mapping: + # Order: tp-cp-ep-pp-dp (pp before dp) + grid = HyperCommGrid([tp_size, cp_size, ep_size, pp_size, dp_size], ["tp", "cp", "ep", "pp", "dp"]) + else: + # Order: tp-cp-ep-dp-pp (dp before pp) - this is the default + grid = HyperCommGrid([tp_size, cp_size, ep_size, dp_size, pp_size], ["tp", "cp", "ep", "dp", "pp"]) tp_group = grid.create_pg("tp") cp_group = grid.create_pg("cp") pp_group = grid.create_pg("pp") From 6aa223e0856718eb636f2dcd8988ca5bc0a1b929 Mon Sep 17 00:00:00 2001 From: William Dykas Date: Mon, 22 Dec 2025 07:05:02 -0800 Subject: [PATCH 32/44] fix merge --- megatron/core/extensions/transformer_engine.py | 5 ----- megatron/training/training.py | 1 - 2 files changed, 6 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 1a201d645ff..27a5ac6965e 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -839,11 +839,6 @@ def __init__( ) self.te_quant_params: Optional[TEQuantizationParams] = None - # Set proper partition_stride - setattr(self.weight, 'partition_stride', stride) - if bias and hasattr(self, 'bias') and self.bias is not None: - setattr(self.bias, 'partition_stride', stride) - # Set proper partition_stride setattr(self.weight, 'partition_stride', stride) if bias and hasattr(self, 'bias') and self.bias is not None: diff --git a/megatron/training/training.py b/megatron/training/training.py index 6d72a3a746c..905367b8310 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1054,7 +1054,6 @@ def build_model(): # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): - #TODO(Peter) We need to use the proper models MPU here. tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. From a4e7d4f02a227974c909964ec4abc75e0daea068 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 4 Jan 2026 15:30:03 -0800 Subject: [PATCH 33/44] add offload --- megatron/core/inference/unified_memory.py | 288 +++++++++++++++++++++- megatron/rl/rl_utils.py | 49 +++- megatron/training/arguments.py | 31 +++ megatron/training/training.py | 43 +++- 4 files changed, 398 insertions(+), 13 deletions(-) diff --git a/megatron/core/inference/unified_memory.py b/megatron/core/inference/unified_memory.py index 56073df063f..bc0a9c72bde 100644 --- a/megatron/core/inference/unified_memory.py +++ b/megatron/core/inference/unified_memory.py @@ -1,12 +1,15 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +import ctypes import os import signal +import threading import warnings from contextlib import contextmanager from enum import Enum, auto from pathlib import Path +import torch from torch.cuda.memory import CUDAPluggableAllocator from torch.utils.cpp_extension import CUDA_HOME, load_inline @@ -42,6 +45,10 @@ class UnifiedMemoryCompileTimeoutError(UnifiedMemoryUnsupportedError): _compilation_state = CompilationState.UNATTEMPTED _alloc = None # must remain global until process exit. _mod = None # must remain global until process exit. +_so_path = None # path to compiled extension .so (must remain global until exit). +_ctypes_lib = None # ctypes handle to compiled extension +_ctypes_lock = threading.Lock() +_compilation_error: str | None = None # store last failure reason for better error messages @contextmanager @@ -74,11 +81,19 @@ def _handler(signum, frame): def compile_allocator(): """Attempt to compile UVM allocator.""" - global _compilation_state, _alloc, _mod + global _compilation_state, _alloc, _mod, _so_path, _ctypes_lib, _compilation_error if _compilation_state != CompilationState.UNATTEMPTED: return + if not _has_mem_pool: + _compilation_state = CompilationState.FAILURE + _compilation_error = ( + "PyTorch does not expose CUDA MemPool on this build/version. " + "UVM mempool requires torch.cuda.MemPool or torch.cuda.memory.MemPool." + ) + return + _mempool_c_src = r""" #include #include @@ -134,6 +149,46 @@ def compile_allocator(): (void)size; (void)device; (void)stream; if (ptr) cudaFree(ptr); } + + // Prefetch managed memory to a device (or to CPU with cudaCpuDeviceId == -1). + EXPORT int managed_prefetch(void* ptr, size_t size, int device, void* stream) { + cudaStream_t s = (cudaStream_t)stream; + cudaError_t err = cudaMemPrefetchAsync(ptr, (size_t)size, device, s); + return (int)err; + } + + // Update preferred location advice for managed memory (GPU device id, or CPU with cudaCpuDeviceId == -1). + EXPORT int managed_advise_preferred_location(void* ptr, size_t size, int device) { + cudaError_t err; + #if CUDART_VERSION >= 13000 + cudaMemLocation location; + if (device == (int)-1) { + location.type = cudaMemLocationTypeHost; + location.id = 0; + } else { + location.type = cudaMemLocationTypeDevice; + location.id = device; + } + err = cudaMemAdvise(ptr, (size_t)size, cudaMemAdviseSetPreferredLocation, location); + #else + err = cudaMemAdvise(ptr, (size_t)size, cudaMemAdviseSetPreferredLocation, device); + #endif + return (int)err; + } + + // Ensure a device is in the page table for this managed region. + EXPORT int managed_advise_accessed_by(void* ptr, size_t size, int device) { + cudaError_t err; + #if CUDART_VERSION >= 13000 + cudaMemLocation location; + location.type = cudaMemLocationTypeDevice; + location.id = device; + err = cudaMemAdvise(ptr, (size_t)size, cudaMemAdviseSetAccessedBy, location); + #else + err = cudaMemAdvise(ptr, (size_t)size, cudaMemAdviseSetAccessedBy, device); + #endif + return (int)err; + } """ # Define a timeout of 30s for how long the build is allowed to run. @@ -160,14 +215,16 @@ def compile_allocator(): _cpa = CUDAPluggableAllocator(_so_path, "managed_malloc", "managed_free") _alloc = _cpa.allocator() _compilation_state = CompilationState.SUCCESS + _compilation_error = None except (RuntimeError, ImportError, OSError, UnifiedMemoryCompileTimeoutError) as e: + _compilation_error = str(e) warnings.warn(f"Failed to create unified memory mempool: '{e}'.") _compilation_state = CompilationState.FAILURE + _so_path = None + _ctypes_lib = None # Synchronize failure state across ranks. (For currently unknown reasons, # one rank can show as FAILURE while the remaining ranks show as SUCCESS.) - import torch - local_state = torch.tensor( [_compilation_state.value], dtype=torch.uint8, device=torch.cuda.current_device() ) @@ -193,6 +250,229 @@ def create_unified_mempool() -> "MemPool": # Return mempool. if _compilation_state != CompilationState.SUCCESS: - raise UnifiedMemoryUnsupportedError() + details = _compilation_error + if details is None: + details = "Unknown reason (allocator compilation did not succeed)." + raise UnifiedMemoryUnsupportedError( + "Unified virtual memory (UVM) mempool is unsupported or failed to initialize: " + + details + ) else: return MemPool(allocator=_alloc) + + +def _get_ctypes_lib() -> "ctypes.CDLL": + """Return a ctypes handle to the compiled UVM extension (.so).""" + global _ctypes_lib + compile_allocator() + if _compilation_state != CompilationState.SUCCESS or _so_path is None: + raise UnifiedMemoryUnsupportedError() + if _ctypes_lib is not None: + return _ctypes_lib + with _ctypes_lock: + if _ctypes_lib is None: + _ctypes_lib = ctypes.CDLL(_so_path) + # Configure argtypes/restype for exported helpers. + _ctypes_lib.managed_prefetch.argtypes = [ + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int, + ctypes.c_void_p, + ] + _ctypes_lib.managed_prefetch.restype = ctypes.c_int + _ctypes_lib.managed_advise_preferred_location.argtypes = [ + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int, + ] + _ctypes_lib.managed_advise_preferred_location.restype = ctypes.c_int + _ctypes_lib.managed_advise_accessed_by.argtypes = [ + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int, + ] + _ctypes_lib.managed_advise_accessed_by.restype = ctypes.c_int + return _ctypes_lib + + +def prefetch_managed_tensor(tensor, *, device: int, stream=None) -> None: + """Prefetch a CUDA tensor allocated from the UVM mempool to a specific device. + + This uses `cudaMemPrefetchAsync` to physically migrate the pages backing the tensor. + The virtual address (pointer) remains unchanged, making this safe for use with + recorded CUDA graphs. + + Args: + tensor (torch.Tensor): CUDA tensor allocated from the UVM mempool. + device (int): Target device ID. Use -1 (cudaCpuDeviceId) to prefetch to CPU. + stream (torch.cuda.Stream, optional): Stream to use for the asynchronous prefetch. + Defaults to the current stream. + """ + if tensor is None: + return + if not isinstance(tensor, torch.Tensor): + raise TypeError("prefetch_managed_tensor expects a torch.Tensor") + if tensor.numel() == 0: + return + if not tensor.is_cuda: + raise ValueError("prefetch_managed_tensor expects a CUDA tensor") + + lib = _get_ctypes_lib() + nbytes = tensor.nbytes + if stream is None: + stream = torch.cuda.current_stream() + # torch.cuda.Stream exposes a cuda_stream integer handle. + stream_ptr = ctypes.c_void_p(int(stream.cuda_stream)) + err = lib.managed_prefetch( + ctypes.c_void_p(int(tensor.data_ptr())), ctypes.c_size_t(nbytes), int(device), stream_ptr + ) + if err != 0: + raise RuntimeError(f"cudaMemPrefetchAsync failed with cudaError={err}") + + +def advise_managed_tensor_preferred_location(tensor, *, device: int) -> None: + """Set the preferred physical location hint for a managed tensor. + + This uses `cudaMemAdviseSetPreferredLocation`. It tells the CUDA driver where the + pages should ideally reside. Unlike prefetch, this is a hint and does not + immediately trigger migration unless the driver decides it is necessary. + + Args: + tensor (torch.Tensor): CUDA tensor allocated from the UVM mempool. + device (int): Preferred device ID. Use -1 (cudaCpuDeviceId) for CPU. + """ + if tensor is None: + return + if not isinstance(tensor, torch.Tensor): + raise TypeError("advise_managed_tensor_preferred_location expects a torch.Tensor") + if tensor.numel() == 0: + return + if not tensor.is_cuda: + raise ValueError("advise_managed_tensor_preferred_location expects a CUDA tensor") + + lib = _get_ctypes_lib() + nbytes = tensor.nbytes + err = lib.managed_advise_preferred_location( + ctypes.c_void_p(int(tensor.data_ptr())), ctypes.c_size_t(nbytes), int(device) + ) + if err != 0: + raise RuntimeError(f"cudaMemAdviseSetPreferredLocation failed with cudaError={err}") + + +def advise_managed_tensor_accessed_by(tensor, *, device: int) -> None: + """Hint that a specific device will access the managed tensor. + + This uses `cudaMemAdviseSetAccessedBy`. It ensures that the mapping for this + memory region is established in the page tables of the specified device, + reducing page fault latency when the device first touches the data. + + Args: + tensor (torch.Tensor): CUDA tensor allocated from the UVM mempool. + device (int): Device ID that will access the tensor. Must be a GPU ID. + """ + if tensor is None: + return + if not isinstance(tensor, torch.Tensor): + raise TypeError("advise_managed_tensor_accessed_by expects a torch.Tensor") + if tensor.numel() == 0: + return + if not tensor.is_cuda: + raise ValueError("advise_managed_tensor_accessed_by expects a CUDA tensor") + + lib = _get_ctypes_lib() + nbytes = tensor.nbytes + err = lib.managed_advise_accessed_by( + ctypes.c_void_p(int(tensor.data_ptr())), ctypes.c_size_t(nbytes), int(device) + ) + if err != 0: + raise RuntimeError(f"cudaMemAdviseSetAccessedBy failed with cudaError={err}") + + +def prefetch_managed_module_parameters( + module, *, device: int, include_buffers: bool = False +) -> int: + """Prefetch all UVM-allocated parameters (and optionally buffers) of a module. + + Iterates through all parameters of the module and initiates an asynchronous + migration to the target device. This is typically used to offload weights to + CPU during training or prefetch them to GPU before inference. + + Args: + module (torch.nn.Module): The module containing UVM parameters. + device (int): Target device ID (-1 for CPU). + include_buffers (bool, optional): Whether to also prefetch module buffers. + Defaults to False. + + Returns: + int: The total number of bytes for which prefetch was initiated. + """ + if module is None: + return 0 + + # Avoid duplicate prefetch on shared tensors. + seen_ptrs: set[int] = set() + total_nbytes = 0 + + def _iter_tensors(): + for p in module.parameters(recurse=True): + if p is None: + continue + yield p.data + if include_buffers: + for b in module.buffers(recurse=True): + if b is None: + continue + yield b + + stream = torch.cuda.current_stream() + for t in _iter_tensors(): + if not isinstance(t, torch.Tensor) or not t.is_cuda or t.numel() == 0: + continue + ptr = int(t.data_ptr()) + if ptr in seen_ptrs: + continue + seen_ptrs.add(ptr) + nbytes = t.nbytes + prefetch_managed_tensor(t, device=device, stream=stream) + total_nbytes += nbytes + return total_nbytes + + +def advise_managed_module_parameters_preferred_location( + module, *, device: int, include_buffers: bool = False +) -> None: + """Set the preferred physical location hint for all UVM parameters in a module. + + Args: + module (torch.nn.Module): The module containing UVM parameters. + device (int): Preferred device ID (-1 for CPU). + include_buffers (bool, optional): Whether to also advise on module buffers. + Defaults to False. + """ + if module is None: + return + + seen_ptrs: set[int] = set() + for p in module.parameters(recurse=True): + if p is None: + continue + t = p.data + if not isinstance(t, torch.Tensor) or not t.is_cuda or t.numel() == 0: + continue + ptr = int(t.data_ptr()) + if ptr in seen_ptrs: + continue + seen_ptrs.add(ptr) + advise_managed_tensor_preferred_location(t, device=device) + + if include_buffers: + for b in module.buffers(recurse=True): + if b is None: + continue + if not isinstance(b, torch.Tensor) or not b.is_cuda or b.numel() == 0: + continue + ptr = int(b.data_ptr()) + if ptr in seen_ptrs: + continue + seen_ptrs.add(ptr) + advise_managed_tensor_preferred_location(b, device=device) diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 94e12f7945d..3039d045aca 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -38,6 +38,10 @@ from megatron.core.transformer.cuda_graphs import _CudagraphGlobalRecord from megatron.core.transformer.utils import toggle_cuda_graphs from megatron.core.resharding.refit import swap_model_weights +from megatron.core.inference.unified_memory import ( + advise_managed_module_parameters_preferred_location, + prefetch_managed_module_parameters, +) from megatron.core.utils import get_asyncio_loop, log_single_rank from megatron.rl.sequence_packing_utils import ( get_microbatch_dataloader, @@ -71,7 +75,12 @@ get_wandb_writer, ) from megatron.training.tokenizer.tokenizer import CustomTikTokenizer, _HuggingFaceTokenizer -from megatron.training.utils import get_ltor_masks_and_position_ids, get_nvtx_range, print_rank_0, unwrap_model +from megatron.training.utils import ( + get_ltor_masks_and_position_ids, + get_nvtx_range, + print_rank_0, + unwrap_model, +) from megatron.core.utils import get_pg_size, get_attr_wrapped_model from megatron.core.process_groups_config import ProcessGroupCollection from wandb import wandb_run @@ -85,6 +94,29 @@ _GLOBAL_PACKING_CONTEXT = None +def _maybe_prefetch_separate_inference_model_weights(model_core, *, to_cpu: bool) -> None: + """Prefetch RL *separate inference model* weights to CPU/GPU (UVM-only path). + + Gated only by user args; this assumes the separate inference model was allocated with UVM when enabled. + """ + args = get_args() + if not args.rl_offload_inference_model_weights_when_idle: + return + if args.rl_inference_model_unified_memory_level != 1: + return + + device = -1 if to_cpu else int(torch.cuda.current_device()) + advise_managed_module_parameters_preferred_location(model_core, device=device, include_buffers=True) + nbytes = prefetch_managed_module_parameters(model_core, device=device, include_buffers=True) + # Ensure pages are resident before we enter CUDA-graph capture / inference, or before training continues. + torch.cuda.synchronize() + + if to_cpu: + print_rank_0(f"[Rank 0] offloaded {nbytes / 1024**2:.2f} MB of separate RL inference model weights to CPU (other ranks may vary)") + else: + print_rank_0(f"[Rank 0] prefetched {nbytes / 1024**2:.2f} MB of separate RL inference model weights to GPU (other ranks may vary)") + + def verify_model_weights_swap( train_model: LanguageModule, inference_model: LanguageModule, @@ -428,6 +460,11 @@ def get_environment_rollouts( # If we have seperate training and inference models we to refit weights from the training model to the inference model. if inference_model is not None: + # If the separate inference model weights were prefetched to CPU while idle, bring them + # back to GPU before refit/copy and before any CUDA-graph'd inference. + with nvtx_range("prefetch-inference-model-weights-to-gpu"): + inf_core = unwrap_model(inference_model[0]) + _maybe_prefetch_separate_inference_model_weights(inf_core, to_cpu=False) swap_model_weights(model, inference_model, args.refit_method) if args.rl_verify_model_weights_swap and args.curr_iteration == 0: verify_model_weights_swap( @@ -1535,6 +1572,11 @@ def megatron_rl_inference_mode( lang_module = model[0].module.module if hasattr(model[0].module, "module") else model[0].module lang_module.eval() + # If this is a separate RL inference model allocated with UVM, ensure weights are resident on GPU + # before any CUDA-graph capture/replay or inference. + with nvtx_range("prefetch-inference-model-weights-to-gpu"): + model_core = unwrap_model(model[0]) + _maybe_prefetch_separate_inference_model_weights(model_core, to_cpu=False) rotary_module = getattr(lang_module, "rotary_pos_emb", None) # Vanilla RotaryEmbedding module has lru_cache decorator which breaks RL training @@ -1602,6 +1644,11 @@ def megatron_rl_inference_mode( if cuda_graph_impl != "none": toggle_cuda_graphs(lang_module, 'none', reset_cuda_graphs=reset_cuda_graphs) + # If this is a separate RL inference model, prefetch weights back to CPU so they don't consume + # GPU memory during training. + with nvtx_range("prefetch-inference-model-weights-to-cpu"): + _maybe_prefetch_separate_inference_model_weights(model_core, to_cpu=True) + if offload_optimizer_during_inference: with nvtx_range("onload-optimizer-after-inference"): optimizer.restore_from_cpu() diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 012cf4062a6..cff2d64716b 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -393,6 +393,14 @@ def validate_args(args, defaults={}): assert not (args.rl_partial_rollouts and args.rl_remove_kv_cache_during_training), \ "Cannot use both partial-rollouts and remove-kv-cache-during-training" + assert not ( + args.rl_offload_inference_model_weights_when_idle + and args.rl_inference_model_unified_memory_level != 1 + ), ( + "--rl-offload-inference-model-weights-when-idle requires " + "--rl-inference-model-unified-memory-level=1." + ) + args.grpo_samples_per_iteration = args.grpo_prompts_per_step * args.grpo_group_size num_generated_samples_per_inference_iteration = ( args.grpo_samples_per_iteration * args.grpo_iterations) @@ -2081,6 +2089,29 @@ def _add_rl_args(parser): 'round-robin: distribute bins cyclically across ranks for better load balancing') group.add_argument('--rl-inference-tensor-model-parallel-size', type=int, default=None, help='Degree of tensor model parallelism for inference for RL.') + group.add_argument( + '--rl-inference-model-unified-memory-level', + type=int, + default=0, + choices=[0, 1], + help=( + 'Allocate the separate RL inference model parameters from a unified virtual memory (UVM) ' + 'CUDA mempool. Level 0 disables UVM (default). Level 1 enables UVM allocation so the ' + 'inference model weights can be prefetched to CPU when idle while keeping CUDA-graph-safe ' + 'device pointers.' + ), + ) + group.add_argument( + '--rl-offload-inference-model-weights-when-idle', + action=argparse.BooleanOptionalAction, + required=False, + default=False, + help=( + 'When using a separate RL inference model with UVM-enabled parameters, prefetch its weights ' + 'to CPU when not doing rollout inference, and prefetch back to GPU right before inference. ' + 'Requires --rl-inference-model-unified-memory-level=1.' + ), + ) group.add_argument('--refit-method', type=str, default='nvshmem', choices=['nccl', 'gloo', 'nvshmem'], help=('Method to refit the model weights between training and inference models during RL. ' diff --git a/megatron/training/training.py b/megatron/training/training.py index 905367b8310..de51a204003 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -13,6 +13,7 @@ import os import sys from typing import Any, Optional +from contextlib import nullcontext import torch.distributed @@ -78,7 +79,10 @@ from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel as megatron_FSDP from megatron.core.optimizer.optimizer import param_group_identifier_keys from megatron.core.transformer.custom_layers.batch_invariant_kernels import enable_batch_invariant_mode - +from megatron.core.inference.unified_memory import ( + advise_managed_module_parameters_preferred_location, + prefetch_managed_module_parameters, + ) from megatron.core.optimizer.qk_clip import clip_qk try: @@ -112,6 +116,7 @@ destroy_model_parallel, update_pg_timeout ) +from megatron.core.inference.unified_memory import create_unified_mempool from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.num_microbatches_calculator import ( @@ -788,16 +793,38 @@ def pretrain( # Build an isolated inference config so training config remains unchanged inference_config = copy.deepcopy(config) inference_config.tensor_model_parallel_size = args.rl_inference_tensor_model_parallel_size - - inference_model = get_model( - model_provider, - model_type, - wrap_with_ddp=False, - pg_collection=inference_pg_collection, - config=inference_config, + + # Optionally allocate the RL inference model weights from a unified virtual memory (UVM) + # mempool so we can prefetch weights to CPU when idle while keeping CUDA-graph-safe pointers. + uvm_mempool = None + uvm_level = args.rl_inference_model_unified_memory_level + if uvm_level and uvm_level > 0: + uvm_mempool = create_unified_mempool() + + mempool_ctx = ( + torch.cuda.use_mem_pool(uvm_mempool) if uvm_mempool is not None else nullcontext() ) + with mempool_ctx: + inference_model = get_model( + model_provider, + model_type, + wrap_with_ddp=False, + pg_collection=inference_pg_collection, + config=inference_config, + ) inference_model[0].eval() + # If requested, immediately prefetch weights to CPU to keep them off GPU when idle. + if ( + uvm_mempool is not None + and args.rl_offload_inference_model_weights_when_idle + ): + inference_core = unwrap_model(inference_model[0]) + advise_managed_module_parameters_preferred_location(inference_core, device=-1, include_buffers=True) + nbytes = prefetch_managed_module_parameters(inference_core, device=-1, include_buffers=True) + torch.cuda.synchronize() + print_rank_0(f"[Rank 0] initially offloaded {nbytes / 1024**2:.2f} MB of separate RL inference model weights to CPU (other ranks may vary)") + # Data stuff. app_metrics['app_build_dataiters_start_time'] = one_logger_utils.get_timestamp_in_ms() From 23ec73dbd33d107e0a66a65d28c9b551fce4597c Mon Sep 17 00:00:00 2001 From: root Date: Sun, 4 Jan 2026 15:47:09 -0800 Subject: [PATCH 34/44] lint --- .../copy_services/gloo_copy_service.py | 8 +- .../copy_services/nvshmem_copy_service.py | 36 ++------ megatron/core/resharding/execution.py | 2 +- .../nvshmem_copy_service/__init__.py | 11 +-- .../nvshmem_copy_service/core/__init__.py | 2 - .../core/gpu_resource_manager.py | 15 +--- .../core/kernel_launcher.py | 15 +--- .../core/pipeline_executor.py | 46 +++------- .../resharding/nvshmem_copy_service/logger.py | 5 +- .../nvshmem_copy_service/memory/__init__.py | 2 - .../memory/double_buffer_manager.py | 8 +- .../memory/tensor_pointer_utils.py | 4 +- .../nvshmem_copy_service/nvshmem_types.py | 4 +- .../nvshmem_copy_service/planning/__init__.py | 9 +- .../planning/communication_scheduler.py | 49 +++------- .../planning/gpu_execution_planner.py | 26 ++---- .../planning/task_segmenter.py | 4 +- .../planning/workload_packer.py | 24 ++--- .../nvshmem_copy_service/service.py | 90 +++++-------------- .../nvshmem_copy_service/validation.py | 19 +--- megatron/core/resharding/planner.py | 4 +- megatron/core/resharding/utils.py | 8 +- .../unit_tests/resharding/test_model_swap.py | 12 +-- 23 files changed, 98 insertions(+), 305 deletions(-) diff --git a/megatron/core/resharding/copy_services/gloo_copy_service.py b/megatron/core/resharding/copy_services/gloo_copy_service.py index ebdc05e8bde..ee27531caf6 100644 --- a/megatron/core/resharding/copy_services/gloo_copy_service.py +++ b/megatron/core/resharding/copy_services/gloo_copy_service.py @@ -56,11 +56,15 @@ def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): # Allocate a CPU buffer that matches the destination view; we'll # copy into dest_tensor after the Gloo recv completes. cpu_buffer = torch.empty_like(dest_tensor, device="cpu").contiguous() - self.recv_ops.append((RecvOp(task_id=None, tensor=cpu_buffer, src_rank=src_rank), dest_tensor)) + self.recv_ops.append( + (RecvOp(task_id=None, tensor=cpu_buffer, src_rank=src_rank), dest_tensor) + ) def submit_recv_with_id(self, task_id: int, dest_tensor: torch.Tensor, src_rank: int): cpu_buffer = torch.empty_like(dest_tensor, device="cpu").contiguous() - self.recv_ops.append((RecvOp(task_id=task_id, tensor=cpu_buffer, src_rank=src_rank), dest_tensor)) + self.recv_ops.append( + (RecvOp(task_id=task_id, tensor=cpu_buffer, src_rank=src_rank), dest_tensor) + ) def run(self): total_ops = len(self.send_ops) + len(self.recv_ops) diff --git a/megatron/core/resharding/copy_services/nvshmem_copy_service.py b/megatron/core/resharding/copy_services/nvshmem_copy_service.py index b3e46deef6b..cc9a65e15b9 100644 --- a/megatron/core/resharding/copy_services/nvshmem_copy_service.py +++ b/megatron/core/resharding/copy_services/nvshmem_copy_service.py @@ -28,9 +28,7 @@ class NVSHMEMCopyService(CopyService): def __init__(self): if not dist.is_initialized(): - raise RuntimeError( - "torch.distributed must be initialized before NVSHMEMCopyService()" - ) + raise RuntimeError("torch.distributed must be initialized before NVSHMEMCopyService()") self.rank = dist.get_rank() self._remote = RemoteCopyService() @@ -50,9 +48,7 @@ def _ensure_initialized(self): self._remote.init(log_level="INFO") self._initialized = True logger.info( - "NVSHMEMCopyService initialized: PE %d / %d", - self._remote.my_pe, - self._remote.n_pes, + "NVSHMEMCopyService initialized: PE %d / %d", self._remote.my_pe, self._remote.n_pes ) def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): @@ -81,12 +77,7 @@ def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): # a small adapter layer that batches up matched send/recv slices. # - def submit_send_with_id( - self, - task_id: int, - src_tensor: torch.Tensor, - dest_rank: int, - ): + def submit_send_with_id(self, task_id: int, src_tensor: torch.Tensor, dest_rank: int): """Register a send with an explicit, globally shared task_id.""" self._ensure_initialized() @@ -111,19 +102,10 @@ def submit_send_with_id( # Use public API on RemoteCopyService self._remote.register_send( - task_id=task_id, - src_tensor=src_bytes, - src_pos=0, - size=num_bytes, - dest_pe=dest_rank, + task_id=task_id, src_tensor=src_bytes, src_pos=0, size=num_bytes, dest_pe=dest_rank ) - def submit_recv_with_id( - self, - task_id: int, - dest_tensor: torch.Tensor, - src_rank: int, - ): + def submit_recv_with_id(self, task_id: int, dest_tensor: torch.Tensor, src_rank: int): """Register a recv with an explicit, globally shared task_id.""" self._ensure_initialized() @@ -147,11 +129,7 @@ def submit_recv_with_id( ) self._remote.register_receive( - task_id=task_id, - dest_tensor=dst_bytes, - dest_pos=0, - size=num_bytes, - src_pe=src_rank, + task_id=task_id, dest_tensor=dst_bytes, dest_pos=0, size=num_bytes, src_pe=src_rank ) def run(self): @@ -202,5 +180,3 @@ def run(self): self._remote.run() self._remote.clear_requests() logger.info("NVSHMEMCopyService: NVSHMEM transfers complete") - - diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py index 99cbf1f0c0f..f911bbbff8a 100644 --- a/megatron/core/resharding/execution.py +++ b/megatron/core/resharding/execution.py @@ -65,4 +65,4 @@ def execute_reshard_plan( with torch.no_grad(): dst_param.data[dst_slice].copy_(recv_buffer) - logger.info("Reshard complete") \ No newline at end of file + logger.info("Reshard complete") diff --git a/megatron/core/resharding/nvshmem_copy_service/__init__.py b/megatron/core/resharding/nvshmem_copy_service/__init__.py index 2019c518039..462797f0998 100644 --- a/megatron/core/resharding/nvshmem_copy_service/__init__.py +++ b/megatron/core/resharding/nvshmem_copy_service/__init__.py @@ -6,16 +6,11 @@ can use it without relying on an external library. """ -from .service import RemoteCopyService from . import nvshmem_types from .core import GPUResourceManager, KernelLauncher, PipelineExecutor from .memory import DoubleBufferManager, TensorPointerExtractor -from .planning import ( - CommunicationScheduler, - GPUExecutionPlanner, - TaskSegmenter, - WorkloadPacker, -) +from .planning import CommunicationScheduler, GPUExecutionPlanner, TaskSegmenter, WorkloadPacker +from .service import RemoteCopyService __all__ = [ "RemoteCopyService", @@ -30,5 +25,3 @@ "TaskSegmenter", "WorkloadPacker", ] - - diff --git a/megatron/core/resharding/nvshmem_copy_service/core/__init__.py b/megatron/core/resharding/nvshmem_copy_service/core/__init__.py index 41ca4bad9b6..7eb931f45a9 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/__init__.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/__init__.py @@ -5,5 +5,3 @@ from .pipeline_executor import PipelineExecutor __all__ = ["GPUResourceManager", "KernelLauncher", "PipelineExecutor"] - - diff --git a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py index 2e95b7f75a4..34fb7ff1178 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py @@ -5,7 +5,7 @@ and event lifecycle. """ -from typing import Optional, Dict +from typing import Dict, Optional import nvshmem.core import torch @@ -49,8 +49,7 @@ def init(self) -> None: # torch.distributed must be initialized before calling this if not dist.is_initialized(): raise RuntimeError( - "torch.distributed must be initialized before " - "GPUResourceManager.init()" + "torch.distributed must be initialized before " "GPUResourceManager.init()" ) # Get current CUDA device (already set by caller based on LOCAL_RANK) @@ -165,12 +164,8 @@ def create_events(self, num_events: int = 2): Returns: tuple: (pack_events, unpack_events) lists of torch.cuda.Event """ - pack_events = [ - torch.cuda.Event(enable_timing=False) for _ in range(num_events) - ] - unpack_events = [ - torch.cuda.Event(enable_timing=False) for _ in range(num_events) - ] + pack_events = [torch.cuda.Event(enable_timing=False) for _ in range(num_events)] + unpack_events = [torch.cuda.Event(enable_timing=False) for _ in range(num_events)] return pack_events, unpack_events def finalize(self) -> None: @@ -179,5 +174,3 @@ def finalize(self) -> None: self.my_pe = -1 self.n_pes = -1 # Streams are automatically cleaned up when objects are deleted - - diff --git a/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py b/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py index 042f2c81608..c5db32b5010 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py @@ -5,7 +5,7 @@ """ import os -from typing import Tuple, Optional, Any +from typing import Any, Optional, Tuple import cupy as cp import torch @@ -24,20 +24,13 @@ def __init__(self): def load_kernels(self) -> None: """Load and compile CUDA kernels from source.""" current_dir = os.path.dirname(os.path.abspath(__file__)) - kernel_path = os.path.join( - current_dir, - "..", - "kernels", - "chunked_kernel.cu", - ) + kernel_path = os.path.join(current_dir, "..", "kernels", "chunked_kernel.cu") with open(kernel_path, "r") as f: kernel_source = f.read() self.chunked_copy_kernel = cp.RawKernel( - kernel_source, - "chunked_batched_copy_kernel", - options=("-std=c++11",), + kernel_source, "chunked_batched_copy_kernel", options=("-std=c++11",) ) def set_streams(self, pack_stream, unpack_stream) -> None: @@ -141,5 +134,3 @@ def launch_unpack( nvtx.range_pop() # Record event on PyTorch stream unpack_event.record(stream=torch_unpack_stream) - - diff --git a/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py b/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py index bcd43ea1da2..4728e598b47 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py @@ -5,25 +5,22 @@ and proper stream synchronization. """ -from typing import List, Dict, Optional +from typing import Dict, List, Optional import nvshmem.core import torch from ..logger import PELogger -from ..nvshmem_types import SendRequest, ReceiveRequest, ScheduledBatch -from .kernel_launcher import KernelLauncher from ..memory.double_buffer_manager import DoubleBufferManager +from ..nvshmem_types import ReceiveRequest, ScheduledBatch, SendRequest +from .kernel_launcher import KernelLauncher class PipelineExecutor: """Executes pipelined NVSHMEM communication with pack/send/unpack overlap.""" def __init__( - self, - kernel_launcher: KernelLauncher, - buffer_manager: DoubleBufferManager, - my_pe: int, + self, kernel_launcher: KernelLauncher, buffer_manager: DoubleBufferManager, my_pe: int ): """ Initialize pipeline executor. @@ -77,9 +74,7 @@ def set_events(self, pack_events: List, unpack_events: List): self.unpack_events = unpack_events def execute_pipeline( - self, - iter_schedules: List[Dict[str, Optional[ScheduledBatch]]], - num_iterations: int, + self, iter_schedules: List[Dict[str, Optional[ScheduledBatch]]], num_iterations: int ) -> None: """ Execute pipelined communication. @@ -109,10 +104,7 @@ def execute_pipeline( torch.cuda.nvtx.range_push(f"Iteration {i}") has_send = iter_schedules[i]["send"] is not None has_recv = iter_schedules[i]["recv"] is not None - has_next_send = ( - i + 1 < num_iterations - and iter_schedules[i + 1]["send"] is not None - ) + has_next_send = i + 1 < num_iterations and iter_schedules[i + 1]["send"] is not None has_prior_recv = i > 0 and iter_schedules[i - 1]["recv"] is not None slot = i % 2 @@ -130,9 +122,7 @@ def execute_pipeline( if has_recv else "" ) - PELogger.debug( - f"Iteration {i}/{num_iterations}: slot={slot}{send_info}{recv_info}" - ) + PELogger.debug(f"Iteration {i}/{num_iterations}: slot={slot}{send_info}{recv_info}") # Step 1: Pack NEXT iteration (async) if has_next_send: @@ -164,9 +154,7 @@ def execute_pipeline( batch = iter_schedules[i]["send"] assert batch is not None transfer_size = batch.total_size - PELogger.debug( - f" Send current: {transfer_size} bytes → PE {batch.dest_pe}" - ) + PELogger.debug(f" Send current: {transfer_size} bytes → PE {batch.dest_pe}") nvshmem.core.put( self.buffer_manager.recv_slots[slot][0:transfer_size], @@ -233,9 +221,7 @@ def _launch_unpack(self, iteration: int, batch: ScheduledBatch) -> None: ) def process_self_moves( - self, - send_requests: List[SendRequest], - receive_requests: List[ReceiveRequest], + self, send_requests: List[SendRequest], receive_requests: List[ReceiveRequest] ) -> None: """ Handle same-PE transfers (where src_pe == dest_pe == my_pe). @@ -247,12 +233,8 @@ def process_self_moves( receive_requests: List of receive requests """ # Match send/recv requests where src_pe == dest_pe == my_pe - local_sends = { - r.task_id: r for r in send_requests if r.dest_pe == self.my_pe - } - local_recvs = [ - r for r in receive_requests if r.src_pe == self.my_pe - ] + local_sends = {r.task_id: r for r in send_requests if r.dest_pe == self.my_pe} + local_recvs = [r for r in receive_requests if r.src_pe == self.my_pe] if local_recvs: PELogger.debug(f"Processing {len(local_recvs)} self-moves") @@ -263,9 +245,7 @@ def process_self_moves( if recv_req.task_id in local_sends: send_req = local_sends[recv_req.task_id] PELogger.debug( - " Self-move: task_id=%d, size=%d bytes", - recv_req.task_id, - send_req.size, + " Self-move: task_id=%d, size=%d bytes", recv_req.task_id, send_req.size ) # Create views of the tensors with offsets @@ -285,5 +265,3 @@ def process_self_moves( if num_processed > 0: PELogger.info("Self-moves complete: %d transfers", num_processed) - - diff --git a/megatron/core/resharding/nvshmem_copy_service/logger.py b/megatron/core/resharding/nvshmem_copy_service/logger.py index 3523f3dd5b4..d81139e039e 100644 --- a/megatron/core/resharding/nvshmem_copy_service/logger.py +++ b/megatron/core/resharding/nvshmem_copy_service/logger.py @@ -13,7 +13,8 @@ - Support for TRACE, DEBUG, INFO, SUMMARY, WARN, ERROR levels """ -#TODO(Peter): We need to remove this logger and use the regular Megatron logger. + +# TODO(Peter): We need to remove this logger and use the regular Megatron logger. import logging import os @@ -206,5 +207,3 @@ def shutdown(cls): handler.close() cls._logger.handlers.clear() cls._logger = None - - diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py b/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py index 5c9f8b573f4..b33fcca198c 100644 --- a/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py +++ b/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py @@ -4,5 +4,3 @@ from .tensor_pointer_utils import TensorPointerExtractor __all__ = ["DoubleBufferManager", "TensorPointerExtractor"] - - diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py b/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py index 314db889385..254b3f49c65 100644 --- a/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py +++ b/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py @@ -28,12 +28,10 @@ def allocate(self) -> None: """Allocate NVSHMEM symmetric buffers for double-buffering.""" for i in range(2): self.send_slots[i] = nvshmem.core.interop.torch.bytetensor( - (self.slot_size,), - dtype=torch.uint8, + (self.slot_size,), dtype=torch.uint8 ) self.recv_slots[i] = nvshmem.core.interop.torch.bytetensor( - (self.slot_size,), - dtype=torch.uint8, + (self.slot_size,), dtype=torch.uint8 ) # Zero out buffers self.send_slots[i].zero_() @@ -72,5 +70,3 @@ def free(self) -> None: if self.recv_slots[i] is not None: nvshmem.core.interop.torch.free_tensor(self.recv_slots[i]) self.recv_slots[i] = None - - diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py b/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py index f39dbb0ae95..70cc54edd3f 100644 --- a/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py +++ b/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py @@ -22,7 +22,7 @@ def get_pointer(tensor: Any) -> int: Returns: int: Memory address of the tensor data - + Examples: >>> import torch @@ -41,5 +41,3 @@ def get_pointer(tensor: Any) -> int: return tensor.data.ptr else: # Assume raw integer pointer return tensor - - diff --git a/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py b/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py index e83dbc51d60..c471e8d63c4 100644 --- a/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py +++ b/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Any +from typing import Any, List # Constants MAX_SEGMENT_SIZE = 256 * 1024 * 1024 # 256MB @@ -57,5 +57,3 @@ class TransferMetadata: sizes: Any # cupy array of uint64 (sizes) num_tasks: int total_size: int - - diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py b/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py index d00914b6ef0..0f858b61edf 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py @@ -5,11 +5,4 @@ from .task_segmenter import TaskSegmenter from .workload_packer import WorkloadPacker -__all__ = [ - "CommunicationScheduler", - "GPUExecutionPlanner", - "TaskSegmenter", - "WorkloadPacker", -] - - +__all__ = ["CommunicationScheduler", "GPUExecutionPlanner", "TaskSegmenter", "WorkloadPacker"] diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py index d70eb559ce5..05450f5767e 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py @@ -3,7 +3,7 @@ import torch from ..logger import PELogger -from ..nvshmem_types import WorkloadGroup, ScheduledBatch, WorkloadSummary +from ..nvshmem_types import ScheduledBatch, WorkloadGroup, WorkloadSummary class CommunicationScheduler: @@ -16,10 +16,7 @@ def __init__(self): self.num_iterations = 0 def build_schedule( - self, - workloads: Dict[int, List[WorkloadGroup]], - my_pe: int, - n_pes: int, + self, workloads: Dict[int, List[WorkloadGroup]], my_pe: int, n_pes: int ) -> Tuple[Dict[int, List[ScheduledBatch]], Dict[Tuple[int, int, int], WorkloadSummary]]: """ Main scheduling method. @@ -30,9 +27,7 @@ def build_schedule( - global workload summaries (key: (src, dest, batch_idx) -> summary) """ total_local_batches = sum(len(groups) for groups in workloads.values()) - PELogger.info( - f"Building schedule: {total_local_batches} local batches, {n_pes} PEs" - ) + PELogger.info(f"Building schedule: {total_local_batches} local batches, {n_pes} PEs") # Step 1: Collect all batches across all PE pairs PELogger.debug("Collecting batches from all PEs...") @@ -47,17 +42,11 @@ def build_schedule( # Step 3: Exchange detailed workload summaries (Task IDs/Sizes) # This is needed for receivers to know what tasks are in each batch PELogger.debug("Exchanging workload summaries...") - global_summaries = self._exchange_workload_summaries( - workloads, - my_pe, - n_pes, - ) + global_summaries = self._exchange_workload_summaries(workloads, my_pe, n_pes) PELogger.debug(f"Exchanged {len(global_summaries)} workload summaries") # Step 4: Build schedule map for this PE - my_batches = [ - b for b in all_batches if b.src_pe == my_pe or b.dest_pe == my_pe - ] + my_batches = [b for b in all_batches if b.src_pe == my_pe or b.dest_pe == my_pe] my_batches.sort(key=lambda x: x.iteration) final_schedule: Dict[int, List[ScheduledBatch]] = {} @@ -67,10 +56,7 @@ def build_schedule( return final_schedule, global_summaries def _collect_all_batches( - self, - workloads: Dict[int, List[WorkloadGroup]], - my_pe: int, - n_pes: int, + self, workloads: Dict[int, List[WorkloadGroup]], my_pe: int, n_pes: int ) -> List[ScheduledBatch]: """ Exchanges batch counts and details with all PEs to build a global view. @@ -100,12 +86,7 @@ def _collect_all_batches( continue for src, dest, idx in pe_batches: global_batches.append( - ScheduledBatch( - src_pe=src, - dest_pe=dest, - batch_index=idx, - iteration=-1, - ) + ScheduledBatch(src_pe=src, dest_pe=dest, batch_index=idx, iteration=-1) ) PELogger.debug(f" Global batches collected: {len(global_batches)} total") @@ -139,10 +120,7 @@ def _assign_iterations(self, batches: List[ScheduledBatch]): iteration += 1 def _has_conflict( - self, - batch: ScheduledBatch, - iteration: int, - all_batches: List[ScheduledBatch], + self, batch: ScheduledBatch, iteration: int, all_batches: List[ScheduledBatch] ) -> bool: for other in all_batches: if other.iteration == iteration and other is not batch: @@ -151,10 +129,7 @@ def _has_conflict( return False def _exchange_workload_summaries( - self, - workloads: Dict[int, List[WorkloadGroup]], - my_pe: int, - n_pes: int, + self, workloads: Dict[int, List[WorkloadGroup]], my_pe: int, n_pes: int ) -> Dict[Tuple[int, int, int], WorkloadSummary]: """ Exchange detailed workload content using torch.distributed. @@ -181,9 +156,7 @@ def _exchange_workload_summaries( batch_count += 1 total_tasks += len(group.tasks) - PELogger.debug( - f" Local summaries: {batch_count} batches, {total_tasks} tasks" - ) + PELogger.debug(f" Local summaries: {batch_count} batches, {total_tasks} tasks") # Gather all summaries from all PEs using torch.distributed all_summaries_list: List[Dict[Tuple[int, int, int], Dict[str, object]] | None] = [ @@ -206,5 +179,3 @@ def _exchange_workload_summaries( PELogger.debug(f" Exchanged {len(global_map)} workload summaries") return global_map - - diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py index a568906f4c3..5eee6ed8bd9 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py @@ -5,7 +5,7 @@ (pointer arrays, sizes, chunking) for kernel execution. """ -from typing import List, Dict, Tuple, Any, Optional +from typing import Any, Dict, List, Optional, Tuple import cupy as cp import torch @@ -13,9 +13,9 @@ from ..logger import PELogger from ..memory.tensor_pointer_utils import TensorPointerExtractor from ..nvshmem_types import ( - SendRequest, ReceiveRequest, ScheduledBatch, + SendRequest, WorkloadGroup, WorkloadSummary, ) @@ -64,20 +64,14 @@ def create_gpu_plans( # Plan kernel args for packing send_batch.gpu_plan = self._plan_kernel_args( - ptrs, - positions, - sizes, - is_pack=True, - buffer_base=send_slots[i % 2].data_ptr(), + ptrs, positions, sizes, is_pack=True, buffer_base=send_slots[i % 2].data_ptr() ) task_ids = [t.task_id for t in send_batch.tasks] PELogger.debug( f" Iter {i} send plan: {len(send_batch.tasks)} tasks → " f"PE {send_batch.dest_pe}, {send_batch.total_size} bytes" ) - displayed_ids = ( - task_ids[:10] if len(task_ids) <= 10 else task_ids[:10] + ["..."] - ) + displayed_ids = task_ids[:10] if len(task_ids) <= 10 else task_ids[:10] + ["..."] PELogger.debug(f" Send task IDs: {displayed_ids}") recv_batch = sched["recv"] @@ -105,9 +99,7 @@ def create_gpu_plans( # Create fast lookup map for receive requests relevant_reqs: Dict[int, ReceiveRequest] = { - r.task_id: r - for r in receive_requests - if r.src_pe == recv_batch.src_pe + r.task_id: r for r in receive_requests if r.src_pe == recv_batch.src_pe } # Match summary tasks with receive requests @@ -136,11 +128,7 @@ def create_gpu_plans( # Plan kernel args for unpacking recv_batch.gpu_plan = self._plan_kernel_args( - ptrs, - positions, - sizes, - is_pack=False, - buffer_base=recv_slots[i % 2].data_ptr(), + ptrs, positions, sizes, is_pack=False, buffer_base=recv_slots[i % 2].data_ptr() ) if recv_batch.gpu_plan is None: @@ -225,5 +213,3 @@ def _plan_kernel_args( cp_sizes = cp.asarray(d_sizes) return (cp_src_addrs, cp_dst_addrs, cp_sizes, total_chunks) - - diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py b/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py index 0e98b8a7811..e9fa1724004 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py @@ -1,6 +1,6 @@ from typing import List -from ..nvshmem_types import SendRequest, ReceiveRequest, MAX_SEGMENT_SIZE +from ..nvshmem_types import MAX_SEGMENT_SIZE, ReceiveRequest, SendRequest # Constants for ID encoding (from C++ implementation) REQUEST_ID_BASE = 1000000000 @@ -93,5 +93,3 @@ def segment_receive_request(self, req: ReceiveRequest) -> List[ReceiveRequest]: output_requests.append(new_req) return output_requests - - diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py b/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py index b4cdffb7767..d0e8595a4c6 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py @@ -1,7 +1,7 @@ -from typing import List, Dict +from typing import Dict, List from ..logger import PELogger -from ..nvshmem_types import SendRequest, WorkloadGroup, MAX_SEGMENT_SIZE, MAX_TASKS_PER_BATCH +from ..nvshmem_types import MAX_SEGMENT_SIZE, MAX_TASKS_PER_BATCH, SendRequest, WorkloadGroup class WorkloadPacker: @@ -11,9 +11,7 @@ class WorkloadPacker: """ def pack_workloads( - self, - send_requests: List[SendRequest], - n_pes: int, + self, send_requests: List[SendRequest], n_pes: int ) -> Dict[int, List[WorkloadGroup]]: """ Groups requests by destination PE and packs them into batches. @@ -51,9 +49,7 @@ def pack_workloads( return workloads def _pack_single_destination( - self, - tasks: List[SendRequest], - dest_pe: int, + self, tasks: List[SendRequest], dest_pe: int ) -> List[WorkloadGroup]: if not tasks: return [] @@ -72,19 +68,13 @@ def _pack_single_destination( if (would_exceed_size or would_exceed_task_cap) and current_batch.tasks: # Finalize current batch batches.append(current_batch) - task_first_10_string = ", ".join( - [str(t.task_id) for t in current_batch.tasks[:10]] - ) + task_first_10_string = ", ".join([str(t.task_id) for t in current_batch.tasks[:10]]) PELogger.debug( f" Packed batch to PE {dest_pe} idx {len(batches) - 1}: " f"{task_first_10_string}... (total {len(current_batch.tasks)} tasks)" ) # Start new batch - current_batch = WorkloadGroup( - dest_pe=dest_pe, - tasks=[], - total_size=0, - ) + current_batch = WorkloadGroup(dest_pe=dest_pe, tasks=[], total_size=0) # Add task to current batch current_batch.tasks.append(task) @@ -95,5 +85,3 @@ def _pack_single_destination( batches.append(current_batch) return batches - - diff --git a/megatron/core/resharding/nvshmem_copy_service/service.py b/megatron/core/resharding/nvshmem_copy_service/service.py index fff5cdd092e..06496a7c2bb 100644 --- a/megatron/core/resharding/nvshmem_copy_service/service.py +++ b/megatron/core/resharding/nvshmem_copy_service/service.py @@ -6,26 +6,16 @@ GPU resource management, and pipelined execution. """ -from typing import List, Dict, Tuple, Optional +from typing import Dict, List, Optional, Tuple import nvshmem.core import torch.cuda.nvtx as nvtx from .core import GPUResourceManager, KernelLauncher, PipelineExecutor -from .memory import DoubleBufferManager -from .nvshmem_types import ( - SendRequest, - ReceiveRequest, - ScheduledBatch, - WorkloadSummary, -) -from .planning import ( - TaskSegmenter, - WorkloadPacker, - CommunicationScheduler, - GPUExecutionPlanner, -) from .logger import PELogger +from .memory import DoubleBufferManager +from .nvshmem_types import ReceiveRequest, ScheduledBatch, SendRequest, WorkloadSummary +from .planning import CommunicationScheduler, GPUExecutionPlanner, TaskSegmenter, WorkloadPacker class RemoteCopyService: @@ -106,16 +96,13 @@ def init(self, log_level: str = "INFO") -> None: # Cache CuPy stream wrappers for efficient kernel launching self.kernel_launcher.set_streams( - self.gpu_resources.pack_stream, - self.gpu_resources.unpack_stream, + self.gpu_resources.pack_stream, self.gpu_resources.unpack_stream ) PELogger.debug("Cached CuPy stream wrappers") # Create pipeline executor with dependencies self.pipeline_executor = PipelineExecutor( - self.kernel_launcher, - self.buffer_manager, - self.my_pe, + self.kernel_launcher, self.buffer_manager, self.my_pe ) # Set streams on pipeline executor @@ -131,12 +118,7 @@ def init(self, log_level: str = "INFO") -> None: PELogger.info("Initialization complete") def register_send( - self, - task_id: int, - src_tensor, - src_pos: int, - size: int, - dest_pe: int, + self, task_id: int, src_tensor, src_pos: int, size: int, dest_pe: int ) -> None: """ Register a send operation. @@ -156,12 +138,7 @@ def register_send( self.send_requests.append(req) def register_receive( - self, - task_id: int, - dest_tensor, - dest_pos: int, - size: int, - src_pe: int, + self, task_id: int, dest_tensor, dest_pos: int, size: int, src_pe: int ) -> None: """ Register a receive operation. @@ -214,36 +191,24 @@ def schedule(self) -> None: # Step 2: Pack tasks into workload groups PELogger.debug("Step 2: Packing workloads...") - workloads = self.workload_packer.pack_workloads( - self.send_requests, - self.n_pes, - ) + workloads = self.workload_packer.pack_workloads(self.send_requests, self.n_pes) total_batches = sum(len(batches) for batches in workloads.values()) active_pes = sum(1 for batches in workloads.values() if batches) - PELogger.info( - f"Packed: {total_batches} batches across {active_pes} destination PEs" - ) + PELogger.info(f"Packed: {total_batches} batches across {active_pes} destination PEs") # Step 3: Schedule workloads to iterations PELogger.debug("Step 3: Building communication schedule...") schedule, global_summaries = self.comm_scheduler.build_schedule( - workloads, - self.my_pe, - self.n_pes, + workloads, self.my_pe, self.n_pes ) self.num_iterations = self.comm_scheduler.num_iterations - PELogger.info( - f"Scheduled: {total_batches} batches → {self.num_iterations} iterations" - ) + PELogger.info(f"Scheduled: {total_batches} batches → {self.num_iterations} iterations") # Step 4: Prepare iteration schedules PELogger.debug("Step 4: Preparing iteration schedules...") self.iter_schedules = self._prepare_iter_schedules( - schedule, - workloads, - global_summaries, - self.num_iterations, + schedule, workloads, global_summaries, self.num_iterations ) # Step 5: Build GPU execution plans @@ -257,14 +222,10 @@ def schedule(self) -> None: # Step 6: Create double-buffered events PELogger.debug("Step 6: Creating synchronization events...") - self.pack_events, self.unpack_events = self.gpu_resources.create_events( - num_events=2 - ) + self.pack_events, self.unpack_events = self.gpu_resources.create_events(num_events=2) self.pipeline_executor.set_events(self.pack_events, self.unpack_events) - PELogger.info( - f"Schedule complete: {self.num_iterations} iterations ready" - ) + PELogger.info(f"Schedule complete: {self.num_iterations} iterations ready") def run(self) -> None: """ @@ -273,9 +234,9 @@ def run(self) -> None: Can be called multiple times after a single schedule() call to repeat the same communication pattern. """ - #import torch - #torch.save(self.send_requests, f"send_requests_{torch.distributed.get_rank()}.pt") - #torch.save(self.receive_requests, f"receive_requests_{torch.distributed.get_rank()}.pt") + # import torch + # torch.save(self.send_requests, f"send_requests_{torch.distributed.get_rank()}.pt") + # torch.save(self.receive_requests, f"receive_requests_{torch.distributed.get_rank()}.pt") if not self.initialized: raise RuntimeError("RemoteCopyService not initialized") @@ -294,10 +255,7 @@ def run(self) -> None: # Execute pipelined communication nvtx.range_push("execute_pipeline") - self.pipeline_executor.execute_pipeline( - self.iter_schedules, - self.num_iterations, - ) + self.pipeline_executor.execute_pipeline(self.iter_schedules, self.num_iterations) nvtx.range_pop() # execute_pipeline # Global barrier after execution @@ -305,10 +263,7 @@ def run(self) -> None: nvshmem.core.barrier_all(stream=self.gpu_resources.send_stream) # Process same-PE transfers - self.pipeline_executor.process_self_moves( - self.send_requests, - self.receive_requests, - ) + self.pipeline_executor.process_self_moves(self.send_requests, self.receive_requests) # End timing range nvtx.range_pop() # RemoteCopyService.run_total @@ -396,8 +351,7 @@ def _prepare_iter_schedules( # Skip same-PE transfers (handled separately by process_self_moves) if b.src_pe == b.dest_pe: PELogger.debug( - f" Iter {i}: Skipping same-PE batch " - f"({b.src_pe} → {b.dest_pe})" + f" Iter {i}: Skipping same-PE batch " f"({b.src_pe} → {b.dest_pe})" ) continue @@ -439,5 +393,3 @@ def _prepare_iter_schedules( iter_schedules.append(sched) return iter_schedules - - diff --git a/megatron/core/resharding/nvshmem_copy_service/validation.py b/megatron/core/resharding/nvshmem_copy_service/validation.py index f2197b7067f..b69589e2423 100644 --- a/megatron/core/resharding/nvshmem_copy_service/validation.py +++ b/megatron/core/resharding/nvshmem_copy_service/validation.py @@ -45,11 +45,7 @@ def all_passed(self) -> bool: return self.failed_tasks == 0 -def generate_deterministic_data( - task_id: int, - size: int, - device: str = "cuda", -) -> torch.Tensor: +def generate_deterministic_data(task_id: int, size: int, device: str = "cuda") -> torch.Tensor: """ Generate deterministic data pattern for a task. @@ -70,10 +66,7 @@ def generate_deterministic_data( def validate_received_data( - task_id: int, - tensor: torch.Tensor, - size: int, - src_pe: int = -1, + task_id: int, tensor: torch.Tensor, size: int, src_pe: int = -1 ) -> ValidationResult: """ Validate received data against expected deterministic pattern. @@ -90,11 +83,7 @@ def validate_received_data( recv_data = tensor[:size] # Generate expected pattern on same device - expected = generate_deterministic_data( - task_id, - size, - device=recv_data.device.type, - ) + expected = generate_deterministic_data(task_id, size, device=recv_data.device.type) # Compare mismatches_mask = recv_data != expected @@ -151,5 +140,3 @@ def log_validation_summary(summary: ValidationSummary) -> None: len(failed_tasks), task_ids[:15] if len(task_ids) <= 15 else task_ids[:15] + ["..."], ) - - diff --git a/megatron/core/resharding/planner.py b/megatron/core/resharding/planner.py index f446df1481d..2d8c1cddcee 100644 --- a/megatron/core/resharding/planner.py +++ b/megatron/core/resharding/planner.py @@ -275,9 +275,7 @@ def build_centralized_reshard_plan( "not found in source model." ) # Choose a representative source metadata with DP round-robin balancing - src_metadata = select_src_metadata_balanced( - src_meta_list, dst_metadata, dst_rank - ) + src_metadata = select_src_metadata_balanced(src_meta_list, dst_metadata, dst_rank) sources = _determine_source_ranks_for_dst_param( resolved_name, src_metadata, dst_metadata, dst_rank ) diff --git a/megatron/core/resharding/utils.py b/megatron/core/resharding/utils.py index b188063fa65..c2dc8937a07 100644 --- a/megatron/core/resharding/utils.py +++ b/megatron/core/resharding/utils.py @@ -82,9 +82,7 @@ class ReshardPlan: recv_ops: list[TransferOp] def __str__(self): - return ( - f"ReshardPlan(sends={len(self.send_ops)}, recvs={len(self.recv_ops)})" - ) + return f"ReshardPlan(sends={len(self.send_ops)}, recvs={len(self.recv_ops)})" # ----------------------------------------------------------------------------- @@ -162,7 +160,7 @@ def extract_param_metadata( is_tp = bool(getattr(param, 'tensor_model_parallel', False)) partition_dim = int(getattr(param, 'partition_dim', 0)) partition_stride = int(getattr(param, 'partition_stride', 1)) - + # SwiGLU/GLU compatibility: For gated linear units, fc1 stores interleaved [gate, up] portions # and requires partition_stride=2 for correct resharding. New models set this at construction # time (MLP sets partition_stride=2 on weight when gated_linear_unit=True). For legacy models @@ -171,7 +169,7 @@ def extract_param_metadata( # and stride doesn't affect single-block transfers. # if 'mlp.linear_fc1' in param_name and is_tp and partition_stride == 1: # partition_stride = 2 - + # EP detection: Megatron convention - expert params are not allreduced is_ep = not bool(getattr(param, 'allreduce', True)) diff --git a/tests/unit_tests/resharding/test_model_swap.py b/tests/unit_tests/resharding/test_model_swap.py index 30d05e87eed..f0cab745320 100644 --- a/tests/unit_tests/resharding/test_model_swap.py +++ b/tests/unit_tests/resharding/test_model_swap.py @@ -2,7 +2,7 @@ import copy import os import types -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple import pytest import torch @@ -20,19 +20,18 @@ from megatron.core.resharding.refit import swap_model_weights from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.cuda_graphs import CudaGraphManager, _CudagraphGlobalRecord +from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils - try: import nvshmem.core + has_nvshmem = True except Exception: has_nvshmem = False - def _build_pg_collection( tp_size: int, pp_size: int = None, ep_size: int = 1 ) -> ProcessGroupCollection: @@ -123,6 +122,7 @@ def _set_pg_collection(module, tp_group, dp_group): module.pg_collection = types.SimpleNamespace(tp=tp_group, dp=dp_group, ep=None, pp=None) return module + @pytest.mark.parametrize( "refit_backend", [ @@ -153,7 +153,7 @@ def _set_pg_collection(module, tp_group, dp_group): (2, 1, 1, 1, 2, 1, None), # TP2,PP1 -> TP1,PP2 (1, 2, 1, 2, 1, 1, None), # TP1,PP2 -> TP2,PP1 (1, 1, 2, 1, 1, 4, 4), # EP2 -> EP4 - (1, 1, 2, 1, 1, 1, 4), # EP2 -> EP1 + (1, 1, 2, 1, 1, 1, 4), # EP2 -> EP1 (1, 1, 1, 1, 1, 2, 4), (1, 1, 2, 1, 2, 2, 4), ], @@ -187,7 +187,7 @@ def test_swap_gpt_parametrized( # Small GPT config seq_len = 8 vocab_size = 128 - # --group-query-attention --num-query-groups 8 + # --group-query-attention --num-query-groups 8 cfg = TransformerConfig( num_layers=4 if (src_pp > 1 or dst_pp > 1) else 2, hidden_size=32, From 34990da9de91be0956c967505aeeecf7ceba3d44 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 4 Jan 2026 16:23:43 -0800 Subject: [PATCH 35/44] lint --- .../resharding/copy_services/gloo_copy_service.py | 3 +++ .../resharding/copy_services/nccl_copy_service.py | 3 +++ .../resharding/copy_services/nvshmem_copy_service.py | 2 +- megatron/core/resharding/execution.py | 1 - .../core/gpu_resource_manager.py | 9 ++++++--- .../resharding/nvshmem_copy_service/nvshmem_types.py | 12 ++++++++++++ .../planning/communication_scheduler.py | 2 -- .../planning/gpu_execution_planner.py | 10 ++-------- .../nvshmem_copy_service/planning/task_segmenter.py | 7 +++++-- .../core/resharding/nvshmem_copy_service/service.py | 4 ++-- .../resharding/nvshmem_copy_service/validation.py | 1 + 11 files changed, 35 insertions(+), 19 deletions(-) diff --git a/megatron/core/resharding/copy_services/gloo_copy_service.py b/megatron/core/resharding/copy_services/gloo_copy_service.py index ee27531caf6..95f9d454682 100644 --- a/megatron/core/resharding/copy_services/gloo_copy_service.py +++ b/megatron/core/resharding/copy_services/gloo_copy_service.py @@ -50,9 +50,11 @@ def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): self.send_ops.append(SendOp(task_id=None, tensor=src_tensor, dest_rank=dest_rank)) def submit_send_with_id(self, task_id: int, src_tensor: torch.Tensor, dest_rank: int): + """Submit a send operation with a unique task identifier.""" self.send_ops.append(SendOp(task_id=task_id, tensor=src_tensor, dest_rank=dest_rank)) def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): + """Submit a receive operation.""" # Allocate a CPU buffer that matches the destination view; we'll # copy into dest_tensor after the Gloo recv completes. cpu_buffer = torch.empty_like(dest_tensor, device="cpu").contiguous() @@ -61,6 +63,7 @@ def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): ) def submit_recv_with_id(self, task_id: int, dest_tensor: torch.Tensor, src_rank: int): + """Submit a receive operation with a unique task identifier.""" cpu_buffer = torch.empty_like(dest_tensor, device="cpu").contiguous() self.recv_ops.append( (RecvOp(task_id=task_id, tensor=cpu_buffer, src_rank=src_rank), dest_tensor) diff --git a/megatron/core/resharding/copy_services/nccl_copy_service.py b/megatron/core/resharding/copy_services/nccl_copy_service.py index 678a03cbf1b..43556f02986 100644 --- a/megatron/core/resharding/copy_services/nccl_copy_service.py +++ b/megatron/core/resharding/copy_services/nccl_copy_service.py @@ -51,12 +51,15 @@ def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): self.send_ops.append(SendOp(task_id=None, tensor=src_tensor, dest_rank=dest_rank)) def submit_send_with_id(self, task_id: int, src_tensor: torch.Tensor, dest_rank: int): + """Submit a send operation with a unique task identifier.""" self.send_ops.append(SendOp(task_id=task_id, tensor=src_tensor, dest_rank=dest_rank)) def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): + """Submit a receive operation.""" self.recv_ops.append(RecvOp(task_id=None, tensor=dest_tensor, src_rank=src_rank)) def submit_recv_with_id(self, task_id: int, dest_tensor: torch.Tensor, src_rank: int): + """Submit a receive operation with a unique task identifier.""" self.recv_ops.append(RecvOp(task_id=task_id, tensor=dest_tensor, src_rank=src_rank)) def run(self): diff --git a/megatron/core/resharding/copy_services/nvshmem_copy_service.py b/megatron/core/resharding/copy_services/nvshmem_copy_service.py index cc9a65e15b9..f38a98aa6fe 100644 --- a/megatron/core/resharding/copy_services/nvshmem_copy_service.py +++ b/megatron/core/resharding/copy_services/nvshmem_copy_service.py @@ -12,7 +12,7 @@ """ import logging -from typing import Dict, List, Tuple +from typing import Dict import torch import torch.distributed as dist diff --git a/megatron/core/resharding/execution.py b/megatron/core/resharding/execution.py index f911bbbff8a..7f706fd94a0 100644 --- a/megatron/core/resharding/execution.py +++ b/megatron/core/resharding/execution.py @@ -8,7 +8,6 @@ import torch.distributed as dist from .copy_services.base import CopyService -from .copy_services.nvshmem_copy_service import NVSHMEMCopyService from .utils import ReshardPlan logger = logging.getLogger(__name__) diff --git a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py index 34fb7ff1178..259cb60db0b 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py @@ -5,12 +5,15 @@ and event lifecycle. """ +import logging from typing import Dict, Optional import nvshmem.core import torch import torch.distributed as dist -from cuda.core.experimental import Device, system +from cuda.core.experimental import Device + +logger = logging.getLogger(__name__) class GPUResourceManager: @@ -86,7 +89,7 @@ def init(self) -> None: initializer_method="uid", ) - print("NVSHMEM initialized") + logger.info("NVSHMEM initialized") self.my_pe = nvshmem.core.my_pe() self.n_pes = nvshmem.core.n_pes() @@ -116,7 +119,7 @@ def init(self) -> None: "copy": self.torch_copy_stream, } - print("Stream mapping built") + logger.info("Stream mapping built") self.initialized = True diff --git a/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py b/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py index c471e8d63c4..090f8a774cc 100644 --- a/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py +++ b/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py @@ -8,6 +8,8 @@ @dataclass class SendRequest: + """Container for a send operation request.""" + task_id: int src_tensor: Any # cupy.ndarray or pointer src_pos: int @@ -17,6 +19,8 @@ class SendRequest: @dataclass class ReceiveRequest: + """Container for a receive operation request.""" + task_id: int dest_tensor: Any # cupy.ndarray or pointer dest_pos: int @@ -26,6 +30,8 @@ class ReceiveRequest: @dataclass class WorkloadGroup: + """Container for a group of send requests to a specific destination PE.""" + dest_pe: int tasks: List[SendRequest] = field(default_factory=list) total_size: int = 0 @@ -33,6 +39,8 @@ class WorkloadGroup: @dataclass class ScheduledBatch: + """Metadata for a scheduled communication batch.""" + src_pe: int dest_pe: int batch_index: int @@ -46,6 +54,8 @@ class ScheduledBatch: @dataclass class WorkloadSummary: + """Summary of a workload group for communication with other PEs.""" + total_size: int task_ids: List[int] task_sizes: List[int] @@ -53,6 +63,8 @@ class WorkloadSummary: @dataclass class TransferMetadata: + """GPU-resident metadata for communication tasks.""" + ptrs: Any # cupy array of uint64 (pointers) sizes: Any # cupy array of uint64 (sizes) num_tasks: int diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py index 05450f5767e..18c9ea81b53 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py @@ -1,7 +1,5 @@ from typing import Dict, List, Tuple -import torch - from ..logger import PELogger from ..nvshmem_types import ScheduledBatch, WorkloadGroup, WorkloadSummary diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py index 5eee6ed8bd9..6b912080cc0 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py @@ -5,20 +5,14 @@ (pointer arrays, sizes, chunking) for kernel execution. """ -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import cupy as cp import torch from ..logger import PELogger from ..memory.tensor_pointer_utils import TensorPointerExtractor -from ..nvshmem_types import ( - ReceiveRequest, - ScheduledBatch, - SendRequest, - WorkloadGroup, - WorkloadSummary, -) +from ..nvshmem_types import ReceiveRequest, ScheduledBatch class GPUExecutionPlanner: diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py b/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py index e9fa1724004..61a551540b7 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py @@ -1,7 +1,10 @@ +import logging from typing import List from ..nvshmem_types import MAX_SEGMENT_SIZE, ReceiveRequest, SendRequest +logger = logging.getLogger(__name__) + # Constants for ID encoding (from C++ implementation) REQUEST_ID_BASE = 1000000000 SEGMENT_ID_MULTIPLIER = 1000 @@ -24,13 +27,13 @@ def _calculate_num_segments(self, size: int) -> int: def _validate_segmentation(self, task_id: int, size: int) -> bool: num_segments = self._calculate_num_segments(size) if num_segments > MAX_SEGMENTS_PER_REQUEST: - print( + logger.error( f"Error: Task {task_id} requires {num_segments} segments, " f"exceeds max {MAX_SEGMENTS_PER_REQUEST}" ) return False if task_id >= MAX_REQUESTS: - print(f"Error: Task ID {task_id} exceeds max {MAX_REQUESTS}") + logger.error(f"Error: Task ID {task_id} exceeds max {MAX_REQUESTS}") return False return True diff --git a/megatron/core/resharding/nvshmem_copy_service/service.py b/megatron/core/resharding/nvshmem_copy_service/service.py index 06496a7c2bb..d28dfd83494 100644 --- a/megatron/core/resharding/nvshmem_copy_service/service.py +++ b/megatron/core/resharding/nvshmem_copy_service/service.py @@ -131,7 +131,7 @@ def register_send( dest_pe: Destination PE rank """ if dest_pe >= self.n_pes or dest_pe < 0: - print(f"Error: Invalid destination PE {dest_pe}") + PELogger.error(f"Error: Invalid destination PE {dest_pe}") return req = SendRequest(task_id, src_tensor, src_pos, size, dest_pe) @@ -151,7 +151,7 @@ def register_receive( src_pe: Source PE rank """ if src_pe >= self.n_pes or src_pe < 0: - print(f"Error: Invalid source PE {src_pe}") + PELogger.error(f"Error: Invalid source PE {src_pe}") return req = ReceiveRequest(task_id, dest_tensor, dest_pos, size, src_pe) diff --git a/megatron/core/resharding/nvshmem_copy_service/validation.py b/megatron/core/resharding/nvshmem_copy_service/validation.py index b69589e2423..be8682ac12e 100644 --- a/megatron/core/resharding/nvshmem_copy_service/validation.py +++ b/megatron/core/resharding/nvshmem_copy_service/validation.py @@ -42,6 +42,7 @@ class ValidationSummary: @property def all_passed(self) -> bool: + """Check if all validated tasks passed.""" return self.failed_tasks == 0 From ae71cecca2620ee0b95d36bfa337ac200ab962a1 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 4 Jan 2026 16:56:46 -0800 Subject: [PATCH 36/44] fix copywrite --- .../copy_services/nvshmem_copy_service.py | 38 ++++++++++++++++--- .../nvshmem_copy_service/__init__.py | 2 + .../nvshmem_copy_service/core/__init__.py | 2 + .../core/gpu_resource_manager.py | 2 + .../core/kernel_launcher.py | 2 + .../core/pipeline_executor.py | 2 + .../resharding/nvshmem_copy_service/logger.py | 2 + .../nvshmem_copy_service/memory/__init__.py | 2 + .../memory/double_buffer_manager.py | 2 + .../memory/tensor_pointer_utils.py | 2 + .../nvshmem_copy_service/nvshmem_types.py | 2 + .../nvshmem_copy_service/planning/__init__.py | 2 + .../planning/communication_scheduler.py | 2 + .../planning/gpu_execution_planner.py | 2 + .../planning/task_segmenter.py | 2 + .../planning/workload_packer.py | 2 + .../nvshmem_copy_service/service.py | 2 + .../nvshmem_copy_service/validation.py | 2 + megatron/core/resharding/refit.py | 2 - 19 files changed, 66 insertions(+), 8 deletions(-) diff --git a/megatron/core/resharding/copy_services/nvshmem_copy_service.py b/megatron/core/resharding/copy_services/nvshmem_copy_service.py index f38a98aa6fe..7ff1923fb08 100644 --- a/megatron/core/resharding/copy_services/nvshmem_copy_service.py +++ b/megatron/core/resharding/copy_services/nvshmem_copy_service.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + from __future__ import annotations """ @@ -28,7 +30,9 @@ class NVSHMEMCopyService(CopyService): def __init__(self): if not dist.is_initialized(): - raise RuntimeError("torch.distributed must be initialized before NVSHMEMCopyService()") + raise RuntimeError( + "torch.distributed must be initialized before NVSHMEMCopyService()" + ) self.rank = dist.get_rank() self._remote = RemoteCopyService() @@ -48,7 +52,9 @@ def _ensure_initialized(self): self._remote.init(log_level="INFO") self._initialized = True logger.info( - "NVSHMEMCopyService initialized: PE %d / %d", self._remote.my_pe, self._remote.n_pes + "NVSHMEMCopyService initialized: PE %d / %d", + self._remote.my_pe, + self._remote.n_pes, ) def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): @@ -77,7 +83,12 @@ def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): # a small adapter layer that batches up matched send/recv slices. # - def submit_send_with_id(self, task_id: int, src_tensor: torch.Tensor, dest_rank: int): + def submit_send_with_id( + self, + task_id: int, + src_tensor: torch.Tensor, + dest_rank: int, + ): """Register a send with an explicit, globally shared task_id.""" self._ensure_initialized() @@ -102,10 +113,19 @@ def submit_send_with_id(self, task_id: int, src_tensor: torch.Tensor, dest_rank: # Use public API on RemoteCopyService self._remote.register_send( - task_id=task_id, src_tensor=src_bytes, src_pos=0, size=num_bytes, dest_pe=dest_rank + task_id=task_id, + src_tensor=src_bytes, + src_pos=0, + size=num_bytes, + dest_pe=dest_rank, ) - def submit_recv_with_id(self, task_id: int, dest_tensor: torch.Tensor, src_rank: int): + def submit_recv_with_id( + self, + task_id: int, + dest_tensor: torch.Tensor, + src_rank: int, + ): """Register a recv with an explicit, globally shared task_id.""" self._ensure_initialized() @@ -129,7 +149,11 @@ def submit_recv_with_id(self, task_id: int, dest_tensor: torch.Tensor, src_rank: ) self._remote.register_receive( - task_id=task_id, dest_tensor=dst_bytes, dest_pos=0, size=num_bytes, src_pe=src_rank + task_id=task_id, + dest_tensor=dst_bytes, + dest_pos=0, + size=num_bytes, + src_pe=src_rank, ) def run(self): @@ -180,3 +204,5 @@ def run(self): self._remote.run() self._remote.clear_requests() logger.info("NVSHMEMCopyService: NVSHMEM transfers complete") + + diff --git a/megatron/core/resharding/nvshmem_copy_service/__init__.py b/megatron/core/resharding/nvshmem_copy_service/__init__.py index 462797f0998..2ab8cde81fe 100644 --- a/megatron/core/resharding/nvshmem_copy_service/__init__.py +++ b/megatron/core/resharding/nvshmem_copy_service/__init__.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """ NVSHMEM-based remote copy service and supporting components. diff --git a/megatron/core/resharding/nvshmem_copy_service/core/__init__.py b/megatron/core/resharding/nvshmem_copy_service/core/__init__.py index 7eb931f45a9..f466e925899 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/__init__.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/__init__.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """Core execution components for NVSHMEM operations.""" from .gpu_resource_manager import GPUResourceManager diff --git a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py index 259cb60db0b..c178e180c4a 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """ GPU resource management for NVSHMEM operations. diff --git a/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py b/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py index c5db32b5010..81c0dcaebc4 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """ CUDA kernel management and launching for pack/unpack operations. diff --git a/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py b/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py index 4728e598b47..8f3315d83f2 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """ Pipelined communication execution engine. diff --git a/megatron/core/resharding/nvshmem_copy_service/logger.py b/megatron/core/resharding/nvshmem_copy_service/logger.py index d81139e039e..62b7c85a172 100644 --- a/megatron/core/resharding/nvshmem_copy_service/logger.py +++ b/megatron/core/resharding/nvshmem_copy_service/logger.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """ Per-PE Logger with colored console and file output. diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py b/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py index b33fcca198c..5cd8aac704b 100644 --- a/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py +++ b/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """Memory management utilities for NVSHMEM operations.""" from .double_buffer_manager import DoubleBufferManager diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py b/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py index 254b3f49c65..02d88e2b39e 100644 --- a/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py +++ b/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """ Double buffer management for NVSHMEM symmetric memory. diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py b/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py index 70cc54edd3f..ee250618ee7 100644 --- a/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py +++ b/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """ Utilities for extracting data pointers from different tensor types. diff --git a/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py b/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py index 090f8a774cc..731cace0502 100644 --- a/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py +++ b/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + from dataclasses import dataclass, field from typing import Any, List diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py b/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py index 0f858b61edf..9df0b3ac318 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """Planning components for task segmentation, workload packing, and scheduling.""" from .communication_scheduler import CommunicationScheduler diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py index 18c9ea81b53..0f299a84e40 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + from typing import Dict, List, Tuple from ..logger import PELogger diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py index 6b912080cc0..7bf7b8fd0a7 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """ GPU execution planning for pack/unpack operations. diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py b/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py index 61a551540b7..fdeaea33ae5 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + import logging from typing import List diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py b/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py index d0e8595a4c6..1f2374bc187 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + from typing import Dict, List from ..logger import PELogger diff --git a/megatron/core/resharding/nvshmem_copy_service/service.py b/megatron/core/resharding/nvshmem_copy_service/service.py index d28dfd83494..36868785515 100644 --- a/megatron/core/resharding/nvshmem_copy_service/service.py +++ b/megatron/core/resharding/nvshmem_copy_service/service.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """ Remote Copy Service - Main orchestrator for NVSHMEM-based GPU-to-GPU transfers. diff --git a/megatron/core/resharding/nvshmem_copy_service/validation.py b/megatron/core/resharding/nvshmem_copy_service/validation.py index be8682ac12e..fafb1321024 100644 --- a/megatron/core/resharding/nvshmem_copy_service/validation.py +++ b/megatron/core/resharding/nvshmem_copy_service/validation.py @@ -1,3 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + """ Validation utilities for GPU-to-GPU communication. diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index 06a69a1fc86..30df5cb3a58 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -74,13 +74,11 @@ def reshard_model_weights( if not hasattr(tgt_core, "pg_collection") or tgt_core.pg_collection is None: raise RuntimeError("Target model missing pg_collection required for NCCL reshard") - # TODO(Peter): We should figure out why this happens. Seems like a bug in Orthotope. # Fill missing DP group on the source using Megatron's parallel state if not provided if getattr(src_core.pg_collection, "dp", None) is None: src_core.pg_collection.dp = parallel_state.get_data_parallel_group() # caching plan for reuse - # TODO(Peter): Find a better place to cache this. cached_plan: Optional[Any] = getattr(tgt_core, "_cached_reshard_plan", None) if cached_plan is None: plan = build_centralized_reshard_plan(src_core, tgt_core, num_experts=num_experts) From 3bf9528c2315f3fc206ce95abd1073a805833334 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 4 Jan 2026 16:59:42 -0800 Subject: [PATCH 37/44] lint --- .../copy_services/nvshmem_copy_service.py | 36 ++++--------------- 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/megatron/core/resharding/copy_services/nvshmem_copy_service.py b/megatron/core/resharding/copy_services/nvshmem_copy_service.py index 7ff1923fb08..d99add8a6f9 100644 --- a/megatron/core/resharding/copy_services/nvshmem_copy_service.py +++ b/megatron/core/resharding/copy_services/nvshmem_copy_service.py @@ -30,9 +30,7 @@ class NVSHMEMCopyService(CopyService): def __init__(self): if not dist.is_initialized(): - raise RuntimeError( - "torch.distributed must be initialized before NVSHMEMCopyService()" - ) + raise RuntimeError("torch.distributed must be initialized before NVSHMEMCopyService()") self.rank = dist.get_rank() self._remote = RemoteCopyService() @@ -52,9 +50,7 @@ def _ensure_initialized(self): self._remote.init(log_level="INFO") self._initialized = True logger.info( - "NVSHMEMCopyService initialized: PE %d / %d", - self._remote.my_pe, - self._remote.n_pes, + "NVSHMEMCopyService initialized: PE %d / %d", self._remote.my_pe, self._remote.n_pes ) def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): @@ -83,12 +79,7 @@ def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): # a small adapter layer that batches up matched send/recv slices. # - def submit_send_with_id( - self, - task_id: int, - src_tensor: torch.Tensor, - dest_rank: int, - ): + def submit_send_with_id(self, task_id: int, src_tensor: torch.Tensor, dest_rank: int): """Register a send with an explicit, globally shared task_id.""" self._ensure_initialized() @@ -113,19 +104,10 @@ def submit_send_with_id( # Use public API on RemoteCopyService self._remote.register_send( - task_id=task_id, - src_tensor=src_bytes, - src_pos=0, - size=num_bytes, - dest_pe=dest_rank, + task_id=task_id, src_tensor=src_bytes, src_pos=0, size=num_bytes, dest_pe=dest_rank ) - def submit_recv_with_id( - self, - task_id: int, - dest_tensor: torch.Tensor, - src_rank: int, - ): + def submit_recv_with_id(self, task_id: int, dest_tensor: torch.Tensor, src_rank: int): """Register a recv with an explicit, globally shared task_id.""" self._ensure_initialized() @@ -149,11 +131,7 @@ def submit_recv_with_id( ) self._remote.register_receive( - task_id=task_id, - dest_tensor=dst_bytes, - dest_pos=0, - size=num_bytes, - src_pe=src_rank, + task_id=task_id, dest_tensor=dst_bytes, dest_pos=0, size=num_bytes, src_pe=src_rank ) def run(self): @@ -204,5 +182,3 @@ def run(self): self._remote.run() self._remote.clear_requests() logger.info("NVSHMEMCopyService: NVSHMEM transfers complete") - - From 3817e28dcd2b92a17559c21cad668db8a2310989 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 4 Jan 2026 17:09:37 -0800 Subject: [PATCH 38/44] fix comment --- .../resharding/copy_services/nvshmem_copy_service.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/megatron/core/resharding/copy_services/nvshmem_copy_service.py b/megatron/core/resharding/copy_services/nvshmem_copy_service.py index d99add8a6f9..8d231de5339 100644 --- a/megatron/core/resharding/copy_services/nvshmem_copy_service.py +++ b/megatron/core/resharding/copy_services/nvshmem_copy_service.py @@ -2,17 +2,6 @@ from __future__ import annotations -""" -NVSHMEM-based implementation of the CopyService interface. - -This wraps the higher-level RemoteCopyService so it can be used anywhere a -CopyService is expected (e.g., refit/reshard execution). - -NOTE: This is a first, minimal wiring. It currently mirrors the point-to-point -semantics of execute_reshard_plan by treating each send/recv pair as an -independent NVSHMEM "task" defined over contiguous slices. -""" - import logging from typing import Dict From 5c1c58fe0fe4271de3e68c477f9f7e1644066810 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 4 Jan 2026 17:23:30 -0800 Subject: [PATCH 39/44] fix import guards --- .../core/gpu_resource_manager.py | 13 ++++++++++++- .../nvshmem_copy_service/core/pipeline_executor.py | 8 +++++++- .../memory/double_buffer_manager.py | 14 +++++++++++++- .../resharding/nvshmem_copy_service/service.py | 13 ++++++++++++- 4 files changed, 44 insertions(+), 4 deletions(-) diff --git a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py index c178e180c4a..6fa0fc1a12f 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py @@ -10,7 +10,13 @@ import logging from typing import Dict, Optional -import nvshmem.core +try: + import nvshmem.core + + HAVE_NVSHMEM = True +except ImportError: + HAVE_NVSHMEM = False + import torch import torch.distributed as dist from cuda.core.experimental import Device @@ -51,6 +57,11 @@ def init(self) -> None: if self.initialized: return + if not HAVE_NVSHMEM: + raise RuntimeError( + "nvshmem.core is not available. Please install nvshmem to use GPUResourceManager." + ) + # torch.distributed must be initialized before calling this if not dist.is_initialized(): raise RuntimeError( diff --git a/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py b/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py index 8f3315d83f2..5ba07f9956a 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py @@ -9,7 +9,13 @@ from typing import Dict, List, Optional -import nvshmem.core +try: + import nvshmem.core + + HAVE_NVSHMEM = True +except ImportError: + HAVE_NVSHMEM = False + import torch from ..logger import PELogger diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py b/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py index 02d88e2b39e..079b2c17610 100644 --- a/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py +++ b/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py @@ -6,7 +6,13 @@ Manages send and receive buffers with double-buffering for pipelined communication. """ -import nvshmem.core.interop.torch +try: + import nvshmem.core.interop.torch + + HAVE_NVSHMEM = True +except ImportError: + HAVE_NVSHMEM = False + import torch from ..nvshmem_types import MAX_SEGMENT_SIZE @@ -28,6 +34,12 @@ def __init__(self, slot_size: int = MAX_SEGMENT_SIZE): def allocate(self) -> None: """Allocate NVSHMEM symmetric buffers for double-buffering.""" + if not HAVE_NVSHMEM: + raise RuntimeError( + "nvshmem.core.interop.torch is not available. " + "Please install nvshmem to use DoubleBufferManager." + ) + for i in range(2): self.send_slots[i] = nvshmem.core.interop.torch.bytetensor( (self.slot_size,), dtype=torch.uint8 diff --git a/megatron/core/resharding/nvshmem_copy_service/service.py b/megatron/core/resharding/nvshmem_copy_service/service.py index 36868785515..631e63ae41b 100644 --- a/megatron/core/resharding/nvshmem_copy_service/service.py +++ b/megatron/core/resharding/nvshmem_copy_service/service.py @@ -10,7 +10,13 @@ from typing import Dict, List, Optional, Tuple -import nvshmem.core +try: + import nvshmem.core + + HAVE_NVSHMEM = True +except ImportError: + HAVE_NVSHMEM = False + import torch.cuda.nvtx as nvtx from .core import GPUResourceManager, KernelLauncher, PipelineExecutor @@ -81,6 +87,11 @@ def init(self, log_level: str = "INFO") -> None: Args: log_level: Logging level (TRACE, DEBUG, INFO, WARN, ERROR) """ + if not HAVE_NVSHMEM: + raise RuntimeError( + "nvshmem.core is not available. Please install nvshmem to use NVSHMEMCopyService." + ) + # Initialize GPU resources (NVSHMEM, device, streams) self.gpu_resources.init() From 88959e8c5a3360bd4557a2f64d36db561f9d9c27 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 4 Jan 2026 18:01:36 -0800 Subject: [PATCH 40/44] fix import errors --- .../nvshmem_copy_service/core/gpu_resource_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py index 6fa0fc1a12f..6e03b914b26 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py @@ -12,6 +12,7 @@ try: import nvshmem.core + from cuda.core.experimental import Device HAVE_NVSHMEM = True except ImportError: @@ -19,7 +20,6 @@ import torch import torch.distributed as dist -from cuda.core.experimental import Device logger = logging.getLogger(__name__) @@ -28,7 +28,7 @@ class GPUResourceManager: """Manages GPU resources including NVSHMEM, streams, and events.""" def __init__(self): - self.device: Optional[Device] = None + self.device = None self.my_pe: int = -1 self.n_pes: int = -1 self.initialized: bool = False From 4306ccd18672434851ce83d1d31be21193d0ad58 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 4 Jan 2026 18:18:04 -0800 Subject: [PATCH 41/44] fix import errors --- .../core/kernel_launcher.py | 19 ++++++++++++++----- .../resharding/nvshmem_copy_service/logger.py | 2 -- .../planning/gpu_execution_planner.py | 15 +++++++++++++-- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py b/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py index 81c0dcaebc4..4e86d6a9505 100644 --- a/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py +++ b/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py @@ -7,9 +7,15 @@ """ import os -from typing import Any, Optional, Tuple +from typing import Any, Tuple + +try: + import cupy as cp + + HAVE_CUPY = True +except ImportError: + HAVE_CUPY = False -import cupy as cp import torch import torch.cuda.nvtx as nvtx @@ -18,13 +24,16 @@ class KernelLauncher: """Manages CUDA kernel loading and launching for data pack/unpack operations.""" def __init__(self): - self.chunked_copy_kernel: Optional[cp.RawKernel] = None + self.chunked_copy_kernel = None # Cached CuPy stream wrappers for efficient kernel launching - self.cp_pack_stream: Optional[cp.cuda.ExternalStream] = None - self.cp_unpack_stream: Optional[cp.cuda.ExternalStream] = None + self.cp_pack_stream = None + self.cp_unpack_stream = None def load_kernels(self) -> None: """Load and compile CUDA kernels from source.""" + if not HAVE_CUPY: + raise RuntimeError("cupy is not available. Please install cupy to use KernelLauncher.") + current_dir = os.path.dirname(os.path.abspath(__file__)) kernel_path = os.path.join(current_dir, "..", "kernels", "chunked_kernel.cu") diff --git a/megatron/core/resharding/nvshmem_copy_service/logger.py b/megatron/core/resharding/nvshmem_copy_service/logger.py index 62b7c85a172..a3c7c1699ad 100644 --- a/megatron/core/resharding/nvshmem_copy_service/logger.py +++ b/megatron/core/resharding/nvshmem_copy_service/logger.py @@ -16,8 +16,6 @@ """ -# TODO(Peter): We need to remove this logger and use the regular Megatron logger. - import logging import os from datetime import datetime diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py index 7bf7b8fd0a7..68c4d11d7e5 100644 --- a/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py +++ b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py @@ -9,7 +9,13 @@ from typing import Dict, List, Optional, Tuple -import cupy as cp +try: + import cupy as cp + + HAVE_CUPY = True +except ImportError: + HAVE_CUPY = False + import torch from ..logger import PELogger @@ -42,6 +48,11 @@ def create_gpu_plans( recv_slots: List of receive buffer slots receive_requests: List of all receive requests for matching """ + if not HAVE_CUPY: + raise RuntimeError( + "cupy is not available. Please install cupy to use GPUExecutionPlanner." + ) + PELogger.debug(f"Creating GPU plans for {len(iter_schedules)} iterations") for i, sched in enumerate(iter_schedules): send_batch = sched["send"] @@ -151,7 +162,7 @@ def _plan_kernel_args( sizes: List[int], is_pack: bool, buffer_base: int, - ) -> Optional[Tuple[cp.ndarray, cp.ndarray, cp.ndarray, int]]: + ) -> Optional[Tuple[object, object, object, int]]: """ Generate GPU-ready pointer arrays for kernel execution. From 53091cb11629f1a169754437382cebbe0d8b7090 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 4 Jan 2026 18:50:18 -0800 Subject: [PATCH 42/44] edit --- tests/unit_tests/resharding/test_model_swap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/resharding/test_model_swap.py b/tests/unit_tests/resharding/test_model_swap.py index f0cab745320..38c9ebafeec 100644 --- a/tests/unit_tests/resharding/test_model_swap.py +++ b/tests/unit_tests/resharding/test_model_swap.py @@ -222,7 +222,7 @@ def test_swap_gpt_parametrized( dst_cfg.add_bias_linear = False # Require Transformer Engine for TEGroupedMLP; skip if unavailable try: - import transformer_engine # noqa: F401 + import transformer_engine except Exception: Utils.destroy_model_parallel() pytest.skip("Transformer Engine not available; skipping TE-grouped MoE test") From e0b3fc4da3c01a6b2ae86d691fe1bb0e6d19a1c7 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 8 Jan 2026 10:51:01 -0800 Subject: [PATCH 43/44] fix formatting --- megatron/training/training.py | 4 +--- .../model_config.yaml | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/megatron/training/training.py b/megatron/training/training.py index 6284ff69fc1..95ca9abf7b0 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -58,9 +58,6 @@ get_pg_rank, StragglerDetector, ) -from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.process_groups_config import ProcessGroupCollection - from megatron.core.fp8_utils import correct_amax_history_if_needed from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.pipeline_parallel.utils import ( @@ -81,6 +78,7 @@ from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel as megatron_FSDP from megatron.core.optimizer.optimizer import param_group_identifier_keys + from megatron.core.optimizer.qk_clip import clip_qk try: diff --git a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/model_config.yaml index 3b57d8355bd..051bcf7ddd4 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/model_config.yaml @@ -77,4 +77,4 @@ MODEL_ARGS: --finetune: true --inference-logging-step-interval: 1 --rl-inference-tensor-model-parallel-size: 2 - --refit-method: gloo + --refit-method: gloo \ No newline at end of file From 1329aa2d4e3261ff3d840a64583e31c801eef728 Mon Sep 17 00:00:00 2001 From: Peter Dykas Date: Wed, 28 Jan 2026 12:21:44 -0800 Subject: [PATCH 44/44] fix cache --- megatron/core/resharding/__init__.py | 9 ++++- megatron/core/resharding/refit.py | 49 +++++++++++++++++++++------- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/megatron/core/resharding/__init__.py b/megatron/core/resharding/__init__.py index d06484eef37..083c4518c0e 100644 --- a/megatron/core/resharding/__init__.py +++ b/megatron/core/resharding/__init__.py @@ -1,7 +1,12 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from .execution import execute_reshard_plan from .planner import build_centralized_reshard_plan -from .refit import reshard_model_weights, swap_model_weights +from .refit import ( + clear_service_cache, + get_or_create_service, + reshard_model_weights, + swap_model_weights, +) from .utils import ParameterMetadata, ReshardPlan, ShardingDescriptor, TransferOp __all__ = [ @@ -9,6 +14,8 @@ "execute_reshard_plan", "swap_model_weights", "reshard_model_weights", + "get_or_create_service", + "clear_service_cache", "ParameterMetadata", "ShardingDescriptor", "TransferOp", diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index 30df5cb3a58..5461b8d3900 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -22,6 +22,41 @@ # Supported refit backend names RefitBackendName = Literal["nccl", "gloo", "nvshmem"] +# Module-level cache for refit services to avoid repeated allocations +_service_cache: dict[str, CopyService] = {} + + +def get_or_create_service(backend: RefitBackendName) -> CopyService: + """Get or create a cached CopyService instance for the given backend. + + This avoids expensive repeated allocations (especially for NVSHMEM buffers) + when swap_model_weights is called multiple times with the same backend. + """ + if backend in _service_cache: + return _service_cache[backend] + + if backend == "nccl": + service = NCCLCopyService() + elif backend == "gloo": + service = GlooCopyService() + elif backend == "nvshmem": + service = NVSHMEMCopyService() + else: + raise ValueError(f"Unknown backend '{backend}'") + + _service_cache[backend] = service + return service + + +def clear_service_cache(): + """Clear the cached refit services. + + Call this if you need to invalidate the cache, for example when + reinitializing distributed state. + """ + global _service_cache + _service_cache.clear() + def swap_model_weights( src_model: LanguageModule, @@ -38,18 +73,8 @@ def swap_model_weights( service = refit_method reshard_model_weights(src_model, target_model, service=service) elif isinstance(refit_method, str): - if refit_method == "nccl": - service = NCCLCopyService() - reshard_model_weights(src_model, target_model, service=service) - elif refit_method == "gloo": - # Debug / fallback backend: run refit over CPU/Gloo instead of NCCL. - service = GlooCopyService() - reshard_model_weights(src_model, target_model, service=service) - elif refit_method == "nvshmem": - service = NVSHMEMCopyService() - reshard_model_weights(src_model, target_model, service=service) - else: - raise ValueError(f"Unknown refit_method '{refit_method}'") + service = get_or_create_service(refit_method) + reshard_model_weights(src_model, target_model, service=service) else: raise TypeError("refit_method must be a str backend name or a CopyService instance")