diff --git a/megatron/core/resharding/__init__.py b/megatron/core/resharding/__init__.py index d06484eef37..083c4518c0e 100644 --- a/megatron/core/resharding/__init__.py +++ b/megatron/core/resharding/__init__.py @@ -1,7 +1,12 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from .execution import execute_reshard_plan from .planner import build_centralized_reshard_plan -from .refit import reshard_model_weights, swap_model_weights +from .refit import ( + clear_service_cache, + get_or_create_service, + reshard_model_weights, + swap_model_weights, +) from .utils import ParameterMetadata, ReshardPlan, ShardingDescriptor, TransferOp __all__ = [ @@ -9,6 +14,8 @@ "execute_reshard_plan", "swap_model_weights", "reshard_model_weights", + "get_or_create_service", + "clear_service_cache", "ParameterMetadata", "ShardingDescriptor", "TransferOp", diff --git a/megatron/core/resharding/copy_services/__init__.py b/megatron/core/resharding/copy_services/__init__.py index 15986e4d28e..447588f7b3a 100644 --- a/megatron/core/resharding/copy_services/__init__.py +++ b/megatron/core/resharding/copy_services/__init__.py @@ -3,5 +3,6 @@ from .base import CopyService from .nccl_copy_service import NCCLCopyService +from .nvshmem_copy_service import NVSHMEMCopyService -__all__ = ["CopyService", "NCCLCopyService"] +__all__ = ["CopyService", "NCCLCopyService", "NVSHMEMCopyService"] diff --git a/megatron/core/resharding/copy_services/nvshmem_copy_service.py b/megatron/core/resharding/copy_services/nvshmem_copy_service.py new file mode 100644 index 00000000000..8d231de5339 --- /dev/null +++ b/megatron/core/resharding/copy_services/nvshmem_copy_service.py @@ -0,0 +1,173 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from __future__ import annotations + +import logging +from typing import Dict + +import torch +import torch.distributed as dist + +from ..nvshmem_copy_service import RemoteCopyService +from .base import CopyService + +logger = logging.getLogger(__name__) + + +class NVSHMEMCopyService(CopyService): + """CopyService implementation backed by NVSHMEM RemoteCopyService.""" + + def __init__(self): + if not dist.is_initialized(): + raise RuntimeError("torch.distributed must be initialized before NVSHMEMCopyService()") + + self.rank = dist.get_rank() + self._remote = RemoteCopyService() + # Lazily initialized on first use to avoid side effects at import time + self._initialized = False + + # NOTE: keep the original typed tensors here (not uint8 views) so local copies + # preserve shape/strides semantics and avoid byte-offset pitfalls. + self._local_send_ops: Dict[int, torch.Tensor] = {} + self._local_recv_ops: Dict[int, torch.Tensor] = {} + self._local_copy_stream = torch.cuda.Stream() + + logger.info("NVSHMEMCopyService constructed") + + def _ensure_initialized(self): + if not self._initialized: + self._remote.init(log_level="INFO") + self._initialized = True + logger.info( + "NVSHMEMCopyService initialized: PE %d / %d", self._remote.my_pe, self._remote.n_pes + ) + + def submit_send(self, src_tensor: torch.Tensor, dest_rank: int): + """ + Basic CopyService API is not rich enough to drive the NVSHMEM planner + (it lacks a globally shared task identifier), so this method is kept + only for interface compatibility and should not be used directly. + + The resharding path calls into NVSHMEMCopyService via the + submit_send_with_id/submit_recv_with_id helpers instead. + """ + raise RuntimeError( + "NVSHMEMCopyService.submit_send() is not supported; " + "use submit_send_with_id(...) from execute_reshard_plan." + ) + + def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int): + raise RuntimeError( + "NVSHMEMCopyService.submit_recv() is not supported; " + "use submit_recv_with_id(...) from execute_reshard_plan." + ) + + # + # New helper API used from execute_reshard_plan via monkey-patching: + # we avoid changing the existing execute_reshard_plan signature by adding + # a small adapter layer that batches up matched send/recv slices. + # + + def submit_send_with_id(self, task_id: int, src_tensor: torch.Tensor, dest_rank: int): + """Register a send with an explicit, globally shared task_id.""" + self._ensure_initialized() + + if not src_tensor.is_contiguous(): + src_tensor = src_tensor.contiguous() + + # Local transfers: keep them out of RemoteCopyService entirely. + if dest_rank == self.rank: + self._local_send_ops[task_id] = src_tensor + return + + num_bytes = src_tensor.numel() * src_tensor.element_size() + src_bytes = src_tensor.view(torch.uint8) + + logger.debug( + "NVSHMEMCopyService: register_send task_id=%d, %d bytes (%d → %d)", + task_id, + num_bytes, + self.rank, + dest_rank, + ) + + # Use public API on RemoteCopyService + self._remote.register_send( + task_id=task_id, src_tensor=src_bytes, src_pos=0, size=num_bytes, dest_pe=dest_rank + ) + + def submit_recv_with_id(self, task_id: int, dest_tensor: torch.Tensor, src_rank: int): + """Register a recv with an explicit, globally shared task_id.""" + self._ensure_initialized() + + if not dest_tensor.is_contiguous(): + dest_tensor = dest_tensor.contiguous() + + # Local transfers: keep them out of RemoteCopyService entirely. + if src_rank == self.rank: + self._local_recv_ops[task_id] = dest_tensor + return + + num_bytes = dest_tensor.numel() * dest_tensor.element_size() + dst_bytes = dest_tensor.view(torch.uint8) + + logger.debug( + "NVSHMEMCopyService: register_recv task_id=%d, %d bytes (%d ← %d)", + task_id, + num_bytes, + self.rank, + src_rank, + ) + + self._remote.register_receive( + task_id=task_id, dest_tensor=dst_bytes, dest_pos=0, size=num_bytes, src_pe=src_rank + ) + + def run(self): + """ + Execute all registered transfer pairs via NVSHMEM. + + This converts the registered pairs into RemoteCopyService send/receive + requests, builds a schedule, runs the pipelined NVSHMEM transfer, and + then clears internal state. + """ + self._ensure_initialized() + + # 1) Run same-rank copies (match by task_id), like NCCL backend. + if self._local_send_ops or self._local_recv_ops: + missing_sends = set(self._local_recv_ops.keys()) - set(self._local_send_ops.keys()) + missing_recvs = set(self._local_send_ops.keys()) - set(self._local_recv_ops.keys()) + if missing_sends or missing_recvs: + raise RuntimeError( + "NVSHMEMCopyService: unmatched local ops on rank " + f"{self.rank}: missing_sends={sorted(list(missing_sends))[:10]} " + f"missing_recvs={sorted(list(missing_recvs))[:10]}" + ) + + with torch.no_grad(): + with torch.cuda.stream(self._local_copy_stream): + for task_id, dst in self._local_recv_ops.items(): + src = self._local_send_ops[task_id] + if src.numel() != dst.numel() or src.element_size() != dst.element_size(): + raise RuntimeError( + "NVSHMEMCopyService: local copy size mismatch on rank " + f"{self.rank} task_id={task_id}: " + f"src=({tuple(src.shape)}, {src.dtype}) " + f"dst=({tuple(dst.shape)}, {dst.dtype})" + ) + dst.copy_(src, non_blocking=True) + + torch.cuda.current_stream().wait_stream(self._local_copy_stream) + self._local_send_ops.clear() + self._local_recv_ops.clear() + + # 2) Execute remote schedule (if any remote sends/recvs were registered). + if not self._remote.send_requests and not self._remote.receive_requests: + logger.info("NVSHMEMCopyService: no remote requests; local copies complete") + return + + logger.info("NVSHMEMCopyService: building NVSHMEM schedule and executing") + self._remote.schedule() + self._remote.run() + self._remote.clear_requests() + logger.info("NVSHMEMCopyService: NVSHMEM transfers complete") diff --git a/megatron/core/resharding/nvshmem_copy_service/__init__.py b/megatron/core/resharding/nvshmem_copy_service/__init__.py new file mode 100644 index 00000000000..2ab8cde81fe --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +NVSHMEM-based remote copy service and supporting components. + +This package is an in-tree integration of the standalone +`nvshmem_copy_service/python` implementation so that Megatron +can use it without relying on an external library. +""" + +from . import nvshmem_types +from .core import GPUResourceManager, KernelLauncher, PipelineExecutor +from .memory import DoubleBufferManager, TensorPointerExtractor +from .planning import CommunicationScheduler, GPUExecutionPlanner, TaskSegmenter, WorkloadPacker +from .service import RemoteCopyService + +__all__ = [ + "RemoteCopyService", + "nvshmem_types", + "GPUResourceManager", + "KernelLauncher", + "PipelineExecutor", + "DoubleBufferManager", + "TensorPointerExtractor", + "CommunicationScheduler", + "GPUExecutionPlanner", + "TaskSegmenter", + "WorkloadPacker", +] diff --git a/megatron/core/resharding/nvshmem_copy_service/core/__init__.py b/megatron/core/resharding/nvshmem_copy_service/core/__init__.py new file mode 100644 index 00000000000..f466e925899 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/core/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Core execution components for NVSHMEM operations.""" + +from .gpu_resource_manager import GPUResourceManager +from .kernel_launcher import KernelLauncher +from .pipeline_executor import PipelineExecutor + +__all__ = ["GPUResourceManager", "KernelLauncher", "PipelineExecutor"] diff --git a/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py new file mode 100644 index 00000000000..6e03b914b26 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/core/gpu_resource_manager.py @@ -0,0 +1,192 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +GPU resource management for NVSHMEM operations. + +Handles NVSHMEM initialization, CUDA device setup, stream management, +and event lifecycle. +""" + +import logging +from typing import Dict, Optional + +try: + import nvshmem.core + from cuda.core.experimental import Device + + HAVE_NVSHMEM = True +except ImportError: + HAVE_NVSHMEM = False + +import torch +import torch.distributed as dist + +logger = logging.getLogger(__name__) + + +class GPUResourceManager: + """Manages GPU resources including NVSHMEM, streams, and events.""" + + def __init__(self): + self.device = None + self.my_pe: int = -1 + self.n_pes: int = -1 + self.initialized: bool = False + + # CUDA streams (cuda.core.experimental) + self.pack_stream = None + self.unpack_stream = None + self.send_stream = None + self.copy_stream = None + + # PyTorch stream wrappers + self.torch_pack_stream = None + self.torch_unpack_stream = None + self.torch_send_stream = None + self.torch_copy_stream = None + + # Stream name to PyTorch stream mapping + self._torch_streams: Dict[str, torch.cuda.ExternalStream] = {} + + def init(self) -> None: + """ + Initialize NVSHMEM, CUDA device, and streams. + + Expects torch.distributed to be already initialized. + """ + if self.initialized: + return + + if not HAVE_NVSHMEM: + raise RuntimeError( + "nvshmem.core is not available. Please install nvshmem to use GPUResourceManager." + ) + + # torch.distributed must be initialized before calling this + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized before " "GPUResourceManager.init()" + ) + + # Get current CUDA device (already set by caller based on LOCAL_RANK) + local_rank = torch.cuda.current_device() + + # nvshmem4py requires a cuda.core Device at init time + self.device = Device(local_rank) + self.device.set_current() + + # Extract rank, nranks from the default process group + num_ranks = dist.get_world_size() + rank_id = dist.get_rank() + + # Create/Broadcast UniqueID using broadcast_object_list + uniqueid = nvshmem.core.get_unique_id(empty=True) + if rank_id == 0: + uniqueid = nvshmem.core.get_unique_id() + broadcast_objects = [uniqueid] + else: + broadcast_objects = [None] + + # Broadcast ID to all ranks using the default group + dist.broadcast_object_list(broadcast_objects, src=0) + + # Barrier to ensure everyone has the ID before NVSHMEM init + dist.barrier() + + # Initialize NVSHMEM with the broadcasted UID + nvshmem.core.init( + device=self.device, + uid=broadcast_objects[0], + rank=rank_id, + nranks=num_ranks, + initializer_method="uid", + ) + + logger.info("NVSHMEM initialized") + + self.my_pe = nvshmem.core.my_pe() + self.n_pes = nvshmem.core.n_pes() + + # Create CUDA streams + self.pack_stream = self.device.create_stream() + self.unpack_stream = self.device.create_stream() + self.send_stream = self.device.create_stream() + self.copy_stream = self.device.create_stream() + + # Get stream pointers and create PyTorch wrappers + _, pack_stream_ptr = self.pack_stream.__cuda_stream__() + _, unpack_stream_ptr = self.unpack_stream.__cuda_stream__() + _, send_stream_ptr = self.send_stream.__cuda_stream__() + _, copy_stream_ptr = self.copy_stream.__cuda_stream__() + + self.torch_pack_stream = torch.cuda.ExternalStream(pack_stream_ptr) + self.torch_unpack_stream = torch.cuda.ExternalStream(unpack_stream_ptr) + self.torch_send_stream = torch.cuda.ExternalStream(send_stream_ptr) + self.torch_copy_stream = torch.cuda.ExternalStream(copy_stream_ptr) + + # Build stream mapping + self._torch_streams = { + "pack": self.torch_pack_stream, + "unpack": self.torch_unpack_stream, + "send": self.torch_send_stream, + "copy": self.torch_copy_stream, + } + + logger.info("Stream mapping built") + + self.initialized = True + + # Initial barrier to ensure all PEs are ready + nvshmem.core.barrier_all(stream=self.send_stream) + + def get_stream(self, name: str): + """ + Get CUDA stream by name. + + Args: + name: Stream name ('pack', 'unpack', 'send', 'copy') + + Returns: + CUDA stream object + """ + streams = { + "pack": self.pack_stream, + "unpack": self.unpack_stream, + "send": self.send_stream, + "copy": self.copy_stream, + } + return streams.get(name) + + def get_torch_stream(self, name: str) -> Optional[torch.cuda.ExternalStream]: + """ + Get PyTorch ExternalStream by name. + + Args: + name: Stream name ('pack', 'unpack', 'send', 'copy') + + Returns: + PyTorch ExternalStream + """ + return self._torch_streams.get(name) + + def create_events(self, num_events: int = 2): + """ + Create double-buffered CUDA events for pack and unpack operations. + + Args: + num_events: Number of events to create for each type + (default: 2 for double buffering) + + Returns: + tuple: (pack_events, unpack_events) lists of torch.cuda.Event + """ + pack_events = [torch.cuda.Event(enable_timing=False) for _ in range(num_events)] + unpack_events = [torch.cuda.Event(enable_timing=False) for _ in range(num_events)] + return pack_events, unpack_events + + def finalize(self) -> None: + """Cleanup resources (streams are automatically managed by CUDA).""" + self.initialized = False + self.my_pe = -1 + self.n_pes = -1 + # Streams are automatically cleaned up when objects are deleted diff --git a/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py b/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py new file mode 100644 index 00000000000..4e86d6a9505 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/core/kernel_launcher.py @@ -0,0 +1,147 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +CUDA kernel management and launching for pack/unpack operations. + +Handles kernel compilation, launching, and stream coordination. +""" + +import os +from typing import Any, Tuple + +try: + import cupy as cp + + HAVE_CUPY = True +except ImportError: + HAVE_CUPY = False + +import torch +import torch.cuda.nvtx as nvtx + + +class KernelLauncher: + """Manages CUDA kernel loading and launching for data pack/unpack operations.""" + + def __init__(self): + self.chunked_copy_kernel = None + # Cached CuPy stream wrappers for efficient kernel launching + self.cp_pack_stream = None + self.cp_unpack_stream = None + + def load_kernels(self) -> None: + """Load and compile CUDA kernels from source.""" + if not HAVE_CUPY: + raise RuntimeError("cupy is not available. Please install cupy to use KernelLauncher.") + + current_dir = os.path.dirname(os.path.abspath(__file__)) + kernel_path = os.path.join(current_dir, "..", "kernels", "chunked_kernel.cu") + + with open(kernel_path, "r") as f: + kernel_source = f.read() + + self.chunked_copy_kernel = cp.RawKernel( + kernel_source, "chunked_batched_copy_kernel", options=("-std=c++11",) + ) + + def set_streams(self, pack_stream, unpack_stream) -> None: + """ + Cache CuPy stream wrappers for kernel launching. + + This eliminates per-launch overhead of stream pointer extraction + and CuPy ExternalStream creation. + + Args: + pack_stream: CUDA stream for pack operations + unpack_stream: CUDA stream for unpack operations + """ + _, pack_stream_ptr = pack_stream.__cuda_stream__() + _, unpack_stream_ptr = unpack_stream.__cuda_stream__() + self.cp_pack_stream = cp.cuda.ExternalStream(pack_stream_ptr) + self.cp_unpack_stream = cp.cuda.ExternalStream(unpack_stream_ptr) + + def launch_pack( + self, + gpu_plan: Tuple[Any, Any, Any, int], + pack_stream, + torch_pack_stream: torch.cuda.ExternalStream, + pack_event: torch.cuda.Event, + ) -> None: + """ + Launch pack kernel to copy data from user tensors to send buffer. + + Args: + gpu_plan: Tuple of (cp_src_addrs, cp_dst_addrs, cp_sizes, num_chunks) + as CuPy arrays + pack_stream: CUDA stream (cuda.core.experimental.Stream) - unused, + kept for compatibility + torch_pack_stream: PyTorch external stream wrapper + pack_event: CUDA event to record after kernel launch + """ + nvtx.range_push("Launch Pack Kernel") + if not gpu_plan: + nvtx.range_pop() + return + + # Unpack cached CuPy arrays from gpu_plan + cp_src, cp_dst, cp_sizes, num_chunks = gpu_plan + + # Grid/Block configuration + THREADS_PER_BLOCK = 1024 + NUM_BLOCKS = 75 + + # Launch kernel using cached CuPy stream + assert self.chunked_copy_kernel is not None + assert self.cp_pack_stream is not None + self.chunked_copy_kernel( + (NUM_BLOCKS,), + (THREADS_PER_BLOCK,), + (cp_src, cp_dst, cp_sizes, num_chunks), + stream=self.cp_pack_stream, + ) + nvtx.range_pop() + # Record event on PyTorch stream + pack_event.record(stream=torch_pack_stream) + + def launch_unpack( + self, + gpu_plan: Tuple[Any, Any, Any, int], + unpack_stream, + torch_unpack_stream: torch.cuda.ExternalStream, + unpack_event: torch.cuda.Event, + ) -> None: + """ + Launch unpack kernel to copy data from receive buffer to user tensors. + + Args: + gpu_plan: Tuple of (cp_src_addrs, cp_dst_addrs, cp_sizes, num_chunks) + as CuPy arrays + unpack_stream: CUDA stream (cuda.core.experimental.Stream) - unused, + kept for compatibility + torch_unpack_stream: PyTorch external stream wrapper + unpack_event: CUDA event to record after kernel launch + """ + nvtx.range_push("Launch Unpack Kernel") + if not gpu_plan: + nvtx.range_pop() + return + + # Unpack cached CuPy arrays from gpu_plan + cp_src, cp_dst, cp_sizes, num_chunks = gpu_plan + + # Grid/Block configuration + THREADS_PER_BLOCK = 1024 + NUM_BLOCKS = 75 + + # Launch kernel using cached CuPy stream + assert self.chunked_copy_kernel is not None + assert self.cp_unpack_stream is not None + self.chunked_copy_kernel( + (NUM_BLOCKS,), + (THREADS_PER_BLOCK,), + (cp_src, cp_dst, cp_sizes, num_chunks), + stream=self.cp_unpack_stream, + ) + nvtx.range_pop() + # Record event on PyTorch stream + unpack_event.record(stream=torch_unpack_stream) diff --git a/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py b/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py new file mode 100644 index 00000000000..5ba07f9956a --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/core/pipeline_executor.py @@ -0,0 +1,275 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Pipelined communication execution engine. + +Orchestrates the pack/send/unpack pipeline with double-buffering +and proper stream synchronization. +""" + +from typing import Dict, List, Optional + +try: + import nvshmem.core + + HAVE_NVSHMEM = True +except ImportError: + HAVE_NVSHMEM = False + +import torch + +from ..logger import PELogger +from ..memory.double_buffer_manager import DoubleBufferManager +from ..nvshmem_types import ReceiveRequest, ScheduledBatch, SendRequest +from .kernel_launcher import KernelLauncher + + +class PipelineExecutor: + """Executes pipelined NVSHMEM communication with pack/send/unpack overlap.""" + + def __init__( + self, kernel_launcher: KernelLauncher, buffer_manager: DoubleBufferManager, my_pe: int + ): + """ + Initialize pipeline executor. + + Args: + kernel_launcher: KernelLauncher instance for pack/unpack kernels + buffer_manager: DoubleBufferManager for send/recv buffers + my_pe: This PE's rank + """ + self.kernel_launcher = kernel_launcher + self.buffer_manager = buffer_manager + self.my_pe = my_pe + + # Streams (will be set by service) + self.pack_stream = None + self.unpack_stream = None + self.send_stream = None + self.copy_stream = None + + self.torch_pack_stream = None + self.torch_unpack_stream = None + self.torch_copy_stream = None + + # Events for double-buffered synchronization + self.pack_events = [] + self.unpack_events = [] + + def set_streams( + self, + pack_stream, + unpack_stream, + send_stream, + copy_stream, + torch_pack_stream, + torch_unpack_stream, + torch_copy_stream, + ): + """Set CUDA streams for execution.""" + self.pack_stream = pack_stream + self.unpack_stream = unpack_stream + self.send_stream = send_stream + self.copy_stream = copy_stream + + self.torch_pack_stream = torch_pack_stream + self.torch_unpack_stream = torch_unpack_stream + self.torch_copy_stream = torch_copy_stream + + def set_events(self, pack_events: List, unpack_events: List): + """Set double-buffered CUDA events.""" + self.pack_events = pack_events + self.unpack_events = unpack_events + + def execute_pipeline( + self, iter_schedules: List[Dict[str, Optional[ScheduledBatch]]], num_iterations: int + ) -> None: + """ + Execute pipelined communication. + + Pipeline stages: + 1. Pack NEXT iteration (async) + 2. Unpack PRIOR iteration (async) + 3. Send CURRENT iteration (sync) + 4. Barrier + 5. Wait for async pack/unpack to complete + + Args: + iter_schedules: List of iteration schedules + num_iterations: Total number of iterations + """ + PELogger.info(f"Executing pipeline: {num_iterations} iterations") + + # Priming: Pack iteration 0 and WAIT for completion + if num_iterations > 0 and iter_schedules[0]["send"]: + torch.cuda.nvtx.range_push("Priming") + PELogger.debug("Priming: Packing iteration 0") + self._launch_pack(0, iter_schedules[0]["send"]) + self.pack_events[0].synchronize() + torch.cuda.nvtx.range_pop() + + for i in range(num_iterations): + torch.cuda.nvtx.range_push(f"Iteration {i}") + has_send = iter_schedules[i]["send"] is not None + has_recv = iter_schedules[i]["recv"] is not None + has_next_send = i + 1 < num_iterations and iter_schedules[i + 1]["send"] is not None + has_prior_recv = i > 0 and iter_schedules[i - 1]["recv"] is not None + + slot = i % 2 + + # Log iteration start + send_info = ( + f" → PE {iter_schedules[i]['send'].dest_pe} " + f"({iter_schedules[i]['send'].total_size} bytes)" + if has_send + else "" + ) + recv_info = ( + f" ← PE {iter_schedules[i]['recv'].src_pe} " + f"({iter_schedules[i]['recv'].total_size} bytes)" + if has_recv + else "" + ) + PELogger.debug(f"Iteration {i}/{num_iterations}: slot={slot}{send_info}{recv_info}") + + # Step 1: Pack NEXT iteration (async) + if has_next_send: + torch.cuda.nvtx.range_push("Step 1: Pack Next") + next_batch = iter_schedules[i + 1]["send"] + assert next_batch is not None + PELogger.debug( + f" Pack next (iter {i+1}): {len(next_batch.tasks)} tasks " + f"→ PE {next_batch.dest_pe}" + ) + self._launch_pack(i + 1, next_batch) + torch.cuda.nvtx.range_pop() + + # Step 2: Unpack PRIOR iteration (async) + if has_prior_recv: + torch.cuda.nvtx.range_push("Step 2: Unpack Prior") + prior_batch = iter_schedules[i - 1]["recv"] + assert prior_batch is not None + PELogger.debug( + f" Unpack prior (iter {i-1}): {prior_batch.total_size} bytes " + f"← PE {prior_batch.src_pe}" + ) + self._launch_unpack(i - 1, prior_batch) + torch.cuda.nvtx.range_pop() + + # Step 3: Send CURRENT iteration + if has_send: + torch.cuda.nvtx.range_push("Step 3: Send Current") + batch = iter_schedules[i]["send"] + assert batch is not None + transfer_size = batch.total_size + PELogger.debug(f" Send current: {transfer_size} bytes → PE {batch.dest_pe}") + + nvshmem.core.put( + self.buffer_manager.recv_slots[slot][0:transfer_size], + self.buffer_manager.send_slots[slot][0:transfer_size], + batch.dest_pe, + stream=self.send_stream, + ) + torch.cuda.nvtx.range_pop() + + # Ensure send completes + self.send_stream.sync() + nvshmem.core.quiet(stream=self.send_stream) + + # Step 4: Global barrier + torch.cuda.nvtx.range_push("Step 4: Barrier") + nvshmem.core.barrier_all(stream=self.send_stream) + self.send_stream.sync() + torch.cuda.nvtx.range_pop() + + # Step 5: Wait for async pack/unpack to complete + torch.cuda.nvtx.range_push("Step 5: Wait Async") + if has_prior_recv: + self.unpack_events[(i - 1) % 2].synchronize() + if has_next_send: + self.pack_events[(i + 1) % 2].synchronize() + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_pop() + + # Final unpack for last iteration + if num_iterations > 0 and iter_schedules[num_iterations - 1]["recv"]: + torch.cuda.nvtx.range_push("Final Unpack") + PELogger.debug(f"Final unpack: iteration {num_iterations-1}") + last_recv = iter_schedules[num_iterations - 1]["recv"] + assert last_recv is not None + self._launch_unpack(num_iterations - 1, last_recv) + self.unpack_events[(num_iterations - 1) % 2].synchronize() + torch.cuda.nvtx.range_pop() + + PELogger.info(f"Pipeline complete: {num_iterations} iterations") + + def _launch_pack(self, iteration: int, batch: ScheduledBatch) -> None: + """Launch pack kernel for given iteration.""" + if not batch.gpu_plan: + return + + self.kernel_launcher.launch_pack( + batch.gpu_plan, + self.pack_stream, + self.torch_pack_stream, + self.pack_events[iteration % 2], + ) + + def _launch_unpack(self, iteration: int, batch: ScheduledBatch) -> None: + """Launch unpack kernel for given iteration.""" + if not batch.gpu_plan: + return + + self.kernel_launcher.launch_unpack( + batch.gpu_plan, + self.unpack_stream, + self.torch_unpack_stream, + self.unpack_events[iteration % 2], + ) + + def process_self_moves( + self, send_requests: List[SendRequest], receive_requests: List[ReceiveRequest] + ) -> None: + """ + Handle same-PE transfers (where src_pe == dest_pe == my_pe). + + Uses PyTorch copy on the copy stream for efficiency. + + Args: + send_requests: List of send requests + receive_requests: List of receive requests + """ + # Match send/recv requests where src_pe == dest_pe == my_pe + local_sends = {r.task_id: r for r in send_requests if r.dest_pe == self.my_pe} + local_recvs = [r for r in receive_requests if r.src_pe == self.my_pe] + + if local_recvs: + PELogger.debug(f"Processing {len(local_recvs)} self-moves") + + num_processed = 0 + with torch.cuda.stream(self.torch_copy_stream): + for recv_req in local_recvs: + if recv_req.task_id in local_sends: + send_req = local_sends[recv_req.task_id] + PELogger.debug( + " Self-move: task_id=%d, size=%d bytes", recv_req.task_id, send_req.size + ) + + # Create views of the tensors with offsets + src_view = send_req.src_tensor[ + send_req.src_pos : send_req.src_pos + send_req.size + ] + dest_view = recv_req.dest_tensor[ + recv_req.dest_pos : recv_req.dest_pos + send_req.size + ] + + # Async copy on the copy stream + dest_view.copy_(src_view, non_blocking=True) + num_processed += 1 + + # Synchronize the PyTorch stream + self.torch_copy_stream.synchronize() + + if num_processed > 0: + PELogger.info("Self-moves complete: %d transfers", num_processed) diff --git a/megatron/core/resharding/nvshmem_copy_service/kernels/chunked_kernel.cu b/megatron/core/resharding/nvshmem_copy_service/kernels/chunked_kernel.cu new file mode 100644 index 00000000000..e5b8fcc9a85 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/kernels/chunked_kernel.cu @@ -0,0 +1,103 @@ + +#include + +// CUDA-compatible types (no C++ standard library headers for NVRTC) +typedef unsigned char uint8_t; +typedef unsigned long long uint64_t; +typedef uint64_t uintptr_t; + +// ============================================================================ +// Kernel Configuration Constants (from ChunkedKernel.h) +// ============================================================================ + +constexpr int CHUNK_SIZE = 128 * 1024; // 128KB per chunk +constexpr int NUM_BLOCKS = 75; // Fixed grid size +constexpr int THREADS_PER_BLOCK = 1024; // Fixed block size +constexpr int FLOAT4_SIZE = 16; // 16 bytes per float4 +constexpr int MAX_CHUNKS_PER_BLOCK = 512; // Max chunks per block for shared memory + +extern "C" { + +/** + * Chunked batched copy kernel implementation + * + * This kernel performs efficient batched memory copies using: + * 1. Contiguous block assignment for better load balancing + * 2. Shared memory prefetching of chunk metadata + * 3. Vectorized float4 (16-byte) copies for aligned data + * 4. Byte-by-byte fallback for unaligned or small data + */ +__global__ void chunked_batched_copy_kernel( + uint8_t** src_addrs, + uint8_t** dst_addrs, + size_t* sizes, + int total_chunks +) { + // Shared memory for metadata prefetching + __shared__ uint8_t* s_src_addrs[MAX_CHUNKS_PER_BLOCK]; + __shared__ uint8_t* s_dst_addrs[MAX_CHUNKS_PER_BLOCK]; + __shared__ size_t s_sizes[MAX_CHUNKS_PER_BLOCK]; + + // Contiguous block assignment: block i processes chunks [start_chunk, end_chunk) + int chunks_per_block = (total_chunks + gridDim.x - 1) / gridDim.x; // Ceiling division + int start_chunk = blockIdx.x * chunks_per_block; + int end_chunk = start_chunk + chunks_per_block; + if (end_chunk > total_chunks) { + end_chunk = total_chunks; + } + int num_chunks_this_block = end_chunk - start_chunk; + + // Phase 1: Cooperative loading of metadata to shared memory + // All 1024 threads cooperate to load metadata from global memory + for (int i = threadIdx.x; i < num_chunks_this_block; i += blockDim.x) { + int global_chunk_id = start_chunk + i; + s_src_addrs[i] = src_addrs[global_chunk_id]; + s_dst_addrs[i] = dst_addrs[global_chunk_id]; + s_sizes[i] = sizes[global_chunk_id]; + } + __syncthreads(); + + // Phase 2: Process each chunk using metadata from shared memory + for (int chunk_id = 0; chunk_id < num_chunks_this_block; chunk_id++) { + uint8_t* src = s_src_addrs[chunk_id]; + uint8_t* dst = s_dst_addrs[chunk_id]; + size_t size = s_sizes[chunk_id]; + + // Check if both src and dst are aligned to 16 bytes for float4 access + uintptr_t src_addr = (uintptr_t)src; + uintptr_t dst_addr = (uintptr_t)dst; + bool is_aligned = ((src_addr % FLOAT4_SIZE) == 0) && ((dst_addr % FLOAT4_SIZE) == 0); + + if (is_aligned && size >= FLOAT4_SIZE) { + // Fast path: vectorized float4 copies + size_t aligned_size = (size / FLOAT4_SIZE) * FLOAT4_SIZE; + + // All 1024 threads cooperate on float4 copies + #pragma unroll 4 + for (size_t offset = threadIdx.x * FLOAT4_SIZE; + offset < aligned_size; + offset += blockDim.x * FLOAT4_SIZE) { + // Vectorized 16-byte load and store + float4 data = *((float4*)(src + offset)); + *((float4*)(dst + offset)) = data; + } + + // Handle remaining bytes (< 16 bytes) with byte-by-byte copy + for (size_t offset = aligned_size + threadIdx.x; + offset < size; + offset += blockDim.x) { + dst[offset] = src[offset]; + } + } else { + // Fallback path: byte-by-byte copy for unaligned addresses + // Still use all threads for parallelism + for (size_t offset = threadIdx.x; offset < size; offset += blockDim.x) { + dst[offset] = src[offset]; + } + } + } +} + +} + + diff --git a/megatron/core/resharding/nvshmem_copy_service/logger.py b/megatron/core/resharding/nvshmem_copy_service/logger.py new file mode 100644 index 00000000000..a3c7c1699ad --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/logger.py @@ -0,0 +1,209 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" + +Per-PE Logger with colored console and file output. + + + +Similar to the C++ Logger implementation, provides: + +- Per-PE colored console output + +- Per-PE file logging + +- Support for TRACE, DEBUG, INFO, SUMMARY, WARN, ERROR levels + +""" + +import logging +import os +from datetime import datetime +from typing import Optional + + +class ColoredFormatter(logging.Formatter): + """Custom formatter that adds color codes for console output.""" + + def __init__(self, fmt: str, pe_id: int, use_color: bool = True): + super().__init__(fmt) + self.pe_id = pe_id + self.use_color = use_color + + # ANSI color codes matching C++ implementation + self.colors = { + 0: "\033[31m", # Red + 1: "\033[32m", # Green + 2: "\033[33m", # Yellow + 3: "\033[34m", # Blue + 4: "\033[35m", # Magenta + 5: "\033[36m", # Cyan + 6: "\033[91m", # Bright Red + 7: "\033[92m", # Bright Green + } + self.reset = "\033[0m" + + def formatTime(self, record, datefmt=None): + ct = self.converter(record.created) + if datefmt: + s = datetime.fromtimestamp(record.created).strftime(datefmt) + # For file logs, replace %f with milliseconds + if "%f" in datefmt: + s = s.replace("%f", f"{int(record.msecs):03d}") + else: + s = datetime.fromtimestamp(record.created).strftime("%H:%M:%S") + s = f"{s}.{int(record.msecs):03d}" + return s + + def format(self, record): + # Save original message + original_msg = record.msg + + if self.use_color and self.pe_id >= 0: + color = self.colors.get(self.pe_id, "\033[37m") # White for others + record.msg = f"{color}{record.msg}{self.reset}" + + result = super().format(record) + + # Restore original message for other handlers + record.msg = original_msg + + return result + + +class PELogger: + """Per-PE logger with colored console and file output.""" + + _logger: Optional[logging.Logger] = None + _pe_id: int = -1 + _level: int = logging.INFO + + @classmethod + def init(cls, pe_id: int, level: str = "INFO", logs_dir: str = "logs"): + """ + Initialize logger for this PE. + + Args: + pe_id: Process element ID + level: Log level (TRACE, DEBUG, INFO, WARN, ERROR) + logs_dir: Directory for log files + """ + cls._pe_id = pe_id + + # Convert level string to logging level + level_map = { + "TRACE": logging.DEBUG - 5, # Custom level below DEBUG + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "SUMMARY": logging.INFO, + "WARN": logging.WARNING, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + cls._level = level_map.get(level.upper(), logging.INFO) + + # Create logs directory if it doesn't exist + os.makedirs(logs_dir, exist_ok=True) + + # Create logger + logger_name = f"PE_{pe_id}" + cls._logger = logging.getLogger(logger_name) + cls._logger.setLevel(cls._level) + cls._logger.propagate = False + + # Remove existing handlers to avoid duplicates + cls._logger.handlers.clear() + + # 1. Console handler with color + console_handler = logging.StreamHandler() + console_handler.setLevel(cls._level) + console_format = "[PE %d] [%%(asctime)s] [%%(levelname)s] %%(message)s" % pe_id + console_formatter = ColoredFormatter(console_format, pe_id, use_color=True) + console_handler.setFormatter(console_formatter) + cls._logger.addHandler(console_handler) + + # 2. File handler without color + log_filename = os.path.join(logs_dir, f"pe_{pe_id}.log") + file_handler = logging.FileHandler(log_filename, mode="w") + file_handler.setLevel(cls._level) + file_format = "[PE %d] [%%(asctime)s] [%%(levelname)s] %%(message)s" % pe_id + file_formatter = ColoredFormatter(file_format, pe_id, use_color=False) + file_handler.setFormatter(file_formatter) + cls._logger.addHandler(file_handler) + + @classmethod + def set_level(cls, level: str): + """Set the logging level.""" + level_map = { + "TRACE": logging.DEBUG - 5, + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "SUMMARY": logging.INFO, + "WARN": logging.WARNING, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + cls._level = level_map.get(level.upper(), logging.INFO) + if cls._logger: + cls._logger.setLevel(cls._level) + for handler in cls._logger.handlers: + handler.setLevel(cls._level) + + @classmethod + def trace(cls, msg: str): + """Log at TRACE level (most detailed).""" + if cls._logger: + cls._logger.log(logging.DEBUG - 5, msg) + + @classmethod + def debug(cls, msg: str): + """Log at DEBUG level.""" + if cls._logger: + cls._logger.debug(msg) + + @classmethod + def info(cls, msg: str): + """Log at INFO level.""" + if cls._logger: + cls._logger.info(msg) + + @classmethod + def summary(cls, msg: str): + """Log summary information (INFO level with [SUMMARY] prefix).""" + if cls._logger: + cls._logger.info(f"[SUMMARY] {msg}") + + @classmethod + def warn(cls, msg: str): + """Log at WARNING level.""" + if cls._logger: + cls._logger.warning(msg) + + @classmethod + def warning(cls, msg: str): + """Log at WARNING level (alias for warn).""" + cls.warn(msg) + + @classmethod + def error(cls, msg: str): + """Log at ERROR level.""" + if cls._logger: + cls._logger.error(msg) + + @classmethod + def critical(cls, msg: str): + """Log at CRITICAL level.""" + if cls._logger: + cls._logger.critical(msg) + + @classmethod + def shutdown(cls): + """Shutdown the logger and flush all handlers.""" + if cls._logger: + for handler in cls._logger.handlers: + handler.flush() + handler.close() + cls._logger.handlers.clear() + cls._logger = None diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py b/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py new file mode 100644 index 00000000000..5cd8aac704b --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/memory/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Memory management utilities for NVSHMEM operations.""" + +from .double_buffer_manager import DoubleBufferManager +from .tensor_pointer_utils import TensorPointerExtractor + +__all__ = ["DoubleBufferManager", "TensorPointerExtractor"] diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py b/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py new file mode 100644 index 00000000000..079b2c17610 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/memory/double_buffer_manager.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Double buffer management for NVSHMEM symmetric memory. + +Manages send and receive buffers with double-buffering for pipelined communication. +""" + +try: + import nvshmem.core.interop.torch + + HAVE_NVSHMEM = True +except ImportError: + HAVE_NVSHMEM = False + +import torch + +from ..nvshmem_types import MAX_SEGMENT_SIZE + + +class DoubleBufferManager: + """Manages double-buffered NVSHMEM symmetric buffers for send/receive operations.""" + + def __init__(self, slot_size: int = MAX_SEGMENT_SIZE): + """ + Initialize buffer manager. + + Args: + slot_size: Size of each buffer slot in bytes (default: 256MB) + """ + self.slot_size = slot_size + self.send_slots = [None, None] + self.recv_slots = [None, None] + + def allocate(self) -> None: + """Allocate NVSHMEM symmetric buffers for double-buffering.""" + if not HAVE_NVSHMEM: + raise RuntimeError( + "nvshmem.core.interop.torch is not available. " + "Please install nvshmem to use DoubleBufferManager." + ) + + for i in range(2): + self.send_slots[i] = nvshmem.core.interop.torch.bytetensor( + (self.slot_size,), dtype=torch.uint8 + ) + self.recv_slots[i] = nvshmem.core.interop.torch.bytetensor( + (self.slot_size,), dtype=torch.uint8 + ) + # Zero out buffers + self.send_slots[i].zero_() + self.recv_slots[i].zero_() + + def get_send_slot(self, iteration: int): + """ + Get send buffer for given iteration. + + Args: + iteration: Iteration number + + Returns: + NVSHMEM tensor for sending + """ + return self.send_slots[iteration % 2] + + def get_recv_slot(self, iteration: int): + """ + Get receive buffer for given iteration. + + Args: + iteration: Iteration number + + Returns: + NVSHMEM tensor for receiving + """ + return self.recv_slots[iteration % 2] + + def free(self) -> None: + """Free NVSHMEM symmetric buffers.""" + for i in range(2): + if self.send_slots[i] is not None: + nvshmem.core.interop.torch.free_tensor(self.send_slots[i]) + self.send_slots[i] = None + if self.recv_slots[i] is not None: + nvshmem.core.interop.torch.free_tensor(self.recv_slots[i]) + self.recv_slots[i] = None diff --git a/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py b/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py new file mode 100644 index 00000000000..ee250618ee7 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/memory/tensor_pointer_utils.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Utilities for extracting data pointers from different tensor types. + +Supports PyTorch tensors, CuPy arrays, and raw integer pointers. +""" + +from typing import Any + +import torch + + +class TensorPointerExtractor: + """Extract memory pointers from various tensor types.""" + + @staticmethod + def get_pointer(tensor: Any) -> int: + """ + Extract the data pointer from a tensor. + + Args: + tensor: Can be torch.Tensor, CuPy array, or raw int pointer + + Returns: + int: Memory address of the tensor data + + Examples: + + >>> import torch + + >>> t = torch.zeros(100, device='cuda') + + >>> ptr = TensorPointerExtractor.get_pointer(t) + + >>> isinstance(ptr, int) + + True + """ + if isinstance(tensor, torch.Tensor): + return tensor.data_ptr() + elif hasattr(tensor, "data"): # CuPy array + return tensor.data.ptr + else: # Assume raw integer pointer + return tensor diff --git a/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py b/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py new file mode 100644 index 00000000000..731cace0502 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/nvshmem_types.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from dataclasses import dataclass, field +from typing import Any, List + +# Constants +MAX_SEGMENT_SIZE = 256 * 1024 * 1024 # 256MB +MAX_TASKS_PER_BATCH = 10000 + + +@dataclass +class SendRequest: + """Container for a send operation request.""" + + task_id: int + src_tensor: Any # cupy.ndarray or pointer + src_pos: int + size: int + dest_pe: int + + +@dataclass +class ReceiveRequest: + """Container for a receive operation request.""" + + task_id: int + dest_tensor: Any # cupy.ndarray or pointer + dest_pos: int + size: int + src_pe: int + + +@dataclass +class WorkloadGroup: + """Container for a group of send requests to a specific destination PE.""" + + dest_pe: int + tasks: List[SendRequest] = field(default_factory=list) + total_size: int = 0 + + +@dataclass +class ScheduledBatch: + """Metadata for a scheduled communication batch.""" + + src_pe: int + dest_pe: int + batch_index: int + iteration: int + # Metadata for GPU execution + gpu_plan: Any = None # Placeholder for GPU-resident plan + tasks: List[SendRequest] = field(default_factory=list) + total_size: int = 0 + tasks_summary: Any = None # WorkloadSummary + + +@dataclass +class WorkloadSummary: + """Summary of a workload group for communication with other PEs.""" + + total_size: int + task_ids: List[int] + task_sizes: List[int] + + +@dataclass +class TransferMetadata: + """GPU-resident metadata for communication tasks.""" + + ptrs: Any # cupy array of uint64 (pointers) + sizes: Any # cupy array of uint64 (sizes) + num_tasks: int + total_size: int diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py b/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py new file mode 100644 index 00000000000..9df0b3ac318 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/planning/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Planning components for task segmentation, workload packing, and scheduling.""" + +from .communication_scheduler import CommunicationScheduler +from .gpu_execution_planner import GPUExecutionPlanner +from .task_segmenter import TaskSegmenter +from .workload_packer import WorkloadPacker + +__all__ = ["CommunicationScheduler", "GPUExecutionPlanner", "TaskSegmenter", "WorkloadPacker"] diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py new file mode 100644 index 00000000000..0f299a84e40 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/planning/communication_scheduler.py @@ -0,0 +1,181 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from typing import Dict, List, Tuple + +from ..logger import PELogger +from ..nvshmem_types import ScheduledBatch, WorkloadGroup, WorkloadSummary + + +class CommunicationScheduler: + """ + Builds a conflict-free, iteration-based schedule for communication. + Ensures that in any given iteration, a PE is not overloaded. + """ + + def __init__(self): + self.num_iterations = 0 + + def build_schedule( + self, workloads: Dict[int, List[WorkloadGroup]], my_pe: int, n_pes: int + ) -> Tuple[Dict[int, List[ScheduledBatch]], Dict[Tuple[int, int, int], WorkloadSummary]]: + """ + Main scheduling method. + 1. Exchanges workload info with other PEs. + 2. Assigns batches to iterations. + 3. Returns: + - local schedule (iteration -> list of batches) + - global workload summaries (key: (src, dest, batch_idx) -> summary) + """ + total_local_batches = sum(len(groups) for groups in workloads.values()) + PELogger.info(f"Building schedule: {total_local_batches} local batches, {n_pes} PEs") + + # Step 1: Collect all batches across all PE pairs + PELogger.debug("Collecting batches from all PEs...") + all_batches = self._collect_all_batches(workloads, my_pe, n_pes) + PELogger.debug(f"Collected {len(all_batches)} total batches globally") + + # Step 2: Assign batches to iterations using conflict-free algorithm + PELogger.debug("Assigning batches to iterations...") + self._assign_iterations(all_batches) + PELogger.info(f"Schedule built: {self.num_iterations} iterations") + + # Step 3: Exchange detailed workload summaries (Task IDs/Sizes) + # This is needed for receivers to know what tasks are in each batch + PELogger.debug("Exchanging workload summaries...") + global_summaries = self._exchange_workload_summaries(workloads, my_pe, n_pes) + PELogger.debug(f"Exchanged {len(global_summaries)} workload summaries") + + # Step 4: Build schedule map for this PE + my_batches = [b for b in all_batches if b.src_pe == my_pe or b.dest_pe == my_pe] + my_batches.sort(key=lambda x: x.iteration) + + final_schedule: Dict[int, List[ScheduledBatch]] = {} + for b in my_batches: + final_schedule.setdefault(b.iteration, []).append(b) + + return final_schedule, global_summaries + + def _collect_all_batches( + self, workloads: Dict[int, List[WorkloadGroup]], my_pe: int, n_pes: int + ) -> List[ScheduledBatch]: + """ + Exchanges batch counts and details with all PEs to build a global view. + Uses torch.distributed for reliable communication. + """ + import torch.distributed as dist + + # Build local batch list + local_batches: List[Tuple[int, int, int]] = [] + for dest_pe, groups in workloads.items(): + if dest_pe == my_pe: + continue + for i, _ in enumerate(groups): + local_batches.append((my_pe, dest_pe, i)) # (src, dest, batch_idx) + + PELogger.debug(f" Local batch count: {len(local_batches)}") + PELogger.debug(f" Local batches: {local_batches}") + + # Gather all batches from all PEs using torch.distributed + all_batches_list: List[List[Tuple[int, int, int]] | None] = [None] * n_pes + dist.all_gather_object(all_batches_list, local_batches) + + # Flatten into global batch list + global_batches: List[ScheduledBatch] = [] + for pe_batches in all_batches_list: + if pe_batches is None: + continue + for src, dest, idx in pe_batches: + global_batches.append( + ScheduledBatch(src_pe=src, dest_pe=dest, batch_index=idx, iteration=-1) + ) + + PELogger.debug(f" Global batches collected: {len(global_batches)} total") + + # Group by source for readability + batches_by_src: Dict[int, List[Tuple[int, int]]] = {} + for b in global_batches: + batches_by_src.setdefault(b.src_pe, []).append((b.dest_pe, b.batch_index)) + for src_pe in sorted(batches_by_src.keys()): + PELogger.debug(f" PE {src_pe} sends to: {batches_by_src[src_pe]}") + + return global_batches + + def _assign_iterations(self, batches: List[ScheduledBatch]): + self.num_iterations = 0 + batches.sort(key=lambda x: (x.src_pe, x.dest_pe, x.batch_index)) + + for batch in batches: + iteration = 0 + assigned = False + while not assigned: + if not self._has_conflict(batch, iteration, batches): + batch.iteration = iteration + self.num_iterations = max(self.num_iterations, iteration + 1) + assigned = True + PELogger.debug( + f" Assigned batch ({batch.src_pe} → {batch.dest_pe}, " + f"idx={batch.batch_index}) to iteration {iteration}" + ) + else: + iteration += 1 + + def _has_conflict( + self, batch: ScheduledBatch, iteration: int, all_batches: List[ScheduledBatch] + ) -> bool: + for other in all_batches: + if other.iteration == iteration and other is not batch: + if other.src_pe == batch.src_pe or other.dest_pe == batch.dest_pe: + return True + return False + + def _exchange_workload_summaries( + self, workloads: Dict[int, List[WorkloadGroup]], my_pe: int, n_pes: int + ) -> Dict[Tuple[int, int, int], WorkloadSummary]: + """ + Exchange detailed workload content using torch.distributed. + Simple and reliable - no NVSHMEM symmetric memory issues. + """ + import torch.distributed as dist + + # Build local summaries as a simple dict: + # (src, dest, batch_idx) -> {total_size, task_ids, task_sizes} + local_summaries: Dict[Tuple[int, int, int], Dict[str, object]] = {} + batch_count = 0 + total_tasks = 0 + + for dest_pe, groups in workloads.items(): + if dest_pe == my_pe: + continue + for batch_idx, group in enumerate(groups): + key = (my_pe, dest_pe, batch_idx) + local_summaries[key] = { + "total_size": group.total_size, + "task_ids": [t.task_id for t in group.tasks], + "task_sizes": [t.size for t in group.tasks], + } + batch_count += 1 + total_tasks += len(group.tasks) + + PELogger.debug(f" Local summaries: {batch_count} batches, {total_tasks} tasks") + + # Gather all summaries from all PEs using torch.distributed + all_summaries_list: List[Dict[Tuple[int, int, int], Dict[str, object]] | None] = [ + None + ] * n_pes + dist.all_gather_object(all_summaries_list, local_summaries) + + # Merge into global map + global_map: Dict[Tuple[int, int, int], WorkloadSummary] = {} + for pe_summaries in all_summaries_list: + if pe_summaries is None: + continue + for key, data in pe_summaries.items(): + summary = WorkloadSummary( + total_size=int(data["total_size"]), + task_ids=list(data["task_ids"]), + task_sizes=list(data["task_sizes"]), + ) + global_map[key] = summary + + PELogger.debug(f" Exchanged {len(global_map)} workload summaries") + return global_map diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py new file mode 100644 index 00000000000..68c4d11d7e5 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/planning/gpu_execution_planner.py @@ -0,0 +1,222 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +GPU execution planning for pack/unpack operations. + +Converts high-level task descriptions into GPU-ready metadata +(pointer arrays, sizes, chunking) for kernel execution. +""" + +from typing import Dict, List, Optional, Tuple + +try: + import cupy as cp + + HAVE_CUPY = True +except ImportError: + HAVE_CUPY = False + +import torch + +from ..logger import PELogger +from ..memory.tensor_pointer_utils import TensorPointerExtractor +from ..nvshmem_types import ReceiveRequest, ScheduledBatch + + +class GPUExecutionPlanner: + """Plans GPU kernel execution by building pointer arrays and metadata.""" + + def __init__(self): + self.tensor_utils = TensorPointerExtractor() + self.CHUNK_SIZE = 128 * 1024 # 128KB chunks + + def create_gpu_plans( + self, + iter_schedules: List[Dict[str, Optional[ScheduledBatch]]], + send_slots: List, + recv_slots: List, + receive_requests: List[ReceiveRequest], + ) -> None: + """ + Build GPU execution plans for all iterations. + + Modifies iter_schedules in-place by adding gpu_plan to each batch. + + Args: + iter_schedules: List of iteration schedules (dicts with 'send' and 'recv') + send_slots: List of send buffer slots + recv_slots: List of receive buffer slots + receive_requests: List of all receive requests for matching + """ + if not HAVE_CUPY: + raise RuntimeError( + "cupy is not available. Please install cupy to use GPUExecutionPlanner." + ) + + PELogger.debug(f"Creating GPU plans for {len(iter_schedules)} iterations") + for i, sched in enumerate(iter_schedules): + send_batch = sched["send"] + if send_batch: + # Build Pack Metadata + ptrs: List[int] = [] + positions: List[int] = [] + sizes: List[int] = [] + + for t in send_batch.tasks: + # Extract pointer from tensor + ptr = self.tensor_utils.get_pointer(t.src_tensor) + ptrs.append(ptr) + positions.append(t.src_pos) + sizes.append(t.size) + + # Plan kernel args for packing + send_batch.gpu_plan = self._plan_kernel_args( + ptrs, positions, sizes, is_pack=True, buffer_base=send_slots[i % 2].data_ptr() + ) + task_ids = [t.task_id for t in send_batch.tasks] + PELogger.debug( + f" Iter {i} send plan: {len(send_batch.tasks)} tasks → " + f"PE {send_batch.dest_pe}, {send_batch.total_size} bytes" + ) + displayed_ids = task_ids[:10] if len(task_ids) <= 10 else task_ids[:10] + ["..."] + PELogger.debug(f" Send task IDs: {displayed_ids}") + + recv_batch = sched["recv"] + if recv_batch: + # Build Unpack Metadata + summary = recv_batch.tasks_summary + + # Skip if no summary available (shouldn't happen in normal operation) + if summary is None: + PELogger.error( + f"Iter {i}: recv batch from PE {recv_batch.src_pe} has no " + "tasks_summary - UNPACK WILL BE SKIPPED!" + ) + recv_batch.gpu_plan = None + continue + + PELogger.debug( + f" Iter {i} recv from PE {recv_batch.src_pe}: " + f"{len(summary.task_ids)} tasks in summary" + ) + + ptrs = [] + positions = [] + sizes = [] + + # Create fast lookup map for receive requests + relevant_reqs: Dict[int, ReceiveRequest] = { + r.task_id: r for r in receive_requests if r.src_pe == recv_batch.src_pe + } + + # Match summary tasks with receive requests + matched_task_ids: List[int] = [] + unmatched_task_ids: List[int] = [] + for t_id, t_size in zip(summary.task_ids, summary.task_sizes): + if t_id in relevant_reqs: + req = relevant_reqs[t_id] + ptr = self.tensor_utils.get_pointer(req.dest_tensor) + ptrs.append(ptr) + positions.append(req.dest_pos) + sizes.append(t_size) # Use sender's size + matched_task_ids.append(t_id) + else: + unmatched_task_ids.append(t_id) + PELogger.error( + f"Iter {i}: Unexpected task {t_id} from PE " + f"{recv_batch.src_pe} - no matching recv request!" + ) + + if unmatched_task_ids: + PELogger.error( + f" Iter {i}: {len(unmatched_task_ids)} unmatched tasks " + f"from PE {recv_batch.src_pe}: {unmatched_task_ids[:10]}" + ) + + # Plan kernel args for unpacking + recv_batch.gpu_plan = self._plan_kernel_args( + ptrs, positions, sizes, is_pack=False, buffer_base=recv_slots[i % 2].data_ptr() + ) + + if recv_batch.gpu_plan is None: + PELogger.error( + f" Iter {i} recv plan: FAILED - no gpu_plan created for " + f"{len(sizes)} tasks from PE {recv_batch.src_pe}" + ) + else: + PELogger.debug( + f" Iter {i} recv plan: {len(sizes)} tasks ← " + f"PE {recv_batch.src_pe}, {recv_batch.total_size} bytes" + ) + displayed_recv_ids = ( + matched_task_ids[:10] + if len(matched_task_ids) <= 10 + else matched_task_ids[:10] + ["..."] + ) + PELogger.debug(f" Recv task IDs: {displayed_recv_ids}") + + def _plan_kernel_args( + self, + ptrs: List[int], + positions: List[int], + sizes: List[int], + is_pack: bool, + buffer_base: int, + ) -> Optional[Tuple[object, object, object, int]]: + """ + Generate GPU-ready pointer arrays for kernel execution. + + Applies 128KB chunking to break large transfers into smaller pieces. + + Args: + ptrs: List of tensor data pointers + positions: List of positions within tensors + sizes: List of transfer sizes + is_pack: True for pack (user->buffer), False for unpack (buffer->user) + buffer_base: Base pointer of the buffer + + Returns: + Tuple of (cp_src_addrs, cp_dst_addrs, cp_sizes, num_chunks) as + CuPy arrays, or None if no work. + """ + h_src_addrs: List[int] = [] + h_dst_addrs: List[int] = [] + h_sizes: List[int] = [] + + packed_offset = 0 + + for ptr, pos, size in zip(ptrs, positions, sizes): + num_chunks = (size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE + + for c in range(num_chunks): + chunk_offset = c * self.CHUNK_SIZE + chunk_size = min(self.CHUNK_SIZE, size - chunk_offset) + + if is_pack: + # Pack: user tensor -> buffer + h_src_addrs.append(ptr + pos + chunk_offset) + h_dst_addrs.append(buffer_base + packed_offset + chunk_offset) + else: + # Unpack: buffer -> user tensor + h_src_addrs.append(buffer_base + packed_offset + chunk_offset) + h_dst_addrs.append(ptr + pos + chunk_offset) + + h_sizes.append(chunk_size) + + packed_offset += size + + total_chunks = len(h_sizes) + if total_chunks == 0: + return None + + # Move to GPU using PyTorch, then convert to CuPy for kernel launching + d_src_addrs = torch.tensor(h_src_addrs, dtype=torch.int64, device="cuda") + d_dst_addrs = torch.tensor(h_dst_addrs, dtype=torch.int64, device="cuda") + d_sizes = torch.tensor(h_sizes, dtype=torch.int64, device="cuda") + + # Convert to CuPy arrays (zero-copy) for kernel launching + cp_src_addrs = cp.asarray(d_src_addrs) + cp_dst_addrs = cp.asarray(d_dst_addrs) + cp_sizes = cp.asarray(d_sizes) + + return (cp_src_addrs, cp_dst_addrs, cp_sizes, total_chunks) diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py b/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py new file mode 100644 index 00000000000..fdeaea33ae5 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/planning/task_segmenter.py @@ -0,0 +1,100 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import logging +from typing import List + +from ..nvshmem_types import MAX_SEGMENT_SIZE, ReceiveRequest, SendRequest + +logger = logging.getLogger(__name__) + +# Constants for ID encoding (from C++ implementation) +REQUEST_ID_BASE = 1000000000 +SEGMENT_ID_MULTIPLIER = 1000 +MAX_REQUESTS = 1000000 +MAX_SEGMENTS_PER_REQUEST = 1000 + + +class TaskSegmenter: + """ + Splits large tasks (>256MB) into smaller segments to fit + into the fixed-size communication slots. + """ + + def _encode_segment_id(self, task_id: int, segment_index: int) -> int: + return REQUEST_ID_BASE + (task_id * SEGMENT_ID_MULTIPLIER) + segment_index + + def _calculate_num_segments(self, size: int) -> int: + return (size + MAX_SEGMENT_SIZE - 1) // MAX_SEGMENT_SIZE + + def _validate_segmentation(self, task_id: int, size: int) -> bool: + num_segments = self._calculate_num_segments(size) + if num_segments > MAX_SEGMENTS_PER_REQUEST: + logger.error( + f"Error: Task {task_id} requires {num_segments} segments, " + f"exceeds max {MAX_SEGMENTS_PER_REQUEST}" + ) + return False + if task_id >= MAX_REQUESTS: + logger.error(f"Error: Task ID {task_id} exceeds max {MAX_REQUESTS}") + return False + return True + + def segment_send_request(self, req: SendRequest) -> List[SendRequest]: + """ + Splits a single send request into multiple segments + if larger than MAX_SEGMENT_SIZE. + """ + if req.size <= MAX_SEGMENT_SIZE: + return [req] + + if not self._validate_segmentation(req.task_id, req.size): + raise ValueError(f"Task {req.task_id} validation failed") + + num_segments = self._calculate_num_segments(req.size) + output_requests: List[SendRequest] = [] + + for i in range(num_segments): + segment_offset = i * MAX_SEGMENT_SIZE + segment_size = min(MAX_SEGMENT_SIZE, req.size - segment_offset) + segment_task_id = self._encode_segment_id(req.task_id, i) + + new_req = SendRequest( + task_id=segment_task_id, + src_tensor=req.src_tensor, + src_pos=req.src_pos + segment_offset, + size=segment_size, + dest_pe=req.dest_pe, + ) + output_requests.append(new_req) + + return output_requests + + def segment_receive_request(self, req: ReceiveRequest) -> List[ReceiveRequest]: + """ + Splits a single receive request into multiple segments + if larger than MAX_SEGMENT_SIZE. + """ + if req.size <= MAX_SEGMENT_SIZE: + return [req] + + if not self._validate_segmentation(req.task_id, req.size): + raise ValueError(f"Task {req.task_id} validation failed") + + num_segments = self._calculate_num_segments(req.size) + output_requests: List[ReceiveRequest] = [] + + for i in range(num_segments): + segment_offset = i * MAX_SEGMENT_SIZE + segment_size = min(MAX_SEGMENT_SIZE, req.size - segment_offset) + segment_task_id = self._encode_segment_id(req.task_id, i) + + new_req = ReceiveRequest( + task_id=segment_task_id, + dest_tensor=req.dest_tensor, + dest_pos=req.dest_pos + segment_offset, + size=segment_size, + src_pe=req.src_pe, + ) + output_requests.append(new_req) + + return output_requests diff --git a/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py b/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py new file mode 100644 index 00000000000..1f2374bc187 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/planning/workload_packer.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from typing import Dict, List + +from ..logger import PELogger +from ..nvshmem_types import MAX_SEGMENT_SIZE, MAX_TASKS_PER_BATCH, SendRequest, WorkloadGroup + + +class WorkloadPacker: + """ + Packs individual SendRequests into WorkloadGroups (batches) + destined for the same PE, respecting size limits. + """ + + def pack_workloads( + self, send_requests: List[SendRequest], n_pes: int + ) -> Dict[int, List[WorkloadGroup]]: + """ + Groups requests by destination PE and packs them into batches. + Returns a map: dest_pe -> list of batches + """ + PELogger.debug(f"Packing {len(send_requests)} send requests for {n_pes} PEs") + workloads: Dict[int, List[WorkloadGroup]] = {} + + # Group requests by destination PE + tasks_by_dest: Dict[int, List[SendRequest]] = {} + for req in send_requests: + tasks_by_dest.setdefault(req.dest_pe, []).append(req) + + # Pack tasks for each destination + for dest_pe in range(n_pes): + if dest_pe not in tasks_by_dest: + workloads[dest_pe] = [] + PELogger.debug(f" Dest PE {dest_pe}: 0 tasks → 0 batches") + continue + + tasks = tasks_by_dest[dest_pe] + workloads[dest_pe] = self._pack_single_destination(tasks, dest_pe) + + if workloads[dest_pe]: + total_size = sum(b.total_size for b in workloads[dest_pe]) + PELogger.debug( + f" Dest PE {dest_pe}: {len(tasks)} tasks → " + f"{len(workloads[dest_pe])} batches, {total_size} bytes total" + ) + else: + PELogger.debug( + f" Dest PE {dest_pe}: {len(tasks)} tasks → 0 batches (empty after packing)" + ) + + return workloads + + def _pack_single_destination( + self, tasks: List[SendRequest], dest_pe: int + ) -> List[WorkloadGroup]: + if not tasks: + return [] + + # Sort tasks by size (descending) for better bin packing efficiency + tasks.sort(key=lambda x: x.size, reverse=True) + + batches: List[WorkloadGroup] = [] + current_batch = WorkloadGroup(dest_pe=dest_pe, tasks=[], total_size=0) + + for task in tasks: + # Check if adding this task would exceed batch constraints + would_exceed_size = current_batch.total_size + task.size > MAX_SEGMENT_SIZE + would_exceed_task_cap = len(current_batch.tasks) >= MAX_TASKS_PER_BATCH + + if (would_exceed_size or would_exceed_task_cap) and current_batch.tasks: + # Finalize current batch + batches.append(current_batch) + task_first_10_string = ", ".join([str(t.task_id) for t in current_batch.tasks[:10]]) + PELogger.debug( + f" Packed batch to PE {dest_pe} idx {len(batches) - 1}: " + f"{task_first_10_string}... (total {len(current_batch.tasks)} tasks)" + ) + # Start new batch + current_batch = WorkloadGroup(dest_pe=dest_pe, tasks=[], total_size=0) + + # Add task to current batch + current_batch.tasks.append(task) + current_batch.total_size += task.size + + # Add final batch if not empty + if current_batch.tasks: + batches.append(current_batch) + + return batches diff --git a/megatron/core/resharding/nvshmem_copy_service/service.py b/megatron/core/resharding/nvshmem_copy_service/service.py new file mode 100644 index 00000000000..631e63ae41b --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/service.py @@ -0,0 +1,408 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Remote Copy Service - Main orchestrator for NVSHMEM-based GPU-to-GPU transfers. + +This service coordinates task segmentation, workload packing, scheduling, + +GPU resource management, and pipelined execution. +""" + +from typing import Dict, List, Optional, Tuple + +try: + import nvshmem.core + + HAVE_NVSHMEM = True +except ImportError: + HAVE_NVSHMEM = False + +import torch.cuda.nvtx as nvtx + +from .core import GPUResourceManager, KernelLauncher, PipelineExecutor +from .logger import PELogger +from .memory import DoubleBufferManager +from .nvshmem_types import ReceiveRequest, ScheduledBatch, SendRequest, WorkloadSummary +from .planning import CommunicationScheduler, GPUExecutionPlanner, TaskSegmenter, WorkloadPacker + + +class RemoteCopyService: + """ + Main service for managing remote GPU-to-GPU data transfers. + + Provides high-level API for registering transfers, scheduling, + and executing pipelined communication with NVSHMEM. + """ + + def __init__(self): + # Core components + self.gpu_resources = GPUResourceManager() + self.buffer_manager = DoubleBufferManager() + self.kernel_launcher = KernelLauncher() + self.pipeline_executor = None # Created after init + + # Planning components + self.task_segmenter = TaskSegmenter() + self.workload_packer = WorkloadPacker() + self.comm_scheduler = CommunicationScheduler() + self.gpu_planner = GPUExecutionPlanner() + + # State + self.send_requests: List[SendRequest] = [] + self.receive_requests: List[ReceiveRequest] = [] + self.iter_schedules: Optional[List[Dict]] = None + self.num_iterations: int = 0 + + # Events for double-buffering + self.pack_events = [] + self.unpack_events = [] + + @property + def my_pe(self) -> int: + """Get this PE's rank.""" + return self.gpu_resources.my_pe + + @property + def n_pes(self) -> int: + """Get total number of PEs.""" + return self.gpu_resources.n_pes + + @property + def device(self): + """Get CUDA device.""" + return self.gpu_resources.device + + @property + def initialized(self) -> bool: + """Check if service is initialized.""" + return self.gpu_resources.initialized + + def init(self, log_level: str = "INFO") -> None: + """ + Initialize the service. + + Sets up NVSHMEM, CUDA device, streams, buffers, and kernels. + Expects to be launched with torchrun. + + Args: + log_level: Logging level (TRACE, DEBUG, INFO, WARN, ERROR) + """ + if not HAVE_NVSHMEM: + raise RuntimeError( + "nvshmem.core is not available. Please install nvshmem to use NVSHMEMCopyService." + ) + + # Initialize GPU resources (NVSHMEM, device, streams) + self.gpu_resources.init() + + # Initialize logger after PE ID is known + PELogger.init(self.my_pe, level=log_level) + PELogger.info(f"Initializing RemoteCopyService on PE {self.my_pe}/{self.n_pes}") + + # Allocate double-buffered send/recv slots + self.buffer_manager.allocate() + PELogger.debug("Allocated double-buffered send/recv slots") + + # Load CUDA kernels + self.kernel_launcher.load_kernels() + PELogger.debug("Loaded CUDA kernels") + + # Cache CuPy stream wrappers for efficient kernel launching + self.kernel_launcher.set_streams( + self.gpu_resources.pack_stream, self.gpu_resources.unpack_stream + ) + PELogger.debug("Cached CuPy stream wrappers") + + # Create pipeline executor with dependencies + self.pipeline_executor = PipelineExecutor( + self.kernel_launcher, self.buffer_manager, self.my_pe + ) + + # Set streams on pipeline executor + self.pipeline_executor.set_streams( + self.gpu_resources.pack_stream, + self.gpu_resources.unpack_stream, + self.gpu_resources.send_stream, + self.gpu_resources.copy_stream, + self.gpu_resources.torch_pack_stream, + self.gpu_resources.torch_unpack_stream, + self.gpu_resources.torch_copy_stream, + ) + PELogger.info("Initialization complete") + + def register_send( + self, task_id: int, src_tensor, src_pos: int, size: int, dest_pe: int + ) -> None: + """ + Register a send operation. + + Args: + task_id: Unique task identifier + src_tensor: Source tensor (PyTorch/CuPy tensor or pointer) + src_pos: Starting position in source tensor + size: Number of bytes to send + dest_pe: Destination PE rank + """ + if dest_pe >= self.n_pes or dest_pe < 0: + PELogger.error(f"Error: Invalid destination PE {dest_pe}") + return + + req = SendRequest(task_id, src_tensor, src_pos, size, dest_pe) + self.send_requests.append(req) + + def register_receive( + self, task_id: int, dest_tensor, dest_pos: int, size: int, src_pe: int + ) -> None: + """ + Register a receive operation. + + Args: + task_id: Unique task identifier + dest_tensor: Destination tensor (PyTorch/CuPy tensor or pointer) + dest_pos: Starting position in destination tensor + size: Number of bytes to receive + src_pe: Source PE rank + """ + if src_pe >= self.n_pes or src_pe < 0: + PELogger.error(f"Error: Invalid source PE {src_pe}") + return + + req = ReceiveRequest(task_id, dest_tensor, dest_pos, size, src_pe) + self.receive_requests.append(req) + + def schedule(self) -> None: + """ + Build execution schedule. + + Can be called once and followed by multiple run() calls for + repeated execution with the same communication pattern. + + Steps: + 1. Segment large tasks into manageable chunks + 2. Pack tasks into batches + 3. Schedule batches to iterations (conflict-free) + 4. Build GPU execution plans (pointer arrays, chunking) + 5. Create synchronization events + """ + if not self.initialized: + raise RuntimeError("RemoteCopyService not initialized") + + PELogger.info( + f"Starting schedule: {len(self.send_requests)} send requests, " + f"{len(self.receive_requests)} receive requests" + ) + + # Step 1: Segment tasks (break large tasks into chunks) + PELogger.debug("Step 1: Segmenting tasks...") + orig_send_count = len(self.send_requests) + orig_recv_count = len(self.receive_requests) + self._segment_tasks() + PELogger.info( + f"Segmented: {orig_send_count} sends → {len(self.send_requests)} segments, " + f"{orig_recv_count} recvs → {len(self.receive_requests)} segments" + ) + + # Step 2: Pack tasks into workload groups + PELogger.debug("Step 2: Packing workloads...") + workloads = self.workload_packer.pack_workloads(self.send_requests, self.n_pes) + total_batches = sum(len(batches) for batches in workloads.values()) + active_pes = sum(1 for batches in workloads.values() if batches) + PELogger.info(f"Packed: {total_batches} batches across {active_pes} destination PEs") + + # Step 3: Schedule workloads to iterations + PELogger.debug("Step 3: Building communication schedule...") + schedule, global_summaries = self.comm_scheduler.build_schedule( + workloads, self.my_pe, self.n_pes + ) + + self.num_iterations = self.comm_scheduler.num_iterations + PELogger.info(f"Scheduled: {total_batches} batches → {self.num_iterations} iterations") + + # Step 4: Prepare iteration schedules + PELogger.debug("Step 4: Preparing iteration schedules...") + self.iter_schedules = self._prepare_iter_schedules( + schedule, workloads, global_summaries, self.num_iterations + ) + + # Step 5: Build GPU execution plans + PELogger.debug("Step 5: Building GPU execution plans...") + self.gpu_planner.create_gpu_plans( + self.iter_schedules, + self.buffer_manager.send_slots, + self.buffer_manager.recv_slots, + self.receive_requests, + ) + + # Step 6: Create double-buffered events + PELogger.debug("Step 6: Creating synchronization events...") + self.pack_events, self.unpack_events = self.gpu_resources.create_events(num_events=2) + self.pipeline_executor.set_events(self.pack_events, self.unpack_events) + + PELogger.info(f"Schedule complete: {self.num_iterations} iterations ready") + + def run(self) -> None: + """ + Execute the scheduled communication. + + Can be called multiple times after a single schedule() call + to repeat the same communication pattern. + """ + # import torch + # torch.save(self.send_requests, f"send_requests_{torch.distributed.get_rank()}.pt") + # torch.save(self.receive_requests, f"receive_requests_{torch.distributed.get_rank()}.pt") + + if not self.initialized: + raise RuntimeError("RemoteCopyService not initialized") + if self.iter_schedules is None: + raise RuntimeError("Must call schedule() before run()") + + PELogger.info(f"Starting execution: {self.num_iterations} iterations") + + # Start timing + nvtx.range_push("RemoteCopyService.run_total") + + # Global barrier before execution + PELogger.debug("Barrier: Synchronizing all PEs before execution") + nvshmem.core.barrier_all(stream=self.gpu_resources.send_stream) + self.gpu_resources.send_stream.sync() + + # Execute pipelined communication + nvtx.range_push("execute_pipeline") + self.pipeline_executor.execute_pipeline(self.iter_schedules, self.num_iterations) + nvtx.range_pop() # execute_pipeline + + # Global barrier after execution + PELogger.debug("Barrier: Synchronizing all PEs after pipeline") + nvshmem.core.barrier_all(stream=self.gpu_resources.send_stream) + + # Process same-PE transfers + self.pipeline_executor.process_self_moves(self.send_requests, self.receive_requests) + + # End timing range + nvtx.range_pop() # RemoteCopyService.run_total + + def clear_requests(self) -> None: + """ + Clear registered requests and schedule. + + Call this before registering a new set of transfers. + """ + self.send_requests = [] + self.receive_requests = [] + self.iter_schedules = None + self.num_iterations = 0 + self.pack_events = [] + self.unpack_events = [] + + def finalize(self) -> None: + """Cleanup resources.""" + PELogger.info("Finalizing RemoteCopyService") + + # Barrier to ensure all PEs are ready to finalize + try: + PELogger.debug("Barrier: Synchronizing all PEs before finalize") + nvshmem.core.barrier_all(stream=self.gpu_resources.send_stream) + self.gpu_resources.send_stream.sync() + except Exception as e: + PELogger.error(f"Error in final barrier: {e}") + + # Free buffers + self.buffer_manager.free() + + # Finalize GPU resources (this will call nvshmem.core.finalize internally) + self.gpu_resources.finalize() + + PELogger.info("RemoteCopyService finalized") + PELogger.shutdown() + + def _segment_tasks(self) -> None: + """Segment tasks into manageable chunks.""" + new_sends: List[SendRequest] = [] + for req in self.send_requests: + segments = self.task_segmenter.segment_send_request(req) + new_sends.extend(segments) + if len(segments) > 1: + PELogger.debug( + f" Segmented send task {req.task_id}: " + f"{req.size} bytes → {len(segments)} segments" + ) + self.send_requests = new_sends + + new_recvs: List[ReceiveRequest] = [] + for req in self.receive_requests: + segments = self.task_segmenter.segment_receive_request(req) + new_recvs.extend(segments) + if len(segments) > 1: + PELogger.debug( + f" Segmented recv task {req.task_id}: " + f"{req.size} bytes → {len(segments)} segments" + ) + self.receive_requests = new_recvs + + def _prepare_iter_schedules( + self, + schedule_batches: Dict[int, List[ScheduledBatch]], + workloads: Dict[int, List], + global_summaries: Dict[Tuple[int, int, int], WorkloadSummary], + num_iterations: int, + ) -> List[Dict]: + """ + Organize schedule into iteration-based structure. + + Returns: + List of dicts with 'send' and 'recv' keys for each iteration + """ + iter_schedules: List[Dict[str, Optional[ScheduledBatch]]] = [] + + for i in range(num_iterations): + sched: Dict[str, Optional[ScheduledBatch]] = {"send": None, "recv": None} + + if i in schedule_batches: + batches = schedule_batches[i] + + for b in batches: + # Skip same-PE transfers (handled separately by process_self_moves) + if b.src_pe == b.dest_pe: + PELogger.debug( + f" Iter {i}: Skipping same-PE batch " f"({b.src_pe} → {b.dest_pe})" + ) + continue + + if b.src_pe == self.my_pe: + # This PE sends in this iteration + b.tasks = workloads[b.dest_pe][b.batch_index].tasks + b.total_size = workloads[b.dest_pe][b.batch_index].total_size + sched["send"] = b + PELogger.debug( + f" Iter {i}: Send to PE {b.dest_pe}, batch " + f"{b.batch_index}, {len(b.tasks)} tasks, " + f"{b.total_size} bytes" + ) + + elif b.dest_pe == self.my_pe: + # This PE receives in this iteration + key = (b.src_pe, b.dest_pe, b.batch_index) + if key in global_summaries: + summary = global_summaries[key] + b.tasks_summary = summary + b.total_size = summary.total_size + else: + PELogger.error( + f" Iter {i}: Missing workload summary for " + f"recv from PE {b.src_pe}, batch {b.batch_index}" + ) + PELogger.error( + " Available keys in global_summaries: " + f"{list(global_summaries.keys())}" + ) + b.tasks_summary = None + b.total_size = 0 + sched["recv"] = b + PELogger.debug( + f" Iter {i}: Recv from PE {b.src_pe}, batch " + f"{b.batch_index}, {b.total_size} bytes" + ) + + iter_schedules.append(sched) + + return iter_schedules diff --git a/megatron/core/resharding/nvshmem_copy_service/validation.py b/megatron/core/resharding/nvshmem_copy_service/validation.py new file mode 100644 index 00000000000..fafb1321024 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/validation.py @@ -0,0 +1,145 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Validation utilities for GPU-to-GPU communication. + +Provides deterministic data generation and validation for verifying + +correctness of communication operations.""" + +from dataclasses import dataclass +from typing import List + +import torch + +from .logger import PELogger + + +@dataclass +class ValidationResult: + """Result of validating a single task.""" + + task_id: int + size: int + passed: bool + src_pe: int = -1 + mismatches: int = 0 + first_mismatch_idx: int = -1 + first_mismatch_expected: int = 0 + first_mismatch_actual: int = 0 + # Scheduling info - which batch/iteration this task was supposed to be handled + batch_index: int = -1 + iteration: int = -1 + + +@dataclass +class ValidationSummary: + """Summary of validation across all tasks.""" + + total_tasks: int + passed_tasks: int + failed_tasks: int + total_bytes: int + results: List[ValidationResult] + + @property + def all_passed(self) -> bool: + """Check if all validated tasks passed.""" + return self.failed_tasks == 0 + + +def generate_deterministic_data(task_id: int, size: int, device: str = "cuda") -> torch.Tensor: + """ + Generate deterministic data pattern for a task. + + Pattern: Each byte = (task_id * 31 + position) % 256 + This creates a unique pattern per task that varies along the data. + + Args: + task_id: Unique task identifier + size: Number of bytes to generate + device: Device to create tensor on ('cuda' or 'cpu') + + Returns: + torch.Tensor of uint8 with deterministic pattern + """ + positions = torch.arange(size, dtype=torch.int64, device=device) + pattern = ((task_id * 31 + positions) % 256).to(torch.uint8) + return pattern + + +def validate_received_data( + task_id: int, tensor: torch.Tensor, size: int, src_pe: int = -1 +) -> ValidationResult: + """ + Validate received data against expected deterministic pattern. + + Args: + task_id: Task identifier to regenerate expected data + tensor: Received tensor to validate + size: Number of bytes to validate + + Returns: + ValidationResult with pass/fail status and details + """ + # Get the data slice to validate + recv_data = tensor[:size] + + # Generate expected pattern on same device + expected = generate_deterministic_data(task_id, size, device=recv_data.device.type) + + # Compare + mismatches_mask = recv_data != expected + num_mismatches = mismatches_mask.sum().item() + + result = ValidationResult( + task_id=task_id, + size=size, + passed=(num_mismatches == 0), + src_pe=src_pe, + mismatches=num_mismatches, + ) + + if num_mismatches > 0: + # Find first mismatch for debugging + first_idx = mismatches_mask.nonzero(as_tuple=True)[0][0].item() + result.first_mismatch_idx = first_idx + result.first_mismatch_expected = expected[first_idx].item() + result.first_mismatch_actual = recv_data[first_idx].item() + + return result + + +def log_validation_summary(summary: ValidationSummary) -> None: + """Log validation summary.""" + if summary.all_passed: + PELogger.info( + "Validation PASSED: %d/%d tasks, %d bytes validated", + summary.passed_tasks, + summary.total_tasks, + summary.total_bytes, + ) + else: + PELogger.error( + "Validation FAILED: %d/%d tasks passed, %d failed", + summary.passed_tasks, + summary.total_tasks, + summary.failed_tasks, + ) + + # Group failures by source PE + failures_by_src = {} + for r in summary.results: + if not r.passed: + failures_by_src.setdefault(r.src_pe, []).append(r) + + PELogger.error(" Failures by source PE:") + for src_pe in sorted(failures_by_src.keys()): + failed_tasks = failures_by_src[src_pe] + task_ids = [r.task_id for r in failed_tasks] + PELogger.error( + " PE %d: %d failed tasks: %s", + src_pe, + len(failed_tasks), + task_ids[:15] if len(task_ids) <= 15 else task_ids[:15] + ["..."], + ) diff --git a/megatron/core/resharding/refit.py b/megatron/core/resharding/refit.py index 491a42b9116..5461b8d3900 100644 --- a/megatron/core/resharding/refit.py +++ b/megatron/core/resharding/refit.py @@ -17,9 +17,45 @@ from .copy_services.base import CopyService from .copy_services.gloo_copy_service import GlooCopyService from .copy_services.nccl_copy_service import NCCLCopyService +from .copy_services.nvshmem_copy_service import NVSHMEMCopyService # Supported refit backend names -RefitBackendName = Literal["nccl", "gloo"] +RefitBackendName = Literal["nccl", "gloo", "nvshmem"] + +# Module-level cache for refit services to avoid repeated allocations +_service_cache: dict[str, CopyService] = {} + + +def get_or_create_service(backend: RefitBackendName) -> CopyService: + """Get or create a cached CopyService instance for the given backend. + + This avoids expensive repeated allocations (especially for NVSHMEM buffers) + when swap_model_weights is called multiple times with the same backend. + """ + if backend in _service_cache: + return _service_cache[backend] + + if backend == "nccl": + service = NCCLCopyService() + elif backend == "gloo": + service = GlooCopyService() + elif backend == "nvshmem": + service = NVSHMEMCopyService() + else: + raise ValueError(f"Unknown backend '{backend}'") + + _service_cache[backend] = service + return service + + +def clear_service_cache(): + """Clear the cached refit services. + + Call this if you need to invalidate the cache, for example when + reinitializing distributed state. + """ + global _service_cache + _service_cache.clear() def swap_model_weights( @@ -37,15 +73,8 @@ def swap_model_weights( service = refit_method reshard_model_weights(src_model, target_model, service=service) elif isinstance(refit_method, str): - if refit_method == "nccl": - service = NCCLCopyService() - reshard_model_weights(src_model, target_model, service=service) - elif refit_method == "gloo": - # Debug / fallback backend: run refit over CPU/Gloo instead of NCCL. - service = GlooCopyService() - reshard_model_weights(src_model, target_model, service=service) - else: - raise ValueError(f"Unknown refit_method '{refit_method}'") + service = get_or_create_service(refit_method) + reshard_model_weights(src_model, target_model, service=service) else: raise TypeError("refit_method must be a str backend name or a CopyService instance") diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 05b2d702aa0..7177ebd00bd 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1977,12 +1977,12 @@ def _add_rl_args(parser): 'Requires --rl-inference-model-unified-memory-level=1.' ), ) - group.add_argument('--refit-method', type=str, default='gloo', - choices=['nccl', 'gloo'], + group.add_argument('--refit-method', type=str, default='nvshmem', + choices=['nccl', 'gloo', 'nvshmem'], help=('Method to refit the model weights between training and inference models during RL. ' 'nccl: use NCCLCopyService to refit using NCCL; ' 'gloo: use GlooCopyService over CPU; ' - )) + 'nvshmem: use NVSHMEMCopyService to refit using the NVSHMEM.')) group.add_argument('--rl-verify-model-weights-swap', action=argparse.BooleanOptionalAction, default=False, help='If set, verify that the model weights were correctly transferred by comparing forward pass outputs on' 'the first swap of model weights.') diff --git a/tests/unit_tests/resharding/test_model_swap.py b/tests/unit_tests/resharding/test_model_swap.py index f5db5cb6185..73296a175ed 100644 --- a/tests/unit_tests/resharding/test_model_swap.py +++ b/tests/unit_tests/resharding/test_model_swap.py @@ -24,6 +24,13 @@ from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils +try: + import nvshmem.core + + has_nvshmem = True +except Exception: + has_nvshmem = False + def _build_pg_collection( tp_size: int, pp_size: int = None, ep_size: int = 1 @@ -116,7 +123,20 @@ def _set_pg_collection(module, tp_group, dp_group): return module -@pytest.mark.parametrize("refit_backend", ["nccl", "gloo"]) +@pytest.mark.parametrize( + "refit_backend", + [ + pytest.param( + "nvshmem", + marks=pytest.mark.skipif( + not has_nvshmem, + reason="nvshmem.core is not available (NVSHMEM Python bindings not installed)", + ), + ), + "nccl", + "gloo", + ], +) @pytest.mark.parametrize( "src_tp,src_pp,src_ep,dst_tp,dst_pp,dst_ep,num_experts", [