Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 126 additions & 44 deletions sonicmoe/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
from .forward import _down_projection_forward, _router_forward, _softmax_topk_fwd, _up_projection_forward
from .utils import enable_quack_gemm, is_using_quack_gemm

from typing import Union, Tuple
import torch.distributed as dist

from .expert_parallel import ep_dispatch, ep_combine, DeepEPBuffer, DeepEPConfig


def TC_topk_router_metadata(
topk_router_indices: torch.Tensor, expert_frequency_offset, K: int
Expand Down Expand Up @@ -45,6 +50,35 @@ def TC_topk_router_metadata(
num_activated_expert_per_token_offset,
)

def expert_parallel_TC_topk_router_metadata(
topk_router_indices: torch.Tensor, expert_frequency_offset, K: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

invalid_tokens = torch.sum(topk_router_indices == -1)
s_scatter_idx = torch.argsort(topk_router_indices.view(-1)).int()
expert_frequency_offset = torch.cat([torch.zeros(1, device=expert_frequency_offset.device, dtype=expert_frequency_offset.dtype), expert_frequency_offset])

num_activated_expert_per_token_offset = (topk_router_indices > -1).sum(dim=-1)
num_activated_expert_per_token_offset = torch.cat([torch.tensor([0], device=num_activated_expert_per_token_offset.device, dtype=num_activated_expert_per_token_offset.dtype), num_activated_expert_per_token_offset])
num_activated_expert_per_token_offset = num_activated_expert_per_token_offset.cumsum(0)

x_gather_idx = s_scatter_idx // K


topk_router_indices_valid = topk_router_indices[topk_router_indices >= 0]
s_scatter_idx_valid = torch.argsort(topk_router_indices_valid.view(-1)).int()
s_reverse_scatter_idx_valid = torch.empty_like(s_scatter_idx_valid)
s_reverse_scatter_idx_valid[s_scatter_idx_valid] = torch.arange(
s_scatter_idx_valid.shape[0], device=s_scatter_idx_valid.device, dtype=s_scatter_idx_valid.dtype
)

return (
expert_frequency_offset,
x_gather_idx[invalid_tokens:],
s_scatter_idx_valid,
s_reverse_scatter_idx_valid,
num_activated_expert_per_token_offset,
)

def general_routing_router_metadata(
router_scores_selected: torch.Tensor, sorted_selected_T: torch.Tensor, selected_E: torch.Tensor, T: int, E: int
Expand Down Expand Up @@ -431,61 +465,109 @@ def moe_TC_softmax_topk_layer(
stream_id: int,
activation_type: ActivationType | str = ActivationType.SWIGLU,
is_inference_mode_enabled: bool = False,
rank: int = 0,
ep_size: int = 0,
ep_group: Union[None, dist.ProcessGroup] = None,
ep_buffer: Union[None, DeepEPBuffer] = None,
ep_config: Union[None, DeepEPConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert ((b1 is None) and (b2 is None)) or (
(b1 is not None) and (b2 is not None)
), "b1 and b2 has to be None or not None at the same time!"
router_logits = F.linear(x, router_w)
topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, router_w.size(0), K)
expert_frequency, expert_frequency_offset = count_cumsum(topk_indices.view(-1), router_w.size(0), do_cumsum=True)
num_experts = router_w.size(0)
loacl_num_experts = router_w.size(0) // ep_size
router_logits = F.linear(x.to(router_w.dtype), router_w)

topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, num_experts, K)

(
expert_frequency_offset,
x_gather_idx,
s_scatter_idx,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
) = TC_topk_router_metadata(topk_indices, expert_frequency_offset, K)
if ep_size > 1:
topk_indices = topk_indices.long()

T = x.size(0)
(
recv_x,
recv_topk_indices,
recv_topk_weights,
_,
ep_handle
) = ep_dispatch(
x,
ep_buffer,
ep_config,
topk_indices,
topk_scores,
num_experts,
)

x = recv_x
topk_indices = recv_topk_indices.int()
topk_scores = recv_topk_weights
topk_scores = topk_scores[topk_scores > 0]

expert_frequency, expert_frequency_offset = count_cumsum(topk_indices.view(-1), loacl_num_experts, do_cumsum=True)

if ep_size > 1:
expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx, num_activated_expert_per_token_offset = expert_parallel_TC_topk_router_metadata(
topk_indices, expert_frequency_offset, K
)
else:
expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx, num_activated_expert_per_token_offset = TC_topk_router_metadata(
topk_indices, expert_frequency_offset, K
)

if type(activation_type) == str:
activation_type = ActivationType(activation_type)

y1, z = _UpProjection.apply(
x,
w1,
b1,
expert_frequency_offset,
T * K,
K,
stream_id,
x_gather_idx,
s_scatter_idx,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
False, # is_varlen_K
activation_type,
is_inference_mode_enabled,
)
T = x.size(0)

o = _DownProjection.apply(
y1,
z,
w2,
b2,
topk_scores,
expert_frequency_offset,
T,
K,
stream_id,
x_gather_idx,
s_scatter_idx,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
False, # is_varlen_K
activation_type,
)
if ep_size > 1:
TK = s_scatter_idx.shape[0]
is_varlen_K = True
else:
TK = T * K
is_varlen_K = False

if T > 0:
y1, z = _UpProjection.apply(
x,
w1,
b1,
expert_frequency_offset,
TK,
K,
stream_id,
x_gather_idx,
s_scatter_idx,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
is_varlen_K,
activation_type,
is_inference_mode_enabled,
)


o = _DownProjection.apply(
y1,
z,
w2,
b2,
topk_scores,
expert_frequency_offset,
T,
K,
stream_id,
x_gather_idx,
s_scatter_idx,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
is_varlen_K,
activation_type,
)

else:
o = x

if ep_size > 1:
o, _ = ep_combine(o, ep_buffer, ep_config, ep_handle)

return o, router_logits, expert_frequency

Expand Down
Loading