-
Notifications
You must be signed in to change notification settings - Fork 1
feat: support multi-node TP/EP/DP/PP from training side, and TP/EP/DP from rollout side, with rdma, for models of deepseek arch #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| from collections.abc import Callable, Mapping, Sequence | ||
|
|
||
| import ray | ||
| import sglang.srt.distributed.parallel_state as sglang_parallel_state | ||
| import sglang.srt.layers.dp_attention as sglang_dp_attention | ||
| import sglang.srt.server_args as sglang_server_args | ||
| import torch | ||
|
|
@@ -167,13 +168,29 @@ def connect_rollout_engines( | |
| session_id = targets_to_session_id[(target.engine_ind, target.engine_rank)] | ||
| remote_info = RemoteWeightInfo(session_id, self.remote_weight_infos_by_session_id[session_id]) | ||
| # Instantiate the local model replicas and a corresponding transfer engine with memory registry for each type of rollout shard. | ||
| # TODO verify: | ||
| # - if sglang dp is enabled, then attn_tp is equal to tp // dp | ||
| # - if sglang ep is enabled, then moe-tp is equal to tp // ep | ||
| # generally tp * pp should be equal to the world_size | ||
| if target.engine_rank not in self.engines: | ||
| transfer_engine = self._create_transfer_engine() | ||
| parallel_rank_dict = self.transfer_plan.tp_conversion(target.engine_rank) | ||
| logger.info( | ||
| f"[RDMA] Creating model replica for engine rank {target.engine_rank} with rank dict {parallel_rank_dict}" | ||
| ) | ||
| model_replica = self._create_inference_replica( | ||
| self.args.hf_checkpoint, | ||
| pp_shard=target.source_shard, | ||
| target_rank=target.engine_rank, | ||
| target_rank=target.engine_rank, # NOTE: here we assume that sglang_tp == world_size when pp_size == 1 | ||
| target_tp=self.args.rollout_num_gpus_per_engine, | ||
| dp_rank=parallel_rank_dict["dp_rank"], | ||
| dp_size=self.transfer_plan._rollout_dp_size, | ||
| attn_tp_rank=parallel_rank_dict["attn_tp_rank"], | ||
| attn_tp_size=self.transfer_plan._rollout_attn_tp_size, | ||
| ep_rank=parallel_rank_dict["ep_rank"], | ||
| ep_size=self.transfer_plan._rollout_ep_size, | ||
| moe_tp_rank=parallel_rank_dict["moe_tp_rank"], | ||
| moe_tp_size=self.transfer_plan._rollout_moe_tp_size, | ||
| server_args=self.session_id_to_server_args[session_id], | ||
| ) | ||
| print_memory(f"[RDMA] After model replica at {target.engine_rank}") | ||
|
|
@@ -216,7 +233,20 @@ def _create_transfer_engine(self) -> TransferEngine: | |
| return transfer_engine | ||
|
|
||
| def _create_inference_replica( | ||
| self, model_path: str, pp_shard: int, target_rank: int, target_tp: int, server_args: ServerArgs | ||
| self, | ||
| model_path: str, | ||
| pp_shard: int, | ||
| target_rank: int, | ||
| target_tp: int, | ||
| dp_rank: int, | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dp rank and size is irrelevant here
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Considering we pass One potential concerning point is this part shall we need to change it into : or, we could delete all of them after sglang's mocking context enabled |
||
| dp_size: int, | ||
| attn_tp_rank: int, | ||
| attn_tp_size: int, | ||
| ep_rank: int, | ||
| ep_size: int, | ||
| moe_tp_rank: int, | ||
| moe_tp_size: int, | ||
| server_args: ServerArgs, | ||
| ): | ||
| """ | ||
| Create model replica for target rank with correct tp settings. | ||
|
|
@@ -235,9 +265,24 @@ def _create_inference_replica( | |
|
|
||
| # Mock the distributed environment to get correct weight shapes | ||
| logger.info( | ||
| f" Engine replica: {target_rank} tp {target_tp} pp_shard {pp_shard}, model pp sharding not implemented " | ||
| f" Engine replica: {target_rank} tp {target_tp} pp_shard {pp_shard}, model pp sharding not implemented, " | ||
| f" dp_rank {dp_rank}/{dp_size}, attn_tp_rank {attn_tp_rank}/{attn_tp_size}, " | ||
| f" ep_rank {ep_rank}/{ep_size}, moe_tp_rank {moe_tp_rank}/{moe_tp_size} " | ||
| ) | ||
| with MockSglangDistributedContext(tp_size=target_tp, tp_rank=target_rank, server_args=server_args): | ||
| # TODO: should take attn_tp/ep/dp into account in the future. | ||
| with MockSglangDistributedContext( | ||
| tp_size=target_tp, | ||
| tp_rank=target_rank, | ||
| dp_rank=dp_rank, | ||
| dp_size=dp_size, | ||
| attn_tp_rank=attn_tp_rank, | ||
| attn_tp_size=attn_tp_size, | ||
| ep_rank=ep_rank, | ||
| ep_size=ep_size, | ||
| moe_tp_rank=moe_tp_rank, | ||
| moe_tp_size=moe_tp_size, | ||
| server_args=server_args, | ||
| ): | ||
| model = get_model( | ||
| model_config=model_config, | ||
| load_config=load_config, | ||
|
|
@@ -297,16 +342,35 @@ def finish_transfer_task(self) -> None: | |
|
|
||
|
|
||
| class MockSglangDistributedContext: | ||
| def __init__(self, tp_size: int, tp_rank: int, server_args: ServerArgs): | ||
| def __init__( | ||
| self, | ||
| tp_size: int, | ||
| tp_rank: int, | ||
| dp_rank: int, | ||
| dp_size: int, | ||
| attn_tp_rank: int, | ||
| attn_tp_size: int, | ||
| ep_rank: int, | ||
| ep_size: int, | ||
| moe_tp_rank: int, | ||
| moe_tp_size: int, | ||
| server_args: ServerArgs, | ||
| ): | ||
| """ | ||
| TODO: Extend this to support ep, and dp attention? | ||
| """ | ||
| self.tp_size = tp_size | ||
| self.tp_rank = tp_rank | ||
| self.pp_size = 1 | ||
| self.pp_rank = 0 | ||
| self.attn_tp_size = tp_size | ||
| self.attn_tp_rank = tp_rank | ||
| self.attn_tp_size = attn_tp_size | ||
| self.attn_tp_rank = attn_tp_rank | ||
| self.dp_rank = dp_rank | ||
| self.dp_size = dp_size | ||
| self.ep_rank = ep_rank | ||
| self.ep_size = ep_size | ||
| self.moe_tp_rank = moe_tp_rank | ||
| self.moe_tp_size = moe_tp_size | ||
| self.server_args = server_args | ||
| # Store active patches for cleanup | ||
| self._patches = [] | ||
|
|
@@ -320,41 +384,113 @@ def __enter__(self): | |
| mock_group.world_size = self.tp_size | ||
| mock_group.rank_in_group = self.tp_rank | ||
|
|
||
| # Mock Attn TP group | ||
| mock_attn_tp_group = MagicMock() | ||
| mock_attn_tp_group.world_size = self.attn_tp_size | ||
| mock_attn_tp_group.rank_in_group = self.attn_tp_rank | ||
|
|
||
| # Mock PP group with proper attributes | ||
| mock_pp_group = MagicMock() | ||
| mock_pp_group.rank_in_group = self.pp_rank | ||
| mock_pp_group.world_size = self.pp_size | ||
|
|
||
| # Mock MoE EP group | ||
| mock_ep_group = MagicMock() | ||
| mock_ep_group.world_size = self.ep_size | ||
| mock_ep_group.rank_in_group = self.ep_rank | ||
|
|
||
| # Mock Moe-tp group | ||
| mock_moe_tp_group = MagicMock() | ||
| mock_moe_tp_group.world_size = self.moe_tp_size | ||
| mock_moe_tp_group.rank_in_group = self.moe_tp_rank | ||
|
|
||
| sglang_parallel_state._MOE_TP = mock_moe_tp_group | ||
| sglang_parallel_state._MOE_EP = mock_ep_group | ||
|
|
||
| # IMPORTANT: Set global variables FIRST, before any patches or model loading. | ||
| # The get_attention_tp_rank() function reads from _ATTN_TP_RANK global variable. | ||
| # Setting this BEFORE model loading ensures the correct value is used. | ||
| sglang_server_args._global_server_args = self.server_args | ||
| sglang_dp_attention._ATTN_TP_RANK = self.attn_tp_rank | ||
| sglang_dp_attention._ATTN_TP_SIZE = self.attn_tp_size | ||
| sglang_dp_attention._ATTN_DP_RANK = 0 | ||
| sglang_dp_attention._ATTN_DP_SIZE = 1 | ||
| sglang_dp_attention._ATTN_DP_RANK = self.dp_rank | ||
| sglang_dp_attention._ATTN_DP_SIZE = self.dp_size | ||
|
|
||
| # Mock parallelism getters | ||
| # IMPORTANT: We need to patch functions at BOTH locations: | ||
| # 1. Where they are defined (sglang.srt.layers.dp_attention) | ||
| # 2. Where they are imported and used (sglang.srt.models.qwen3, etc.) | ||
| # This is because Python's import creates a local reference in the importing module. | ||
|
|
||
| self._patches = [ | ||
| patch("sglang.srt.distributed.parallel_state.get_tp_group", return_value=mock_group), | ||
| patch("sglang.srt.distributed.get_pp_group", return_value=mock_pp_group), | ||
| patch("sglang.srt.distributed.parallel_state.get_moe_expert_parallel_rank", return_value=self.ep_rank), | ||
| patch( | ||
| "sglang.srt.distributed.parallel_state.get_moe_expert_parallel_world_size", return_value=self.ep_size | ||
| ), | ||
| patch("sglang.srt.distributed.parallel_state.get_moe_tensor_parallel_rank", return_value=self.moe_tp_rank), | ||
| patch( | ||
| "sglang.srt.distributed.parallel_state.get_moe_tensor_parallel_world_size", | ||
| return_value=self.moe_tp_size, | ||
| ), | ||
| patch( | ||
| "sglang.srt.distributed.get_pp_group", return_value=mock_pp_group | ||
| ), # TODO: redundant. Delete pp group setting in the future | ||
| patch("sglang.srt.distributed.get_moe_tp_group", return_value=mock_moe_tp_group), | ||
| patch("sglang.srt.distributed.get_tp_group", return_value=mock_group), | ||
| patch("sglang.srt.distributed.get_moe_expert_parallel_rank", return_value=self.ep_rank), | ||
| patch("sglang.srt.distributed.get_moe_expert_parallel_world_size", return_value=self.ep_size), | ||
| patch("sglang.srt.distributed.get_moe_tensor_parallel_rank", return_value=self.moe_tp_rank), | ||
| patch("sglang.srt.distributed.get_moe_tensor_parallel_world_size", return_value=self.moe_tp_size), | ||
| patch( | ||
| "sglang.srt.distributed.parallel_state.get_tensor_model_parallel_world_size", return_value=self.tp_size | ||
| ), | ||
| patch("sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank", return_value=self.tp_rank), | ||
| # Patch at definition location | ||
| patch("sglang.srt.layers.dp_attention.get_attention_tp_rank", return_value=self.attn_tp_rank), | ||
| patch("sglang.srt.layers.dp_attention.get_attention_tp_size", return_value=self.attn_tp_size), | ||
| patch("sglang.srt.layers.dp_attention.get_attention_tp_group", return_value=mock_attn_tp_group), | ||
| # Patch at import locations in model files - these are critical! | ||
| patch("sglang.srt.models.qwen3.get_attention_tp_rank", return_value=self.attn_tp_rank), | ||
| patch("sglang.srt.models.qwen3.get_attention_tp_size", return_value=self.attn_tp_size), | ||
| patch("sglang.srt.models.qwen3.get_pp_group", return_value=mock_pp_group), | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. many seem redundant. I will modify the mock context part to enable this for most models --- we shouldn't be changing the imports on the model level.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree. it's annoying... |
||
| # Patch at import locations in DeepSeek V2 model | ||
| patch("sglang.srt.models.deepseek_v2.get_attention_tp_rank", return_value=self.attn_tp_rank), | ||
| patch("sglang.srt.models.deepseek_v2.get_attention_tp_size", return_value=self.attn_tp_size), | ||
| patch("sglang.srt.models.deepseek_v2.get_tensor_model_parallel_world_size", return_value=self.tp_size), | ||
| patch("sglang.srt.models.deepseek_v2.get_pp_group", return_value=mock_pp_group), | ||
| patch("sglang.srt.models.deepseek_v2.get_moe_expert_parallel_world_size", return_value=self.ep_size), | ||
| # Patch moe layers | ||
| patch( | ||
| "sglang.srt.layers.moe.fused_moe_triton.layer.get_moe_expert_parallel_rank", return_value=self.ep_rank | ||
| ), | ||
| patch( | ||
| "sglang.srt.layers.moe.fused_moe_triton.layer.get_moe_expert_parallel_world_size", | ||
| return_value=self.ep_size, | ||
| ), | ||
| patch("sglang.srt.layers.moe.fused_moe_triton.layer.get_tp_group", return_value=mock_group), | ||
| patch( | ||
| "sglang.srt.layers.moe.fused_moe_triton.layer.get_moe_tensor_parallel_rank", | ||
| return_value=self.moe_tp_rank, | ||
| ), | ||
| patch( | ||
| "sglang.srt.layers.moe.fused_moe_triton.layer.get_moe_tensor_parallel_world_size", | ||
| return_value=self.moe_tp_size, | ||
| ), | ||
| # Patch at import locations in MoE token dispatcher | ||
| patch( | ||
| "sglang.srt.layers.moe.token_dispatcher.standard.get_moe_expert_parallel_rank", | ||
| return_value=self.ep_rank, | ||
| ), | ||
| patch( | ||
| "sglang.srt.layers.moe.token_dispatcher.standard.get_moe_expert_parallel_world_size", | ||
| return_value=self.ep_size, | ||
| ), | ||
| patch("sglang.srt.layers.moe.token_dispatcher.standard.get_tp_group", return_value=mock_group), | ||
| # Also patch in distributed module where get_tensor_model_parallel_rank may be imported | ||
| patch("sglang.srt.distributed.get_tensor_model_parallel_rank", return_value=self.tp_rank), | ||
| patch("sglang.srt.distributed.get_tensor_model_parallel_world_size", return_value=self.tp_size), | ||
| patch("sglang.srt.distributed.get_moe_expert_parallel_world_size", return_value=self.ep_size), | ||
| ] | ||
|
|
||
| # Start all patches | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.