Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
2b267a7
initial test
Nov 18, 2025
2f804db
move eveything into megatron
Nov 24, 2025
8a7b508
clean up
Nov 24, 2025
a881fba
clean up
Nov 24, 2025
f45c32d
more refactor
Nov 24, 2025
7ae543d
more cleanup
Nov 24, 2025
52b8b8d
clean up
Nov 24, 2025
f475327
more tests
Nov 24, 2025
5f2372b
merge main in
Nov 25, 2025
2fbd44d
end2end
Dec 1, 2025
2d93e49
clean up
Dec 3, 2025
4e35a2a
Merge branch 'main' of github.com:wdykas/Megatron-LM into refit
Dec 3, 2025
0960061
add tests
Dec 3, 2025
2190b22
refactor
Dec 3, 2025
9641c38
fix tests
Dec 3, 2025
d5d4c47
check changes
Dec 4, 2025
53298e7
check changes for hao
Dec 4, 2025
c72ec6b
cleanup logging
Dec 4, 2025
940615b
clean up
Dec 4, 2025
12dc7ae
add copyyright
Dec 4, 2025
914f396
merge other MR
Dec 5, 2025
8a13950
fix merge
Dec 5, 2025
70272da
cleanup merge
Dec 5, 2025
cc5b44b
remove unwrap
Dec 5, 2025
16281ef
simplify dp round robin
Dec 5, 2025
16dc911
Address comments
Dec 5, 2025
32d8d9f
add fix
Dec 5, 2025
6cd248b
test
Dec 8, 2025
5948930
fix tests
Dec 8, 2025
6098e5b
clean up
Dec 8, 2025
368cf3c
fix nvshmem, all backends working in tests
Dec 12, 2025
ed449fe
fix test
Dec 17, 2025
1bbe010
fix execution mistake
Dec 17, 2025
83d8f4a
verified with runs
Dec 17, 2025
cbd93e0
merge Main
Dec 22, 2025
6aa223e
fix merge
Dec 22, 2025
a4e7d4f
add offload
Jan 4, 2026
23ec73d
lint
Jan 4, 2026
34990da
lint
Jan 5, 2026
ae71cec
fix copywrite
Jan 5, 2026
3bf9528
lint
Jan 5, 2026
3817e28
fix comment
Jan 5, 2026
5c1c58f
fix import guards
Jan 5, 2026
88959e8
fix import errors
Jan 5, 2026
4306ccd
fix import errors
Jan 5, 2026
53091cb
edit
Jan 5, 2026
aa48fc0
rebase
Jan 8, 2026
e0b3fc4
fix formatting
Jan 8, 2026
f2d434b
Merge branch 'main' of github.com:wdykas/Megatron-LM into nvshmem-refit
Jan 14, 2026
1190730
Merge branch 'main' of github.com:wdykas/Megatron-LM into nvshmem-refit
Jan 14, 2026
86077a2
Merge branch 'main' of github.com:wdykas/Megatron-LM into nvshmem-refit
wdykas Jan 26, 2026
90aaf67
Merge branch 'main' into nvshmem-refit
wdykas Jan 27, 2026
1329aa2
fix cache
wdykas Jan 28, 2026
284751a
Merge branch 'nvshmem-refit' of github.com:wdykas/Megatron-LM into nv…
wdykas Jan 28, 2026
d78b710
Merge branch 'main' into nvshmem-refit
wdykas Jan 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion megatron/core/resharding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .execution import execute_reshard_plan
from .planner import build_centralized_reshard_plan
from .refit import reshard_model_weights, swap_model_weights
from .refit import (
clear_service_cache,
get_or_create_service,
reshard_model_weights,
swap_model_weights,
)
from .utils import ParameterMetadata, ReshardPlan, ShardingDescriptor, TransferOp

