Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 144 additions & 4 deletions slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import dataclasses
import logging
import queue
import threading
from argparse import Namespace
from collections.abc import Callable, Mapping, Sequence

Expand Down Expand Up @@ -40,6 +42,103 @@ class RemoteWeightInfo:
weights_info: dict[str, tuple[int, int, int]] # name -> (remote_address, numel, element_size)


@dataclasses.dataclass
class TransferTask:
"""Represents a queued RDMA transfer task."""

session_id: str
source_ptrs: list[int]
target_ptrs: list[int]
source_lens: list[int]
engine: TransferEngine


class ExecutableQueue:
"""
Asynchronous queue for executing transfer_bundle.execute_each() operations.
Allows overlapping weight loading with RDMA transfer execution.
"""

def __init__(self):
self._queue = queue.Queue()
self._background_thread = None
self._shutdown_event = threading.Event()
self._tasks_completed = threading.Event()
self._active_tasks = 0
self._lock = threading.Lock()

def _background_worker(self):
"""Background thread worker that processes queued transfer tasks."""
while not self._shutdown_event.is_set():
try:
# Get task with timeout to allow periodic shutdown checks
task = self._queue.get(timeout=0.1)
try:
# Execute the RDMA transfer
logger.info(f"[RDMA] Executing transfer task for session {task.session_id}...")
ret = task.engine.batch_transfer_async_write(
task.session_id, task.source_ptrs, task.target_ptrs, task.source_lens
)
logger.info(f"[RDMA] Executing transfer task for session {task.session_id} done")
if ret < 0:
logging.error(f"RDMA transfer failed with error code {ret} for session {task.session_id}")
finally:
self._queue.task_done()
with self._lock:
self._active_tasks -= 1
if self._active_tasks == 0:
self._tasks_completed.set()

except queue.Empty:
continue
except Exception as e:
logging.error(f"Error in background worker: {e}")
self._queue.task_done()
with self._lock:
self._active_tasks -= 1
if self._active_tasks == 0:
self._tasks_completed.set()

def start(self):
"""Start the background worker thread."""
if self._background_thread is None or not self._background_thread.is_alive():
self._shutdown_event.clear()
self._tasks_completed.clear()
self._background_thread = threading.Thread(target=self._background_worker, daemon=True)
self._background_thread.start()

def enqueue_task(self, task: TransferTask):
"""Add a transfer task to the queue."""
with self._lock:
self._active_tasks += 1
self._tasks_completed.clear()
self._queue.put(task)

def wait_all_complete(self, timeout=30.0):
"""Wait for all queued tasks to complete before proceeding."""
if self._active_tasks == 0:
return True

# Wait for the completion event first
if not self._tasks_completed.wait(timeout):
return False

# Additionally wait for the queue to be fully processed to avoid race conditions
# This ensures all tasks have been processed by calling task_done()
try:
self._queue.join() # Wait until all items in the queue have been processed
return True
except Exception as e:
logging.error(f"Error during queue join: {e}")
return False

def shutdown(self):
"""Shutdown the background worker thread."""
self._shutdown_event.set()
if self._background_thread and self._background_thread.is_alive():
self._background_thread.join(timeout=5.0)


@dataclasses.dataclass
class TransferBundle:
model_replica: Sequence[torch.nn.Module]
Expand All @@ -50,7 +149,12 @@ class TransferBundle:
def add_remote_session(self, remote_info: RemoteWeightInfo) -> None:
self.remote_weight_infos.append(remote_info)

def execute_each(self, names: Sequence[str]) -> None:
def execute_each(self, names: Sequence[str], executable_queue: ExecutableQueue = None) -> None:
"""
Execute transfer for specific parameter names.
If executable_queue is provided, tasks are queued for async execution.
Otherwise, falls back to immediate execution for backward compatibility.
"""
# Find local pointers and lengths for the given names
source_ptrs, source_lens = [], []
for name in names:
Expand All @@ -66,8 +170,19 @@ def execute_each(self, names: Sequence[str]) -> None:
for name in names:
target_ptrs.append(remote_weights_info[name][0]) # remote address

# Batch transfer weights through RDMA
_ = self.engine.batch_transfer_async_write(session_id, source_ptrs, target_ptrs, source_lens)
if executable_queue is not None:
# Queue the transfer task for async execution
task = TransferTask(
session_id=session_id,
source_ptrs=source_ptrs.copy(),
target_ptrs=target_ptrs.copy(),
source_lens=source_lens.copy(),
engine=self.engine,
)
executable_queue.enqueue_task(task)
else:
# Immediate execution (backward compatibility)
_ = self.engine.batch_transfer_async_write(session_id, source_ptrs, target_ptrs, source_lens)

