diff --git a/sonicmoe/functional/__init__.py b/sonicmoe/functional/__init__.py index 4eccdc6..d0f2e73 100644 --- a/sonicmoe/functional/__init__.py +++ b/sonicmoe/functional/__init__.py @@ -15,6 +15,11 @@ from .forward import _down_projection_forward, _router_forward, _softmax_topk_fwd, _up_projection_forward from .utils import enable_quack_gemm, is_using_quack_gemm +from typing import Union, Tuple +import torch.distributed as dist + +from .expert_parallel import ep_dispatch, ep_combine, DeepEPBuffer, DeepEPConfig + def TC_topk_router_metadata( topk_router_indices: torch.Tensor, expert_frequency_offset, K: int @@ -45,6 +50,35 @@ def TC_topk_router_metadata( num_activated_expert_per_token_offset, ) +def expert_parallel_TC_topk_router_metadata( + topk_router_indices: torch.Tensor, expert_frequency_offset, K: int +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + invalid_tokens = torch.sum(topk_router_indices == -1) + s_scatter_idx = torch.argsort(topk_router_indices.view(-1)).int() + expert_frequency_offset = torch.cat([torch.zeros(1, device=expert_frequency_offset.device, dtype=expert_frequency_offset.dtype), expert_frequency_offset]) + + num_activated_expert_per_token_offset = (topk_router_indices > -1).sum(dim=-1) + num_activated_expert_per_token_offset = torch.cat([torch.tensor([0], device=num_activated_expert_per_token_offset.device, dtype=num_activated_expert_per_token_offset.dtype), num_activated_expert_per_token_offset]) + num_activated_expert_per_token_offset = num_activated_expert_per_token_offset.cumsum(0) + + x_gather_idx = s_scatter_idx // K + + + topk_router_indices_valid = topk_router_indices[topk_router_indices >= 0] + s_scatter_idx_valid = torch.argsort(topk_router_indices_valid.view(-1)).int() + s_reverse_scatter_idx_valid = torch.empty_like(s_scatter_idx_valid) + s_reverse_scatter_idx_valid[s_scatter_idx_valid] = torch.arange( + s_scatter_idx_valid.shape[0], device=s_scatter_idx_valid.device, dtype=s_scatter_idx_valid.dtype + ) + + return ( + expert_frequency_offset, + x_gather_idx[invalid_tokens:], + s_scatter_idx_valid, + s_reverse_scatter_idx_valid, + num_activated_expert_per_token_offset, + ) def general_routing_router_metadata( router_scores_selected: torch.Tensor, sorted_selected_T: torch.Tensor, selected_E: torch.Tensor, T: int, E: int @@ -431,61 +465,109 @@ def moe_TC_softmax_topk_layer( stream_id: int, activation_type: ActivationType | str = ActivationType.SWIGLU, is_inference_mode_enabled: bool = False, + rank: int = 0, + ep_size: int = 0, + ep_group: Union[None, dist.ProcessGroup] = None, + ep_buffer: Union[None, DeepEPBuffer] = None, + ep_config: Union[None, DeepEPConfig] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert ((b1 is None) and (b2 is None)) or ( (b1 is not None) and (b2 is not None) ), "b1 and b2 has to be None or not None at the same time!" - router_logits = F.linear(x, router_w) - topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, router_w.size(0), K) - expert_frequency, expert_frequency_offset = count_cumsum(topk_indices.view(-1), router_w.size(0), do_cumsum=True) + num_experts = router_w.size(0) + loacl_num_experts = router_w.size(0) // ep_size + router_logits = F.linear(x.to(router_w.dtype), router_w) + + topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, num_experts, K) - ( - expert_frequency_offset, - x_gather_idx, - s_scatter_idx, - s_reverse_scatter_idx, - num_activated_expert_per_token_offset, - ) = TC_topk_router_metadata(topk_indices, expert_frequency_offset, K) + if ep_size > 1: + topk_indices = topk_indices.long() - T = x.size(0) + ( + recv_x, + recv_topk_indices, + recv_topk_weights, + _, + ep_handle + ) = ep_dispatch( + x, + ep_buffer, + ep_config, + topk_indices, + topk_scores, + num_experts, + ) + + x = recv_x + topk_indices = recv_topk_indices.int() + topk_scores = recv_topk_weights + topk_scores = topk_scores[topk_scores > 0] + + expert_frequency, expert_frequency_offset = count_cumsum(topk_indices.view(-1), loacl_num_experts, do_cumsum=True) + + if ep_size > 1: + expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx, num_activated_expert_per_token_offset = expert_parallel_TC_topk_router_metadata( + topk_indices, expert_frequency_offset, K + ) + else: + expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx, num_activated_expert_per_token_offset = TC_topk_router_metadata( + topk_indices, expert_frequency_offset, K + ) if type(activation_type) == str: activation_type = ActivationType(activation_type) - y1, z = _UpProjection.apply( - x, - w1, - b1, - expert_frequency_offset, - T * K, - K, - stream_id, - x_gather_idx, - s_scatter_idx, - s_reverse_scatter_idx, - num_activated_expert_per_token_offset, - False, # is_varlen_K - activation_type, - is_inference_mode_enabled, - ) + T = x.size(0) - o = _DownProjection.apply( - y1, - z, - w2, - b2, - topk_scores, - expert_frequency_offset, - T, - K, - stream_id, - x_gather_idx, - s_scatter_idx, - s_reverse_scatter_idx, - num_activated_expert_per_token_offset, - False, # is_varlen_K - activation_type, - ) + if ep_size > 1: + TK = s_scatter_idx.shape[0] + is_varlen_K = True + else: + TK = T * K + is_varlen_K = False + + if T > 0: + y1, z = _UpProjection.apply( + x, + w1, + b1, + expert_frequency_offset, + TK, + K, + stream_id, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + is_varlen_K, + activation_type, + is_inference_mode_enabled, + ) + + + o = _DownProjection.apply( + y1, + z, + w2, + b2, + topk_scores, + expert_frequency_offset, + T, + K, + stream_id, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + is_varlen_K, + activation_type, + ) + + else: + o = x + + if ep_size > 1: + o, _ = ep_combine(o, ep_buffer, ep_config, ep_handle) return o, router_logits, expert_frequency diff --git a/sonicmoe/functional/expert_parallel.py b/sonicmoe/functional/expert_parallel.py new file mode 100644 index 0000000..12cece8 --- /dev/null +++ b/sonicmoe/functional/expert_parallel.py @@ -0,0 +1,214 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Portions of this code are from DeepSeek DeepEP project +# Copyright (c) 2025 DeepSeek +# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE + + +try: + import deep_ep + + DEEPEP_IS_INSTALLED = True +except ImportError: + DEEPEP_IS_INSTALLED = False + +import torch + + +class DeepEPDispatch(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + x, + buffer, + config, + token_indices, + token_probs, + num_experts, + async_finish=False, + allocate_on_comm_stream=False, + ): + + previous_event = None + if async_finish: + previous_event = EventOverlap(EventHandle()) + + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + event, + ) = buffer.get_dispatch_layout( + token_indices, + num_experts, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + ( + recv_x, + recv_token_indices, + recv_token_probs, + num_recv_tokens_per_expert_list, + handle, + after_event_overlap, + ) = buffer.dispatch( + x, + config=config, + topk_idx=token_indices, + topk_weights=token_probs, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + num_tokens_per_rank=num_tokens_per_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + if async_finish: + after_event_overlap.current_stream_wait() + + ctx.buffer = buffer + ctx.config = config + ctx.handle = handle + ctx.async_finish = async_finish + ctx.allocate_on_comm_stream = allocate_on_comm_stream + tokens_per_expert = torch.tensor(num_recv_tokens_per_expert_list) + + return (recv_x, recv_token_indices, recv_token_probs, tokens_per_expert, handle) + + @staticmethod + def backward( + ctx, + grad_output, + grad_token_indices, + grad_token_probs, + grad_tokens_per_expert, + grad_handle, + ): + + buffer = ctx.buffer + config = ctx.config + handle = ctx.handle + previous_event = None + if ctx.async_finish: + previous_event = EventOverlap(EventHandle()) + grad_x, grad_token_probs, after_event = buffer.combine( + grad_output.contiguous(), + handle, + config=config, + topk_weights=grad_token_probs.float(), + previous_event=previous_event, + async_finish=ctx.async_finish, + allocate_on_comm_stream=ctx.allocate_on_comm_stream, + ) + + if ctx.async_finish: + after_event.current_stream_wait() + + return grad_x, None, None, None, grad_token_probs, None, None, None + + +class DeepEPCombine(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + x, + buffer, + config, + handle, + async_finish=False, + allocate_on_comm_stream=False, + ): + previous_event = None + if async_finish: + previous_event = EventOverlap(EventHandle()) + + combined_x, _, after_event = buffer.combine( + x, + config=config, + handle=handle, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + if async_finish: + after_event.current_stream_wait() + + ctx.buffer = buffer + ctx.config = config + ctx.handle = handle + ctx.async_finish = async_finish + ctx.allocate_on_comm_stream = allocate_on_comm_stream + return combined_x, None + + @staticmethod + def backward(ctx, grad_output, previous_event=None): + previous_event = None + if ctx.async_finish: + previous_event = EventOverlap(EventHandle()) + buffer = ctx.buffer + config = ctx.config + grad_x, _, _, _, _, after_event = buffer.dispatch( + grad_output.contiguous(), + config=config, + handle=ctx.handle, + previous_event=previous_event, + async_finish=ctx.async_finish, + allocate_on_comm_stream=ctx.allocate_on_comm_stream, + ) + + if ctx.async_finish: + after_event.current_stream_wait() + return grad_x, None, None, None, None, None + + +if DEEPEP_IS_INSTALLED: + + def ep_dispatch( + x, + buffer, + config, + token_indices, + token_probs, + num_experts, + async_finish=False, + allocate_on_comm_stream=False, + ): + return DeepEPDispatch.apply( + x.contiguous(), + buffer, + config, + token_indices, + token_probs, + num_experts, + async_finish, + allocate_on_comm_stream, + ) + + + def ep_combine( + x, buffer, config, handle, async_finish=False, allocate_on_comm_stream=False + ): + return DeepEPCombine.apply( + x, buffer, config, handle, async_finish, allocate_on_comm_stream + ) + + DeepEPBuffer = deep_ep.Buffer + DeepEPConfig = deep_ep.Config + +else: + + def DeepEPImportErr(): + raise ImportError("Deepep is required by for expert parallel!") + + ep_dispatch = DeepEPImportErr() + ep_combine = DeepEPImportErr() + + DeepEPBuffer = None + DeepEPConfig = None diff --git a/sonicmoe/moe.py b/sonicmoe/moe.py index 08b7686..680c3c2 100644 --- a/sonicmoe/moe.py +++ b/sonicmoe/moe.py @@ -2,11 +2,12 @@ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao # ******************************************************************************** -from typing import Callable +from typing import Callable, Union import torch import torch.nn as nn import torch.nn.functional as F +import torch.distributed as dist from .count_cumsum import count_cumsum from .enums import ActivationType, KernelBackendMoE, is_glu @@ -174,10 +175,20 @@ def __init__( activation_function: ActivationType, add_bias: bool, std: float, + rank: int = 0, + ep_size: int = 0, + ep_group: Union[None, dist.ProcessGroup] = None, + ep_buffer = None, + ep_config = None, ) -> None: super().__init__() - self.num_experts = num_experts + self.rank = rank + self.enable_expert_parallel = ep_size > 1 + self.ep_size = ep_size if ep_size > 0 else 1 + self.ep_group = ep_group + + self.num_experts = num_experts // self.ep_size self.top_k = num_experts_per_tok self.hidden_size = hidden_size @@ -187,8 +198,12 @@ def __init__( self.activation_function = activation_function + if self.ep_size > 1: + with torch.no_grad(): + dist.broadcast(self.router.weight.data, src=0) + self.c_fc = Experts( - num_experts=num_experts, + num_experts=self.num_experts, in_features=self.hidden_size, out_features=2 * self.intermediate_size if is_glu(activation_function) else self.intermediate_size, add_bias=add_bias, @@ -196,7 +211,7 @@ def __init__( ) self.c_proj = Experts( - num_experts=num_experts, + num_experts=self.num_experts, in_features=self.intermediate_size, out_features=self.hidden_size, add_bias=add_bias, @@ -205,21 +220,31 @@ def __init__( self.stream_id = torch.cuda.current_stream().cuda_stream + assert(not (self.enable_expert_parallel and ep_buffer == None)) + assert(not (self.enable_expert_parallel and ep_config == None)) + assert(not (self.enable_expert_parallel and ep_group == None)) + + self.deep_ep_buffer = ep_buffer + self.deep_ep_config = ep_config + + + def forward( self, hidden_states: torch.Tensor, kernel_backend_moe: KernelBackendMoE = KernelBackendMoE.sonicmoe, is_inference_mode: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: + assert (self.ep_size <= 1 or kernel_backend_moe == KernelBackendMoE.sonicmoe), "Expert parallel only surport sonicmoe KernelBackendMoE.backend." original_shape = hidden_states.shape # hidden_states -> (batch_size, query_length, hidden_size) hidden_states = hidden_states.view(-1, self.hidden_size) - + if kernel_backend_moe == KernelBackendMoE.sonicmoe: hidden_states, router_logits, expert_frequency = moe_TC_softmax_topk_layer( hidden_states, - self.router.weight, + self.router.weight.to(torch.float32), self.c_fc.weight.permute(1, 2, 0), self.c_fc.bias, self.c_proj.weight.permute(1, 2, 0), @@ -228,6 +253,11 @@ def forward( self.stream_id, self.activation_function, is_inference_mode or not self.training, + self.rank, + self.ep_size, + self.ep_group, + self.deep_ep_buffer, + self.deep_ep_config ) else: # hidden_states -> (total_q, hidden_size) @@ -248,6 +278,15 @@ def forward( # hidden_states -> (batch_size, query_length, hidden_size) + if self.ep_size > 1: + gathered_router_logits = [torch.zeros_like(router_logits) for _ in range(self.ep_size)] + dist.all_gather(gathered_router_logits, router_logits, group=self.ep_group) + router_logits = torch.cat(gathered_router_logits, dim=0) + + gathered_expert_frequency = [torch.zeros_like(expert_frequency) for _ in range(self.ep_size)] + dist.all_gather(gathered_expert_frequency, expert_frequency, group=self.ep_group) + expert_frequency = torch.cat(gathered_expert_frequency, dim=0) + aux_loss = self._compute_switch_loss( logits=router_logits, probs=F.softmax(router_logits, dim=-1, dtype=torch.float32), @@ -369,3 +408,9 @@ def _get_topk(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x, indices = x.topk(self.top_k, dim=-1) return x, indices + + def sync_router_grad(self): + if self.ep_size > 1 and self.ep_group is not None: + router_grad = self.router.weight.grad.data.to(torch.float32) + dist.all_reduce(router_grad, op=dist.ReduceOp.SUM, group=self.ep_group) + self.router.weight.grad.copy_(router_grad.to(self.router.weight.grad.dtype)) diff --git a/tests/gen_ep8_res.py b/tests/gen_ep8_res.py new file mode 100644 index 0000000..735db81 --- /dev/null +++ b/tests/gen_ep8_res.py @@ -0,0 +1,114 @@ + +import deep_ep +import torch +import os +import torch.distributed as dist + +from sonicmoe import MoE, KernelBackendMoE +from sonicmoe.enums import ActivationType + +def test_loop(local_rank: int, num_local_ranks: int): + + torch.cuda.manual_seed(0) + num_nodes = int(os.getenv('MLP_WORKER_NUM', 1)) + + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '8361')) + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + node_rank = int(os.getenv('RANK', 0)) + assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 + + dist.init_process_group( + backend='nccl', + init_method=f'tcp://{ip}:{port}', + world_size=num_nodes * num_local_ranks, + rank=node_rank * num_local_ranks + local_rank + ) + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device('cuda') + torch.cuda.set_device(local_rank) + + rank, num_ranks, group = dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes))) + + hidden_size = 512 + intermediate_size = 384 + + + test_ll_compatibility, num_rdma_bytes = False, 0 + ep_buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=False, + num_qps_per_rank=(1)) + + nvl_buffer_size = 256 + num_sms = 24 + ep_config = deep_ep.Config(num_sms, 8, nvl_buffer_size) + + moe = MoE( + num_experts=32, + num_experts_per_tok=8, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation_function=ActivationType.SWIGLU, + add_bias=False, + std=0.02, + rank=rank, + ep_size=num_ranks, + ep_group=group, + ep_buffer=ep_buffer, + ep_config=ep_config + ).to(device=rank, dtype=torch.bfloat16) + + moe.router = moe.router.to(torch.float32) + + x = torch.randn(16, hidden_size, device=rank, dtype=torch.bfloat16) + output, aux_loss = moe(x, kernel_backend_moe=KernelBackendMoE.sonicmoe) + if rank == 0: + print("forward output : ", output) + + eg = torch.ones_like(output) + loss = (output * eg).sum() + loss.backward() + moe.sync_router_grad() + + def torch_allgather(x, size, group, dim=0): + gathered_x = [torch.zeros_like(x) for _ in range(size)] + dist.all_gather(gathered_x, x, group=group) + x = torch.cat(gathered_x, dim=dim) + return x + + router_weight = moe.router.weight + c_fc_weight = moe.c_fc.weight + c_fc_grad = moe.c_fc.weight.grad + c_proj_weight = moe.c_proj.weight + + # router_weight = torch_allgather(router_weight, num_ranks, group) + c_fc_weight = torch_allgather(c_fc_weight, num_ranks, group) + c_proj_weight = torch_allgather(c_proj_weight, num_ranks, group) + c_fc_grad = torch_allgather(c_fc_grad, num_ranks, group) + + x = torch_allgather(x, num_ranks, group) + output = torch_allgather(output, num_ranks, group) + + if rank == 0: + # print("gatherd router weight : ", router_weight.shape) + # print("gatherd c_fc weight : ", c_fc_weight.shape) + # print("gatherd c_proj weight : ", c_proj_weight.shape) + # print("gatherd input : ", x.shape) + # print("gatherd output : ", output.shape) + # print("aux loss : ", aux_loss) + + saved_tensor = {} + saved_tensor["router_w"] = router_weight + saved_tensor["c_fc_w"] = c_fc_weight + saved_tensor["c_fc_g"] = c_fc_grad + saved_tensor["c_proj_w"] = c_proj_weight + saved_tensor["input"] = x + saved_tensor["output"] = output + saved_tensor["aux_loss"] = aux_loss + + saved_tensor["router_grad"] = moe.router.weight.grad + + torch.save(saved_tensor, "moe_ep8_tensor_res.pth") + +if __name__ == '__main__': + num_processes = 8 + torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) diff --git a/tests/test_ep8_intranode.py b/tests/test_ep8_intranode.py new file mode 100644 index 0000000..af00718 --- /dev/null +++ b/tests/test_ep8_intranode.py @@ -0,0 +1,65 @@ +import torch +from sonicmoe import MoE, KernelBackendMoE +from sonicmoe.enums import ActivationType + +import subprocess + +if __name__ == "__main__": + result = subprocess.run(['python', 'gen_ep8_res.py'], stdout=subprocess.PIPE, text=True) + + rank = torch.device("cuda:0") + hidden_size = 512 + intermediate_size = 384 + moe = MoE( + num_experts=32, + num_experts_per_tok=8, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation_function=ActivationType.SWIGLU, + add_bias=False, + std=0.02, + ).to(device=rank, dtype=torch.bfloat16) + + moe.router = moe.router.to(torch.float32) + + saved_tensor = torch.load("moe_ep8_tensor_res.pth") + moe.router.weight.data.copy_(saved_tensor["router_w"]) + moe.c_fc.weight.data.copy_(saved_tensor["c_fc_w"]) + moe.c_proj.weight.data.copy_(saved_tensor["c_proj_w"]) + + x = saved_tensor["input"] + ep1_output, ep1_auxloss = moe(x, kernel_backend_moe=KernelBackendMoE.sonicmoe) + ep8_output = saved_tensor["output"] + # print("ep1 output : ", ep1_output.shape) + # print("ep8 output : ", ep8_output.shape) + + ep8_auxloss = saved_tensor["aux_loss"] + + ep8_router_grad = saved_tensor["router_grad"] + + torch.testing.assert_close(ep1_output, ep8_output, rtol=1.4e-2, atol=2e-2) + print("auxloss : ", ep1_auxloss, ep8_auxloss) + torch.testing.assert_close(ep1_auxloss, ep8_auxloss, rtol=1.4e-2, atol=2e-2) + + eg = torch.ones_like(ep1_output) + loss = (ep1_output * eg).sum() + loss.backward() + ep1_router_grad = moe.router.weight.grad + print("ep1 grad : ", ep1_router_grad) + print("ep8 grad : ", ep8_router_grad) + + # print("ep1 grad : ", ep1_router_grad.shape, ep1_router_grad.sum()) + # print("ep1 weight : ", moe.router.weight.shape, moe.router.weight.sum()) + + torch.testing.assert_close(ep1_router_grad, ep8_router_grad, rtol=1.4e-1, atol=2e-1) + + ep1_c_fc_grad = moe.c_fc.weight.grad + ep8_c_fc_grad = saved_tensor["c_fc_g"] + # print("ep1 grad : ", ep1_c_fc_grad) + # print("ep8 grad : ", ep8_c_fc_grad) + + + torch.testing.assert_close(ep1_c_fc_grad, ep8_c_fc_grad, rtol=1.4e-2, atol=2e-2) + + + \ No newline at end of file