Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions slime/backends/megatron_utils/update_weight/remote_transfer_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,21 @@ def _get_parallelism(self, args: Namespace) -> None:
self._dp_rank, self._dp_size = mpu.get_data_parallel_rank(
with_context_parallel=True
), mpu.get_data_parallel_world_size(with_context_parallel=True)
self._edp_rank, self._edp_size = mpu.get_expert_data_parallel_rank(), mpu.get_expert_data_parallel_world_size()

# Gather the target (rollout engine count and parallelism) information.
self._rollout_tp_size = args.sglang_tp_size
self._rollout_dp_size = args.sglang_dp_size
self._rollout_ep_size = args.sglang_ep_size
self._rollout_attn_tp_size = self._rollout_tp_size // self._rollout_dp_size
self._rollout_moe_tp_size = self._rollout_tp_size // self._rollout_ep_size

# EP and PP sizes are not tested and likely miss functionalities.
self._rollout_pp_size = args.sglang_pp_size
if self._rollout_ep_size != 1 or self._rollout_pp_size != 1:
if self._rollout_pp_size != 1:
raise NotImplementedError("Rollout expert and pipeline parallelisms are not supported yet.")
self._num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node)
self._rollout_engine_count = args.rollout_num_gpus // self._num_gpu_per_engine
self._rollout_num_gpu_per_engine = args.rollout_num_gpus_per_engine
self._rollout_engine_count = args.rollout_num_gpus // self._rollout_num_gpu_per_engine
self._rollout_num_gpus = args.rollout_num_gpus
logger.info(
f"RemoteTransferPlan initialized: mode={self.mode}, pp_rank={self._pp_rank}/{self._pp_size}, tp_rank={self._tp_rank}/{self._tp_size}, "
Expand All @@ -99,14 +103,15 @@ def _get_parallelism(self, args: Namespace) -> None:
logger.info(
f"Rollout engine count: {self._rollout_engine_count}, tp_size={self._rollout_tp_size}, ep_size={self._rollout_ep_size}, dp_size={self._rollout_dp_size}"
)
# Calculate the non-expert dp/ expert dp from training side
# Reference: `Megatron-LM/megatron/core/parallel_state.py`

self._gathered_dp_size = self._dp_size * self._tp_size
self._gathered_dp_rank = self._dp_rank * self._tp_size + self._tp_rank
# TODO: If I understand correctly the final size should be same as we now only have pp - dp dimensions for both param groups?
expert_tp_size = self._ep_size * self._etp_size
self._gathered_expert_dp_size = self._dp_size * expert_tp_size
self._gathered_expert_dp_size = self._edp_size * expert_tp_size
Comment thread
JensenFire marked this conversation as resolved.
self._gathered_expert_dp_rank = (
self._dp_rank * expert_tp_size + self._ep_rank * self._etp_size + self._etp_rank
self._edp_rank * expert_tp_size + self._ep_rank * self._etp_size + self._etp_rank
)
logger.info(
f"Gathered dp_size={self._gathered_dp_size}, gathered expert dp_size={self._gathered_expert_dp_size}"
Expand All @@ -116,6 +121,7 @@ def _get_parallelism(self, args: Namespace) -> None:
)

self._rank = self._gathered_dp_rank
self._size = self._gathered_dp_size

def get_nccl_group(self) -> str:
"""
Expand Down Expand Up @@ -148,19 +154,21 @@ def plan_p2p(self) -> list[TransferTaskP2PMeta]:
"""

all_targets = [
(m_idx, k_idx) for m_idx in range(self._rollout_engine_count) for k_idx in range(self._num_gpu_per_engine)
(m_idx, k_idx)
for m_idx in range(self._rollout_engine_count)
for k_idx in range(self._rollout_num_gpu_per_engine)
]
# Assignments: source_rank -> {engin_rank: [engine_indices]}
assignements = defaultdict(lambda: defaultdict(list))
# First round robin assignment
i = -1
for source_rank, (idx, target) in zip(range(self._gathered_dp_size), enumerate(all_targets), strict=False):
for source_rank, (idx, target) in zip(range(self._size), enumerate(all_targets), strict=False):
i = idx
m_idx, k_idx = target
assignements[source_rank][k_idx].append(m_idx)

def count_engine_index_assignments(k_idx: int) -> int:
return [len(assignements[source][k_idx]) for source in range(self._gathered_dp_size)]
return [len(assignements[source][k_idx]) for source in range(self._size)]

# Reminder assignment by least_assigned_source
cur_source_index = 0
Expand All @@ -174,7 +182,7 @@ def count_engine_index_assignments(k_idx: int) -> int:
_, select_source = min((val, idx) for (idx, val) in enumerate(counted) if val > 0)
# Else go back to round robin.
else:
select_source = cur_source_index % self._gathered_dp_size
select_source = cur_source_index % self._size
cur_source_index += 1
assignements[select_source][k_idx].append(m_idx)

Expand All @@ -191,6 +199,20 @@ def count_engine_index_assignments(k_idx: int) -> int:
)
return transfer_tasks