__all__ = [
"build_centralized_reshard_plan",
"execute_reshard_plan",
"swap_model_weights",
"reshard_model_weights",
"get_or_create_service",
"clear_service_cache",
"ParameterMetadata",
"ShardingDescriptor",
"TransferOp",
Expand Down
3 changes: 2 additions & 1 deletion megatron/core/resharding/copy_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@

from .base import CopyService
from .nccl_copy_service import NCCLCopyService
from .nvshmem_copy_service import NVSHMEMCopyService

__all__ = ["CopyService", "NCCLCopyService"]
__all__ = ["CopyService", "NCCLCopyService", "NVSHMEMCopyService"]
173 changes: 173 additions & 0 deletions megatron/core/resharding/copy_services/nvshmem_copy_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from __future__ import annotations

import logging
from typing import Dict

import torch
import torch.distributed as dist

from ..nvshmem_copy_service import RemoteCopyService
from .base import CopyService

logger = logging.getLogger(__name__)


class NVSHMEMCopyService(CopyService):
"""CopyService implementation backed by NVSHMEM RemoteCopyService."""

def __init__(self):
if not dist.is_initialized():
raise RuntimeError("torch.distributed must be initialized before NVSHMEMCopyService()")

self.rank = dist.get_rank()
self._remote = RemoteCopyService()
# Lazily initialized on first use to avoid side effects at import time
self._initialized = False

# NOTE: keep the original typed tensors here (not uint8 views) so local copies
# preserve shape/strides semantics and avoid byte-offset pitfalls.
self._local_send_ops: Dict[int, torch.Tensor] = {}
self._local_recv_ops: Dict[int, torch.Tensor] = {}
self._local_copy_stream = torch.cuda.Stream()

logger.info("NVSHMEMCopyService constructed")

def _ensure_initialized(self):
if not self._initialized:
self._remote.init(log_level="INFO")
self._initialized = True
logger.info(
"NVSHMEMCopyService initialized: PE %d / %d", self._remote.my_pe, self._remote.n_pes
)

def submit_send(self, src_tensor: torch.Tensor, dest_rank: int):
"""
Basic CopyService API is not rich enough to drive the NVSHMEM planner
(it lacks a globally shared task identifier), so this method is kept
only for interface compatibility and should not be used directly.

The resharding path calls into NVSHMEMCopyService via the
submit_send_with_id/submit_recv_with_id helpers instead.
"""
raise RuntimeError(
"NVSHMEMCopyService.submit_send() is not supported; "
"use submit_send_with_id(...) from execute_reshard_plan."
)

def submit_recv(self, dest_tensor: torch.Tensor, src_rank: int):
raise RuntimeError(
"NVSHMEMCopyService.submit_recv() is not supported; "
"use submit_recv_with_id(...) from execute_reshard_plan."
)

#
# New helper API used from execute_reshard_plan via monkey-patching:
# we avoid changing the existing execute_reshard_plan signature by adding
# a small adapter layer that batches up matched send/recv slices.
#

def submit_send_with_id(self, task_id: int, src_tensor: torch.Tensor, dest_rank: int):
"""Register a send with an explicit, globally shared task_id."""
self._ensure_initialized()

if not src_tensor.is_contiguous():
src_tensor = src_tensor.contiguous()

# Local transfers: keep them out of RemoteCopyService entirely.
if dest_rank == self.rank:
self._local_send_ops[task_id] = src_tensor
return

num_bytes = src_tensor.numel() * src_tensor.element_size()
src_bytes = src_tensor.view(torch.uint8)

logger.debug(
"NVSHMEMCopyService: register_send task_id=%d, %d bytes (%d → %d)",
task_id,
num_bytes,
self.rank,
dest_rank,
)

# Use public API on RemoteCopyService
self._remote.register_send(
task_id=task_id, src_tensor=src_bytes, src_pos=0, size=num_bytes, dest_pe=dest_rank
)

def submit_recv_with_id(self, task_id: int, dest_tensor: torch.Tensor, src_rank: int):
"""Register a recv with an explicit, globally shared task_id."""
self._ensure_initialized()

if not dest_tensor.is_contiguous():
dest_tensor = dest_tensor.contiguous()

# Local transfers: keep them out of RemoteCopyService entirely.
if src_rank == self.rank:
self._local_recv_ops[task_id] = dest_tensor
return

num_bytes = dest_tensor.numel() * dest_tensor.element_size()
dst_bytes = dest_tensor.view(torch.uint8)

logger.debug(
"NVSHMEMCopyService: register_recv task_id=%d, %d bytes (%d ← %d)",
task_id,
num_bytes,
self.rank,
src_rank,
)

self._remote.register_receive(
task_id=task_id, dest_tensor=dst_bytes, dest_pos=0, size=num_bytes, src_pe=src_rank
)

def run(self):
"""
Execute all registered transfer pairs via NVSHMEM.

This converts the registered pairs into RemoteCopyService send/receive
requests, builds a schedule, runs the pipelined NVSHMEM transfer, and
then clears internal state.
"""
self._ensure_initialized()

# 1) Run same-rank copies (match by task_id), like NCCL backend.
if self._local_send_ops or self._local_recv_ops:
missing_sends = set(self._local_recv_ops.keys()) - set(self._local_send_ops.keys())
missing_recvs = set(self._local_send_ops.keys()) - set(self._local_recv_ops.keys())
if missing_sends or missing_recvs:
raise RuntimeError(
"NVSHMEMCopyService: unmatched local ops on rank "
f"{self.rank}: missing_sends={sorted(list(missing_sends))[:10]} "
f"missing_recvs={sorted(list(missing_recvs))[:10]}"
)

with torch.no_grad():
with torch.cuda.stream(self._local_copy_stream):
for task_id, dst in self._local_recv_ops.items():
src = self._local_send_ops[task_id]
if src.numel() != dst.numel() or src.element_size() != dst.element_size():
raise RuntimeError(
"NVSHMEMCopyService: local copy size mismatch on rank "
f"{self.rank} task_id={task_id}: "
f"src=({tuple(src.shape)}, {src.dtype}) "
f"dst=({tuple(dst.shape)}, {dst.dtype})"
)
dst.copy_(src, non_blocking=True)

torch.cuda.current_stream().wait_stream(self._local_copy_stream)
self._local_send_ops.clear()
self._local_recv_ops.clear()

# 2) Execute remote schedule (if any remote sends/recvs were registered).
if not self._remote.send_requests and not self._remote.receive_requests:
logger.info("NVSHMEMCopyService: no remote requests; local copies complete")
return

logger.info("NVSHMEMCopyService: building NVSHMEM schedule and executing")
self._remote.schedule()
self._remote.run()
self._remote.clear_requests()
logger.info("NVSHMEMCopyService: NVSHMEM transfers complete")
29 changes: 29 additions & 0 deletions megatron/core/resharding/nvshmem_copy_service/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

"""
NVSHMEM-based remote copy service and supporting components.

This package is an in-tree integration of the standalone
`nvshmem_copy_service/python` implementation so that Megatron
can use it without relying on an external library.
"""

from . import nvshmem_types
from .core import GPUResourceManager, KernelLauncher, PipelineExecutor
from .memory import DoubleBufferManager, TensorPointerExtractor
from .planning import CommunicationScheduler, GPUExecutionPlanner, TaskSegmenter, WorkloadPacker
from .service import RemoteCopyService

__all__ = [
"RemoteCopyService",
"nvshmem_types",
"GPUResourceManager",
"KernelLauncher",
"PipelineExecutor",
"DoubleBufferManager",
"TensorPointerExtractor",
"CommunicationScheduler",
"GPUExecutionPlanner",
"TaskSegmenter",
"WorkloadPacker",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

"""Core execution components for NVSHMEM operations."""

from .gpu_resource_manager import GPUResourceManager
from .kernel_launcher import KernelLauncher
from .pipeline_executor import PipelineExecutor

__all__ = ["GPUResourceManager", "KernelLauncher", "PipelineExecutor"]
Loading