From ff2a9f89ebfc83d84323377159338677109317fe Mon Sep 17 00:00:00 2001 From: JensenFire Date: Sun, 18 Jan 2026 11:10:43 +0000 Subject: [PATCH] async --- .../update_weight/update_weight_from_rdma.py | 148 +++++++++++++++++- tests/test_weight_transfer_moe_multinode.py | 6 + ...test_weight_transfer_multinode_h100_80g.sh | 34 ++-- 3 files changed, 167 insertions(+), 21 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py index 26cb9068a9..d01c068bff 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py @@ -1,5 +1,7 @@ import dataclasses import logging +import queue +import threading from argparse import Namespace from collections.abc import Callable, Mapping, Sequence @@ -40,6 +42,103 @@ class RemoteWeightInfo: weights_info: dict[str, tuple[int, int, int]] # name -> (remote_address, numel, element_size) +@dataclasses.dataclass +class TransferTask: + """Represents a queued RDMA transfer task.""" + + session_id: str + source_ptrs: list[int] + target_ptrs: list[int] + source_lens: list[int] + engine: TransferEngine + + +class ExecutableQueue: + """ + Asynchronous queue for executing transfer_bundle.execute_each() operations. + Allows overlapping weight loading with RDMA transfer execution. + """ + + def __init__(self): + self._queue = queue.Queue() + self._background_thread = None + self._shutdown_event = threading.Event() + self._tasks_completed = threading.Event() + self._active_tasks = 0 + self._lock = threading.Lock() + + def _background_worker(self): + """Background thread worker that processes queued transfer tasks.""" + while not self._shutdown_event.is_set(): + try: + # Get task with timeout to allow periodic shutdown checks + task = self._queue.get(timeout=0.1) + try: + # Execute the RDMA transfer + logger.info(f"[RDMA] Executing transfer task for session {task.session_id}...") + ret = task.engine.batch_transfer_async_write( + task.session_id, task.source_ptrs, task.target_ptrs, task.source_lens + ) + logger.info(f"[RDMA] Executing transfer task for session {task.session_id} done") + if ret < 0: + logging.error(f"RDMA transfer failed with error code {ret} for session {task.session_id}") + finally: + self._queue.task_done() + with self._lock: + self._active_tasks -= 1 + if self._active_tasks == 0: + self._tasks_completed.set() + + except queue.Empty: + continue + except Exception as e: + logging.error(f"Error in background worker: {e}") + self._queue.task_done() + with self._lock: + self._active_tasks -= 1 + if self._active_tasks == 0: + self._tasks_completed.set() + + def start(self): + """Start the background worker thread.""" + if self._background_thread is None or not self._background_thread.is_alive(): + self._shutdown_event.clear() + self._tasks_completed.clear() + self._background_thread = threading.Thread(target=self._background_worker, daemon=True) + self._background_thread.start() + + def enqueue_task(self, task: TransferTask): + """Add a transfer task to the queue.""" + with self._lock: + self._active_tasks += 1 + self._tasks_completed.clear() + self._queue.put(task) + + def wait_all_complete(self, timeout=30.0): + """Wait for all queued tasks to complete before proceeding.""" + if self._active_tasks == 0: + return True + + # Wait for the completion event first + if not self._tasks_completed.wait(timeout): + return False + + # Additionally wait for the queue to be fully processed to avoid race conditions + # This ensures all tasks have been processed by calling task_done() + try: + self._queue.join() # Wait until all items in the queue have been processed + return True + except Exception as e: + logging.error(f"Error during queue join: {e}") + return False + + def shutdown(self): + """Shutdown the background worker thread.""" + self._shutdown_event.set() + if self._background_thread and self._background_thread.is_alive(): + self._background_thread.join(timeout=5.0) + + @dataclasses.dataclass class TransferBundle: model_replica: Sequence[torch.nn.Module] @@ -50,7 +149,12 @@ class TransferBundle: def add_remote_session(self, remote_info: RemoteWeightInfo) -> None: self.remote_weight_infos.append(remote_info) - def execute_each(self, names: Sequence[str]) -> None: + def execute_each(self, names: Sequence[str], executable_queue: ExecutableQueue = None) -> None: + """ + Execute transfer for specific parameter names. + If executable_queue is provided, tasks are queued for async execution. + Otherwise, falls back to immediate execution for backward compatibility. + """ # Find local pointers and lengths for the given names source_ptrs, source_lens = [], [] for name in names: @@ -66,8 +170,19 @@ def execute_each(self, names: Sequence[str]) -> None: for name in names: target_ptrs.append(remote_weights_info[name][0]) # remote address - # Batch transfer weights through RDMA - _ = self.engine.batch_transfer_async_write(session_id, source_ptrs, target_ptrs, source_lens) + if executable_queue is not None: + # Queue the transfer task for async execution + task = TransferTask( + session_id=session_id, + source_ptrs=source_ptrs.copy(), + target_ptrs=target_ptrs.copy(), + source_lens=source_lens.copy(), + engine=self.engine, + ) + executable_queue.enqueue_task(task) + else: + # Immediate execution (backward compatibility) + _ = self.engine.batch_transfer_async_write(session_id, source_ptrs, target_ptrs, source_lens) def execute(self) -> None: # Execute transfer for each target session using this replica. @@ -122,6 +237,10 @@ def __init__( self._model_on_cpu = False self.pipelined_transfer = args.rdma_pipelined_transfer + # Initialize executable queue for async transfer operations + self.executable_queue = ExecutableQueue() + self.executable_queue.start() + def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], rollout_engine_lock: ActorHandle ) -> None: @@ -308,6 +427,8 @@ def _update_bucket_weights_from_remote( ) -> None: """ The RDMA P2P weight update is implemented as a single side write, meaning the trainer writes its weights directly to the rollout engines' memory. + Now uses an executable queue to make transfer_bundle.execute_each() operations asynchronous, + allowing overlap between weight loading and RDMA transfers. """ if not self._is_source or not converted_named_tensors: @@ -320,17 +441,36 @@ def _update_bucket_weights_from_remote( for transfer_bundle in self.engines.values(): updated_name = transfer_bundle.model_replica.load_weights(converted_named_tensors) if self.pipelined_transfer: - transfer_bundle.execute_each(updated_name) + # Use executable queue for async transfer operations + transfer_bundle.execute_each(updated_name, self.executable_queue) converted_named_tensors.clear() + def __del__(self): + """Cleanup resources when the instance is destroyed.""" + if hasattr(self, "executable_queue"): + self.executable_queue.shutdown() + def finish_transfer_task(self) -> None: if not self._is_source: return + # Execute transfer for each engine replica. if not self.pipelined_transfer: for transfer_bundle in self.engines.values(): transfer_bundle.execute() + else: + # Wait for all queued transfer tasks to complete before cpu offloading + logging.info("[RDMA] Waiting for all queued transfer tasks to complete...") + # NOTE: set the timeout? + assert self.executable_queue.wait_all_complete( + timeout=30.0 + ), "[RDMA] Some transfer tasks may not have completed within timeout" + + # Add CUDA synchronization to ensure all asynchronous RDMA operations are complete + # This is critical to prevent race conditions with memory offloading + logging.info("[RDMA] Synchronizing CUDA to ensure all asynchronous operations complete...") + torch.cuda.synchronize() # Offload model replicas from memory after transfer. if not self._model_on_cpu: diff --git a/tests/test_weight_transfer_moe_multinode.py b/tests/test_weight_transfer_moe_multinode.py index ec92efb40c..5f1f3d9a51 100644 --- a/tests/test_weight_transfer_moe_multinode.py +++ b/tests/test_weight_transfer_moe_multinode.py @@ -206,6 +206,12 @@ def execute(args: ScriptArgs): "--profile-update-weight-end 3 " "--tensorboard-dir /root/profiler_logs/ " ) + profile_args = ( + "--use-pytorch-profiler-update-weight " + "--profile-update-weight-start 0 " + "--profile-update-weight-end 6 " + "--tensorboard-dir /root/profiler_logs/ " + ) train_args = ( f"{ckpt_args} " diff --git a/tests/test_weight_transfer_multinode_h100_80g.sh b/tests/test_weight_transfer_multinode_h100_80g.sh index e83bdbeb71..c5bbf8de5e 100644 --- a/tests/test_weight_transfer_multinode_h100_80g.sh +++ b/tests/test_weight_transfer_multinode_h100_80g.sh @@ -3,51 +3,51 @@ # 1 training node, 1 rollout node # NODE 0: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 0 --nnodes 2 2>&1 | tee temp2_moe_2node.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --nnodes 2 2>&1 | tee temp2_moe_2node.log # NODE 1: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 1 --nnodes 2 2>&1 | tee temp2_moe_2node.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --nnodes 2 2>&1 | tee temp2_moe_2node.log # enable training dp = 2 # NODE 0: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 0 --nnodes 2 --train-tp 4 2>&1 | tee temp2_moe_2node.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --nnodes 2 --train-tp 4 2>&1 | tee temp2_moe_2node.log # NODE 1: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 1 --nnodes 2 --train-tp 4 2>&1 | tee temp2_moe_2node.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --nnodes 2 --train-tp 4 2>&1 | tee temp2_moe_2node.log # enable training ep = 8, dp = 2, attn-tp = 4 # NODE 0: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 0 --nnodes 2 --train-tp 4 --train-ep 8 --train-etp 1 2>&1 | tee temp2_moe_2node.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --nnodes 2 --train-tp 4 --train-ep 8 --train-etp 1 2>&1 | tee temp2_moe_2node.log # NODE 1: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 1 --nnodes 2 --train-tp 4 --train-ep 8 --train-etp 1 2>&1 | tee temp2_moe_2node.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --nnodes 2 --train-tp 4 --train-ep 8 --train-etp 1 2>&1 | tee temp2_moe_2node.log # enable training pp = 2 -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 0 --nnodes 2 --train-pp 2 --train-tp 4 --train-etp 4 --decoder-last-pipeline-num-layers 14 2>&1 | tee temp2_moe_2node.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --nnodes 2 --train-pp 2 --train-tp 4 --train-etp 4 --decoder-last-pipeline-num-layers 14 2>&1 | tee temp2_moe_2node.log # NODE 1: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --head-node-ip h100-069-001 --node-rank 1 --nnodes 2 --train-pp 2 --train-tp 4 --train-etp 4 --decoder-last-pipeline-num-layers 14 2>&1 | tee temp2_moe_2node.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --nnodes 2 --train-pp 2 --train-tp 4 --train-etp 4 --decoder-last-pipeline-num-layers 14 2>&1 | tee temp2_moe_2node.log # 2 training nodes, 1 rollout node # NODE 0: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 0 --num-train-gpus 16 --train-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_2training.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --num-train-gpus 16 --train-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_2training.log # NODE 1: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 1 --num-train-gpus 16 --train-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_2training.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --num-train-gpus 16 --train-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_2training.log # NODE 2: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 2 --num-train-gpus 16 --train-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_2training.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 2 --num-train-gpus 16 --train-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_2training.log # 1 training node, 2 rollout nodes # NODE 0: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 0 --num-rollout-gpus 16 --sglang-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --num-rollout-gpus 16 --sglang-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log # NODE 1: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 1 --num-rollout-gpus 16 --sglang-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --num-rollout-gpus 16 --sglang-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log # NODE 2: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 2 --num-rollout-gpus 16 --sglang-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 2 --num-rollout-gpus 16 --sglang-tp 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log # 1 training node, 2 rollout nodes, rollout_ep = 16 # NODE 0: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 0 --num-rollout-gpus 16 --sglang-tp 16 --sglang-ep 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 0 --num-rollout-gpus 16 --sglang-tp 16 --sglang-ep 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log # NODE 1: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 1 --num-rollout-gpus 16 --sglang-tp 16 --sglang-ep 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 1 --num-rollout-gpus 16 --sglang-tp 16 --sglang-ep 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log # NODE 2: -MASTER_ADDR=h100-069-001 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --mode rdma --head-node-ip h100-069-001 --node-rank 2 --num-rollout-gpus 16 --sglang-tp 16 --sglang-ep 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log +MASTER_ADDR=h100-139-003 python /root/slime/tests/test_weight_transfer_moe_multinode.py --multinode --mode rdma --pipelined-transfer --head-node-ip h100-139-003 --node-rank 2 --num-rollout-gpus 16 --sglang-tp 16 --sglang-ep 16 --nnodes 3 2>&1 | tee temp2_moe_3node_1training.log