def tp_conversion(self, targeted_tp_rank: int) -> dict[str, int]:
"""
Given tp_rank, return the rank of attn_tp/dp/ep/moe-tp.
"""
parallel_rank_dict = {}
# attn_tp/dp
# NOTE: iiuc, in sglang, _num_gpu_per_engine == targeted_tp_size?
parallel_rank_dict["attn_tp_rank"] = targeted_tp_rank % self._rollout_attn_tp_size
parallel_rank_dict["dp_rank"] = targeted_tp_rank // self._rollout_attn_tp_size
# moe-tp/ep
parallel_rank_dict["moe_tp_rank"] = targeted_tp_rank % self._rollout_moe_tp_size
parallel_rank_dict["ep_rank"] = targeted_tp_rank // self._rollout_moe_tp_size
return parallel_rank_dict

def is_source(self) -> bool:
"""
Determine if the current rank needs to initiate weight transfer.
Expand Down
156 changes: 146 additions & 10 deletions slime/backends/megatron_utils/update_weight/update_weight_from_rdma.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Callable, Mapping, Sequence

import ray
import sglang.srt.distributed.parallel_state as sglang_parallel_state
import sglang.srt.layers.dp_attention as sglang_dp_attention
import sglang.srt.server_args as sglang_server_args
import torch
Expand Down Expand Up @@ -167,13 +168,29 @@ def connect_rollout_engines(
session_id = targets_to_session_id[(target.engine_ind, target.engine_rank)]
remote_info = RemoteWeightInfo(session_id, self.remote_weight_infos_by_session_id[session_id])
# Instantiate the local model replicas and a corresponding transfer engine with memory registry for each type of rollout shard.
# TODO verify:
# - if sglang dp is enabled, then attn_tp is equal to tp // dp
# - if sglang ep is enabled, then moe-tp is equal to tp // ep
# generally tp * pp should be equal to the world_size
if target.engine_rank not in self.engines:
transfer_engine = self._create_transfer_engine()
parallel_rank_dict = self.transfer_plan.tp_conversion(target.engine_rank)
logger.info(
f"[RDMA] Creating model replica for engine rank {target.engine_rank} with rank dict {parallel_rank_dict}"
)
model_replica = self._create_inference_replica(
self.args.hf_checkpoint,
pp_shard=target.source_shard,
target_rank=target.engine_rank,
target_rank=target.engine_rank, # NOTE: here we assume that sglang_tp == world_size when pp_size == 1
target_tp=self.args.rollout_num_gpus_per_engine,
dp_rank=parallel_rank_dict["dp_rank"],
dp_size=self.transfer_plan._rollout_dp_size,
attn_tp_rank=parallel_rank_dict["attn_tp_rank"],
attn_tp_size=self.transfer_plan._rollout_attn_tp_size,
ep_rank=parallel_rank_dict["ep_rank"],
ep_size=self.transfer_plan._rollout_ep_size,
moe_tp_rank=parallel_rank_dict["moe_tp_rank"],
moe_tp_size=self.transfer_plan._rollout_moe_tp_size,
server_args=self.session_id_to_server_args[session_id],
)
print_memory(f"[RDMA] After model replica at {target.engine_rank}")
Expand Down Expand Up @@ -216,7 +233,20 @@ def _create_transfer_engine(self) -> TransferEngine:
return transfer_engine

def _create_inference_replica(
self, model_path: str, pp_shard: int, target_rank: int, target_tp: int, server_args: ServerArgs
self,
model_path: str,
pp_shard: int,
target_rank: int,
target_tp: int,
dp_rank: int,
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dp rank and size is irrelevant here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering we pass attn_tp_rank/attn_tp_size. it's possible that we could delete them.

One potential concerning point is this part

sglang_dp_attention._ATTN_DP_RANK = 0
sglang_dp_attention._ATTN_DP_SIZE = 1

shall we need to change it into :

sglang_dp_attention._ATTN_DP_RANK = self.dp_rank
sglang_dp_attention._ATTN_DP_SIZE = self.dp_size

or, we could delete all of them after sglang's mocking context enabled

dp_size: int,
attn_tp_rank: int,
attn_tp_size: int,
ep_rank: int,
ep_size: int,
moe_tp_rank: int,
moe_tp_size: int,
server_args: ServerArgs,
):
"""
Create model replica for target rank with correct tp settings.
Expand All @@ -235,9 +265,24 @@ def _create_inference_replica(

# Mock the distributed environment to get correct weight shapes
logger.info(
f" Engine replica: {target_rank} tp {target_tp} pp_shard {pp_shard}, model pp sharding not implemented "
f" Engine replica: {target_rank} tp {target_tp} pp_shard {pp_shard}, model pp sharding not implemented, "
f" dp_rank {dp_rank}/{dp_size}, attn_tp_rank {attn_tp_rank}/{attn_tp_size}, "
f" ep_rank {ep_rank}/{ep_size}, moe_tp_rank {moe_tp_rank}/{moe_tp_size} "
)
with MockSglangDistributedContext(tp_size=target_tp, tp_rank=target_rank, server_args=server_args):
# TODO: should take attn_tp/ep/dp into account in the future.
with MockSglangDistributedContext(
tp_size=target_tp,
tp_rank=target_rank,
dp_rank=dp_rank,
dp_size=dp_size,
attn_tp_rank=attn_tp_rank,
attn_tp_size=attn_tp_size,
ep_rank=ep_rank,
ep_size=ep_size,
moe_tp_rank=moe_tp_rank,
moe_tp_size=moe_tp_size,
server_args=server_args,
):
model = get_model(
model_config=model_config,
load_config=load_config,
Expand Down Expand Up @@ -297,16 +342,35 @@ def finish_transfer_task(self) -> None:


class MockSglangDistributedContext:
def __init__(self, tp_size: int, tp_rank: int, server_args: ServerArgs):
def __init__(
self,
tp_size: int,
tp_rank: int,
dp_rank: int,
dp_size: int,
attn_tp_rank: int,
attn_tp_size: int,
ep_rank: int,
ep_size: int,
moe_tp_rank: int,
moe_tp_size: int,
server_args: ServerArgs,
):
"""
TODO: Extend this to support ep, and dp attention?
"""
self.tp_size = tp_size
self.tp_rank = tp_rank
self.pp_size = 1
self.pp_rank = 0
self.attn_tp_size = tp_size
self.attn_tp_rank = tp_rank
self.attn_tp_size = attn_tp_size
self.attn_tp_rank = attn_tp_rank
self.dp_rank = dp_rank
self.dp_size = dp_size
self.ep_rank = ep_rank
self.ep_size = ep_size
self.moe_tp_rank = moe_tp_rank
self.moe_tp_size = moe_tp_size
self.server_args = server_args
# Store active patches for cleanup
self._patches = []
Expand All @@ -320,41 +384,113 @@ def __enter__(self):
mock_group.world_size = self.tp_size
mock_group.rank_in_group = self.tp_rank

# Mock Attn TP group
mock_attn_tp_group = MagicMock()
mock_attn_tp_group.world_size = self.attn_tp_size
mock_attn_tp_group.rank_in_group = self.attn_tp_rank

# Mock PP group with proper attributes
mock_pp_group = MagicMock()
mock_pp_group.rank_in_group = self.pp_rank
mock_pp_group.world_size = self.pp_size

# Mock MoE EP group
mock_ep_group = MagicMock()
mock_ep_group.world_size = self.ep_size
mock_ep_group.rank_in_group = self.ep_rank

# Mock Moe-tp group
mock_moe_tp_group = MagicMock()
mock_moe_tp_group.world_size = self.moe_tp_size
mock_moe_tp_group.rank_in_group = self.moe_tp_rank

sglang_parallel_state._MOE_TP = mock_moe_tp_group
sglang_parallel_state._MOE_EP = mock_ep_group

# IMPORTANT: Set global variables FIRST, before any patches or model loading.
# The get_attention_tp_rank() function reads from _ATTN_TP_RANK global variable.
# Setting this BEFORE model loading ensures the correct value is used.
sglang_server_args._global_server_args = self.server_args
sglang_dp_attention._ATTN_TP_RANK = self.attn_tp_rank
sglang_dp_attention._ATTN_TP_SIZE = self.attn_tp_size
sglang_dp_attention._ATTN_DP_RANK = 0
sglang_dp_attention._ATTN_DP_SIZE = 1
sglang_dp_attention._ATTN_DP_RANK = self.dp_rank
sglang_dp_attention._ATTN_DP_SIZE = self.dp_size

# Mock parallelism getters
# IMPORTANT: We need to patch functions at BOTH locations:
# 1. Where they are defined (sglang.srt.layers.dp_attention)
# 2. Where they are imported and used (sglang.srt.models.qwen3, etc.)
# This is because Python's import creates a local reference in the importing module.

self._patches = [
patch("sglang.srt.distributed.parallel_state.get_tp_group", return_value=mock_group),
patch("sglang.srt.distributed.get_pp_group", return_value=mock_pp_group),
patch("sglang.srt.distributed.parallel_state.get_moe_expert_parallel_rank", return_value=self.ep_rank),
patch(
"sglang.srt.distributed.parallel_state.get_moe_expert_parallel_world_size", return_value=self.ep_size
),
patch("sglang.srt.distributed.parallel_state.get_moe_tensor_parallel_rank", return_value=self.moe_tp_rank),
patch(
"sglang.srt.distributed.parallel_state.get_moe_tensor_parallel_world_size",
return_value=self.moe_tp_size,
),
patch(
"sglang.srt.distributed.get_pp_group", return_value=mock_pp_group
), # TODO: redundant. Delete pp group setting in the future
patch("sglang.srt.distributed.get_moe_tp_group", return_value=mock_moe_tp_group),
patch("sglang.srt.distributed.get_tp_group", return_value=mock_group),
patch("sglang.srt.distributed.get_moe_expert_parallel_rank", return_value=self.ep_rank),
patch("sglang.srt.distributed.get_moe_expert_parallel_world_size", return_value=self.ep_size),
patch("sglang.srt.distributed.get_moe_tensor_parallel_rank", return_value=self.moe_tp_rank),
patch("sglang.srt.distributed.get_moe_tensor_parallel_world_size", return_value=self.moe_tp_size),
patch(
"sglang.srt.distributed.parallel_state.get_tensor_model_parallel_world_size", return_value=self.tp_size
),
patch("sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank", return_value=self.tp_rank),
# Patch at definition location
patch("sglang.srt.layers.dp_attention.get_attention_tp_rank", return_value=self.attn_tp_rank),
patch("sglang.srt.layers.dp_attention.get_attention_tp_size", return_value=self.attn_tp_size),
patch("sglang.srt.layers.dp_attention.get_attention_tp_group", return_value=mock_attn_tp_group),
# Patch at import locations in model files - these are critical!
patch("sglang.srt.models.qwen3.get_attention_tp_rank", return_value=self.attn_tp_rank),
patch("sglang.srt.models.qwen3.get_attention_tp_size", return_value=self.attn_tp_size),
patch("sglang.srt.models.qwen3.get_pp_group", return_value=mock_pp_group),
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

many seem redundant. I will modify the mock context part to enable this for most models --- we shouldn't be changing the imports on the model level.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree. it's annoying...

# Patch at import locations in DeepSeek V2 model
patch("sglang.srt.models.deepseek_v2.get_attention_tp_rank", return_value=self.attn_tp_rank),
patch("sglang.srt.models.deepseek_v2.get_attention_tp_size", return_value=self.attn_tp_size),
patch("sglang.srt.models.deepseek_v2.get_tensor_model_parallel_world_size", return_value=self.tp_size),
patch("sglang.srt.models.deepseek_v2.get_pp_group", return_value=mock_pp_group),
patch("sglang.srt.models.deepseek_v2.get_moe_expert_parallel_world_size", return_value=self.ep_size),
# Patch moe layers
patch(
"sglang.srt.layers.moe.fused_moe_triton.layer.get_moe_expert_parallel_rank", return_value=self.ep_rank
),
patch(
"sglang.srt.layers.moe.fused_moe_triton.layer.get_moe_expert_parallel_world_size",
return_value=self.ep_size,
),
patch("sglang.srt.layers.moe.fused_moe_triton.layer.get_tp_group", return_value=mock_group),
patch(
"sglang.srt.layers.moe.fused_moe_triton.layer.get_moe_tensor_parallel_rank",
return_value=self.moe_tp_rank,
),
patch(
"sglang.srt.layers.moe.fused_moe_triton.layer.get_moe_tensor_parallel_world_size",
return_value=self.moe_tp_size,
),
# Patch at import locations in MoE token dispatcher
patch(
"sglang.srt.layers.moe.token_dispatcher.standard.get_moe_expert_parallel_rank",
return_value=self.ep_rank,
),
patch(
"sglang.srt.layers.moe.token_dispatcher.standard.get_moe_expert_parallel_world_size",
return_value=self.ep_size,
),
patch("sglang.srt.layers.moe.token_dispatcher.standard.get_tp_group", return_value=mock_group),
# Also patch in distributed module where get_tensor_model_parallel_rank may be imported
patch("sglang.srt.distributed.get_tensor_model_parallel_rank", return_value=self.tp_rank),
patch("sglang.srt.distributed.get_tensor_model_parallel_world_size", return_value=self.tp_size),
patch("sglang.srt.distributed.get_moe_expert_parallel_world_size", return_value=self.ep_size),
]

# Start all patches
Expand Down
Loading