def execute(self) -> None:
# Execute transfer for each target session using this replica.
Expand Down Expand Up @@ -122,6 +237,10 @@ def __init__(
self._model_on_cpu = False
self.pipelined_transfer = args.rdma_pipelined_transfer

# Initialize executable queue for async transfer operations
self.executable_queue = ExecutableQueue()
self.executable_queue.start()

def connect_rollout_engines(
self, rollout_engines: Sequence[ActorHandle], rollout_engine_lock: ActorHandle
) -> None:
Expand Down Expand Up @@ -308,6 +427,8 @@ def _update_bucket_weights_from_remote(
) -> None:
"""
The RDMA P2P weight update is implemented as a single side write, meaning the trainer writes its weights directly to the rollout engines' memory.
Now uses an executable queue to make transfer_bundle.execute_each() operations asynchronous,
allowing overlap between weight loading and RDMA transfers.
"""

if not self._is_source or not converted_named_tensors:
Expand All @@ -320,17 +441,36 @@ def _update_bucket_weights_from_remote(
for transfer_bundle in self.engines.values():
updated_name = transfer_bundle.model_replica.load_weights(converted_named_tensors)
if self.pipelined_transfer:
transfer_bundle.execute_each(updated_name)
# Use executable queue for async transfer operations
transfer_bundle.execute_each(updated_name, self.executable_queue)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's default to execute with the queue if it's clearly better


converted_named_tensors.clear()

def __del__(self):
"""Cleanup resources when the instance is destroyed."""
if hasattr(self, "executable_queue"):
self.executable_queue.shutdown()

def finish_transfer_task(self) -> None:
if not self._is_source:
return

# Execute transfer for each engine replica.
if not self.pipelined_transfer:
for transfer_bundle in self.engines.values():
transfer_bundle.execute()
else:
# Wait for all queued transfer tasks to complete before cpu offloading
logging.info("[RDMA] Waiting for all queued transfer tasks to complete...")
# NOTE: set the timeout?
assert self.executable_queue.wait_all_complete(
timeout=30.0
), "[RDMA] Some transfer tasks may not have completed within timeout"

# Add CUDA synchronization to ensure all asynchronous RDMA operations are complete
# This is critical to prevent race conditions with memory offloading
logging.info("[RDMA] Synchronizing CUDA to ensure all asynchronous operations complete...")
torch.cuda.synchronize()

# Offload model replicas from memory after transfer.
if not self._model_on_cpu:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_weight_transfer_moe_multinode.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ def execute(args: ScriptArgs):
"--profile-update-weight-end 3 "
"--tensorboard-dir /root/profiler_logs/ "
)
profile_args = (
"--use-pytorch-profiler-update-weight "
"--profile-update-weight-start 0 "
"--profile-update-weight-end 6 "
"--tensorboard-dir /root/profiler_logs/ "
)

train_args = (
f"{ckpt_args} "
Expand Down
34 changes: 17 additions & 17 deletions tests/test_weight_transfer_multinode_h100_80g.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,51 @@

# 1 training node, 1 rollout node
# NODE 0:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 0 --nnodes 2 2>&1 | tee temp2_moe_2node.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --nnodes 2 2>&1 | tee temp2_moe_2node.log
# NODE 1:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 1 --nnodes 2 2>&1 | tee temp2_moe_2node.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --nnodes 2 2>&1 | tee temp2_moe_2node.log

# enable training dp = 2
# NODE 0:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 0 --nnodes 2 --train-tp 4 2>&1 | tee temp2_moe_2node.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --nnodes 2 --train-tp 4 2>&1 | tee temp2_moe_2node.log
# NODE 1:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 1 --nnodes 2 --train-tp 4 2>&1 | tee temp2_moe_2node.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --nnodes 2 --train-tp 4 2>&1 | tee temp2_moe_2node.log

# enable training ep = 8, dp = 2, attn-tp = 4
# NODE 0:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 0 --nnodes 2 --train-tp 4 --train-ep 8 --train-etp 1 2>&1 | tee temp2_moe_2node.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --nnodes 2 --train-tp 4 --train-ep 8 --train-etp 1 2>&1 | tee temp2_moe_2node.log
# NODE 1:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 1 --nnodes 2 --train-tp 4 --train-ep 8 --train-etp 1 2>&1 | tee temp2_moe_2node.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --nnodes 2 --train-tp 4 --train-ep 8 --train-etp 1 2>&1 | tee temp2_moe_2node.log

# enable training pp = 2
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 0 --nnodes 2 --train-pp 2 --train-tp 4 --train-etp 4 --decoder-last-pipeline-num-layers 14 2>&1 | tee temp2_moe_2node.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --nnodes 2 --train-pp 2 --train-tp 4 --train-etp 4 --decoder-last-pipeline-num-layers 14 2>&1 | tee temp2_moe_2node.log
# NODE 1:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 1 --nnodes 2 --train-pp 2 --train-tp 4 --train-etp 4 --decoder-last-pipeline-num-layers 14 2>&1 | tee temp2_moe_2node.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --nnodes 2 --train-pp 2 --train-tp 4 --train-etp 4 --decoder-last-pipeline-num-layers 14 2>&1 | tee temp2_moe_2node.log


# 2 training nodes, 1 rollout node
# NODE 0:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 0 --num-train-gpus 16 --train-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_2training.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --num-train-gpus 16 --train-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_2training.log
# NODE 1:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 1 --num-train-gpus 16 --train-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_2training.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --num-train-gpus 16 --train-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_2training.log
# NODE 2:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 2 --num-train-gpus 16 --train-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_2training.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 2 --num-train-gpus 16 --train-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_2training.log


# 1 training node, 2 rollout nodes
# NODE 0:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 0 --num-rollout-gpus 16 --sglang-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --num-rollout-gpus 16 --sglang-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log
# NODE 1:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 1 --num-rollout-gpus 16 --sglang-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --num-rollout-gpus 16 --sglang-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log
# NODE 2:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 2 --num-rollout-gpus 16 --sglang-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 2 --num-rollout-gpus 16 --sglang-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log

# 1 training node, 2 rollout nodes, rollout_ep = 16
# NODE 0:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 0 --num-rollout-gpus 16 --sglang-tp 16 --sglang-ep 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --num-rollout-gpus 16 --sglang-tp 16 --sglang-ep 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log
# NODE 1:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 1 --num-rollout-gpus 16 --sglang-tp 16 --sglang-ep 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --num-rollout-gpus 16 --sglang-tp 16 --sglang-ep 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log
# NODE 2:
MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 2 --num-rollout-gpus 16 --sglang-tp 16 --sglang-ep 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log
MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 2 --num-rollout-gpus 16 --sglang-tp 16 --sglang-ep 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log