From d18f562c85a932dacdaabf5eee2892bde46b4507 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Wed, 11 Feb 2026 08:02:32 +0000 Subject: [PATCH 1/2] add deepseek mhc pre --- dlblas/kernels/ascend/deepseek_mhc/mhc_pre.py | 419 ++++++++++++++++++ 1 file changed, 419 insertions(+) create mode 100644 dlblas/kernels/ascend/deepseek_mhc/mhc_pre.py diff --git a/dlblas/kernels/ascend/deepseek_mhc/mhc_pre.py b/dlblas/kernels/ascend/deepseek_mhc/mhc_pre.py new file mode 100644 index 00000000..9dbd6683 --- /dev/null +++ b/dlblas/kernels/ascend/deepseek_mhc/mhc_pre.py @@ -0,0 +1,419 @@ +import math + +import tilelang +import tilelang.language as T +import torch + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + }, +) +def mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual, + post_mix, + comb_mix, + layer_input, + hidden_size: int, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 16, + hc_mult: int = 4, +): + """Deeply fused kernels, everything other than gemm & sqrsum in mHC pre block.""" + num_tokens = T.dynamic("num_tokens") + hc_mult3 = hc_mult * (2 + hc_mult) + hidden_block = math.gcd(512, hidden_size) + + gemm_out_mul: T.Tensor[[n_splits, num_tokens, hc_mult3], T.float32] + gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] + hc_scale: T.Tensor[[3], T.float32] + hc_base: T.Tensor[[hc_mult3], T.float32] + residual: T.Tensor[[num_tokens, hc_mult, hidden_size], T.bfloat16] + # outputs + post_mix: T.Tensor[[num_tokens, hc_mult], T.float32] + comb_mix: T.Tensor[[num_tokens, hc_mult * hc_mult], T.float32] + layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] + + with T.Kernel(num_tokens, threads=96) as i: + ################################################################## + # _pre_norm_fn_fwd_norm + rms = T.alloc_fragment(1, T.float32) + mixes = T.alloc_fragment(hc_mult3, T.float32) + T.clear(mixes) + rms[0] = 0 + for i_split in T.serial(n_splits): + rms[0] += gemm_out_sqrsum[i_split, i] + rms[0] = T.rsqrt(rms[0] / (hc_mult * hidden_size) + rms_eps) + for j in T.Parallel(hc_mult3): + mixes[j] = 0 + for i_split in T.serial(n_splits): + mixes[j] += gemm_out_mul[i_split, i, j] + mixes[j] *= rms[0] + mixes_shared = T.alloc_shared(hc_mult3, T.float32) + T.copy(mixes, mixes_shared) + + if T.get_thread_binding() < 32: + ################################################################## + # _pre_split_mixes_fwd (post & comb) + cm = T.alloc_fragment((hc_mult, hc_mult), T.float32) + for j in T.Parallel(hc_mult): + post_mix[i, j] = T.sigmoid(mixes_shared[j + hc_mult] * hc_scale[1] + hc_base[j + hc_mult]) * hc_post_mult_value + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = mixes_shared[j * hc_mult + k + hc_mult * 2] * hc_scale[2] + hc_base[j * hc_mult + k + hc_mult * 2] + + ################################################################## + # _sinkhorn_fwd + row_sum = T.alloc_fragment(hc_mult, T.float32) + col_sum = T.alloc_fragment(hc_mult, T.float32) + + # comb = comb.softmax(-1) + eps + row_max = T.alloc_fragment(hc_mult, T.float32) + T.reduce_max(cm, row_max, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = T.exp(cm[j, k] - row_max[j]) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / row_sum[j] + hc_sinkhorn_eps + + # comb = comb / (comb.sum(-2) + eps) + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + for _ in T.serial(sinkhorn_repeat - 1): + # comb = comb / (comb.sum(-1) + eps) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (row_sum[j] + hc_sinkhorn_eps) + + # comb = comb / (comb.sum(-2) + eps) + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + # save comb_mix to global memory + for j, k in T.Parallel(hc_mult, hc_mult): + comb_mix[i, j * hc_mult + k] = cm[j, k] + else: + ################################################################## + # _pre_split_mixes_fwd (pre) + pre_mix_shared = T.alloc_shared(hc_mult, T.float32) + for j in T.Parallel(hc_mult): + pre_mix_shared[j] = ( + T.sigmoid( + mixes_shared[j] * hc_scale[0] + hc_base[j], + ) + + hc_pre_eps + ) + ################################################################### + # _pre_apply_mix_fwd + for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2): + xs = T.alloc_shared((hc_mult, hidden_block), T.float32) + xl = T.alloc_fragment((hc_mult, hidden_block), T.float32) + T.copy(residual[i, 0, i0_h * hidden_block], xs) + T.copy(xs, xl) + + ol = T.alloc_fragment(hidden_block, T.float32) + T.clear(ol) + + for i_hc in T.serial(hc_mult): + pre = pre_mix_shared[i_hc] + for i1_h in T.Parallel(hidden_block): + ol[i1_h] += pre * xl[i_hc, i1_h] + + T.copy(ol, layer_input[i, i0_h * hidden_block]) + + +@tilelang.jit +def mhc_pre_gemm_sqrsum_tilelang( + x, + fn, + out, + sqrsum, + hc_mult3: int, + hc_hidden_size: int, + token_block: int = 32, + hidden_block: int = 256, +) -> tilelang.JITKernel: + """Not highly optimized TileLang implementation of fused gemm and sqrsum in mHC pre block.""" + assert hc_mult3 <= 32 # should be 24 usually + num_tokens = T.dynamic("num_tokens") + assert hc_hidden_size % hidden_block == 0 + + x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) + fn: T.Tensor((hc_mult3, hc_hidden_size), T.float32) + out: T.Tensor((num_tokens, hc_mult3), T.float32) + sqrsum: T.Tensor((num_tokens), T.float32) + + with T.Kernel(T.ceildiv(num_tokens, token_block)) as px: + out_frag = T.alloc_fragment((token_block, 32), T.float32) + sqrsum_part = T.alloc_fragment((token_block, 4), T.float32) + T.clear(out_frag) + T.clear(sqrsum_part) + for pz in T.Pipelined(hc_hidden_size // hidden_block, num_stages=2): + x_smem_16 = T.alloc_shared((token_block, hidden_block), T.bfloat16) + fn_smem = T.alloc_shared((32, hidden_block), T.float32) + + T.annotate_layout({x_smem_16: tilelang.layout.make_swizzled_layout(x_smem_16)}) + + T.copy(x[px * token_block, pz * hidden_block], x_smem_16) + T.copy(fn[0, pz * hidden_block], fn_smem) + + x_frag_16 = T.alloc_fragment((token_block, hidden_block), T.bfloat16) + T.copy(x_smem_16, x_frag_16) + x_frag = T.alloc_fragment((token_block, hidden_block), T.float32) + T.copy(x_frag_16, x_frag) + + for jj in T.serial(hidden_block // 4): + for i, j in T.Parallel(token_block, 4): + sqrsum_part[i, j] += x_frag[i, jj * 4 + j] * x_frag[i, jj * 4 + j] + + # should be TF32 gemm + T.gemm( + x_frag, + fn_smem, + out_frag, + transpose_A=False, + transpose_B=True, + wg_wait=0, + clear_accum=False, + ) + sqrsum_l = T.alloc_fragment(token_block, T.float32) + T.reduce_sum(sqrsum_part, sqrsum_l) + for i in T.Parallel(token_block): + sqrsum[px * token_block + i] = sqrsum_l[i] + for i, j in T.Parallel(token_block, 32): + if j < hc_mult3: + out[px * token_block + i, j] = out_frag[i, j] + + +def mhc_pre( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass for mHC pre block. + + Args: + residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16 + fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32 + hc_scale: shape (3,), dtype torch.float32 + hc_base: shape (hc_mult3,), dtype torch.float32 + rms_eps: RMS normalization epsilon + hc_pre_eps: pre-mix epsilon + hc_sinkhorn_eps: sinkhorn epsilon + hc_post_mult_value: post-mix multiplier value + sinkhorn_repeat: number of sinkhorn iterations + n_splits: split-k factor; TileLang version of mhc_pre_gemm_sqrsum doesn't support this + + Returns: + post_mix: shape (..., hc_mult), dtype torch.float32 + comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32 + layer_input: shape (..., hidden_size), dtype torch.bfloat16 + """ + + # Validate shapes + assert residual.dtype == torch.bfloat16 + assert fn.dtype == torch.float32 + assert hc_scale.dtype == torch.float32 + assert hc_base.dtype == torch.float32 + + hc_mult = residual.shape[-2] + hidden_size = residual.shape[-1] + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + + hc_hidden_size = hc_mult * hidden_size + assert fn.shape[0] == hc_mult3 + assert fn.shape[1] == hc_hidden_size + assert hc_scale.shape == (3,) + assert hc_base.shape == (hc_mult3,) + + outer_shape = residual.shape[:-2] + + residual_flat = residual.view(-1, hc_mult, hidden_size) + num_tokens = residual_flat.shape[0] + fn_flat = fn + + post_mix = torch.empty(num_tokens, hc_mult, dtype=torch.float32, device=residual.device) + comb_mix = torch.empty(num_tokens, hc_mult2, dtype=torch.float32, device=residual.device) + layer_input = torch.empty(num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device) + + gemm_out_mul = torch.empty(n_splits, num_tokens, hc_mult3, dtype=torch.float32, device=residual.device) + gemm_out_sqrsum = torch.empty(n_splits, num_tokens, dtype=torch.float32, device=residual.device) + assert n_splits == 1, "The simple TileLang version gemm_sqrsum doesn't support split-k" + mhc_pre_gemm_sqrsum_tilelang( + residual_flat.view(num_tokens, hc_mult * hidden_size), + fn_flat, + gemm_out_mul.squeeze(0), + gemm_out_sqrsum.squeeze(0), + hc_mult3, + hc_mult * hidden_size, + ) + + mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual_flat, + post_mix, + comb_mix, + layer_input, + hidden_size, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + hc_mult, + ) + + post_mix = post_mix.view(*outer_shape, hc_mult, 1) + comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult) + layer_input = layer_input.view(*outer_shape, hidden_size) + + return post_mix, comb_mix, layer_input + + +def sinkhorn_normalize_ref(x: torch.Tensor, repeat: int, eps: float) -> torch.Tensor: + x = x.softmax(-1) + eps + x = x / (x.sum(-2, keepdim=True) + eps) + for _ in range(repeat - 1): + x = x / (x.sum(-1, keepdim=True) + eps) + x = x / (x.sum(-2, keepdim=True) + eps) + return x + + +def mhc_pre_ref( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hc_mult = residual.shape[-2] + + residual_flat = residual.flatten(-2, -1).float() + sqrsum = residual_flat.square().sum(-1) + mixes = residual_flat @ fn.T * (sqrsum.unsqueeze(-1) / fn.shape[-1] + rms_eps).rsqrt() + + hc_scale = torch.cat( + [ + hc_scale[0].expand(hc_mult), + hc_scale[1].expand(hc_mult), + hc_scale[2].expand(hc_mult * hc_mult), + ], + ) + mixes = mixes * hc_scale + hc_base + + pre_mix = mixes[:, :hc_mult].sigmoid().unsqueeze(-1) + hc_pre_eps + post_mix = (mixes[:, hc_mult : 2 * hc_mult].sigmoid() * hc_post_mult_value).unsqueeze(-1) + res_mix = mixes[:, 2 * hc_mult :].view(-1, hc_mult, hc_mult) + + res_mix = sinkhorn_normalize_ref(res_mix, repeat=sinkhorn_repeat, eps=hc_sinkhorn_eps) + + layer_input = (residual * pre_mix).sum(-2).bfloat16() + + return post_mix, res_mix, layer_input + + +def generate_test_data( + n: int, + hc_mult: int, + hidden_size: int, + rms_eps: float = 1e-6, + hc_pre_eps: float = 1e-6, + hc_sinkhorn_eps: float = 1e-6, + hc_post_mult_value: float = 1.0, + sinkhorn_repeat: int = 10, +) -> dict[str, torch.Tensor | float]: + """Generate test data for big fuse operator.""" + torch.random.manual_seed(42) + + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + device = "cuda" + + residual = ( + torch.randn((n, hc_mult, hidden_size), dtype=torch.float, device=device) + .mul(1 + torch.arange(hc_mult, device=device).mul(0.01).view(1, -1, 1)) + .bfloat16() + ) + + fn = ( + torch.randn((hc_mult3, hc_mult, hidden_size), dtype=torch.float, device=device) + * 1e-4 + * (1 + torch.arange(hc_mult, device=device).mul(0.01).view(1, -1, 1)) + ).flatten(1, 2) + + hc_scale = torch.randn((3,), dtype=torch.float, device=device) * 0.1 + + hc_base = torch.randn((hc_mult3,), dtype=torch.float, device=device) * 0.1 + + return { + "residual": residual, + "fn": fn, + "hc_scale": hc_scale, + "hc_base": hc_base, + "rms_eps": rms_eps, + "hc_pre_eps": hc_pre_eps, + "hc_sinkhorn_eps": hc_sinkhorn_eps, + "hc_post_mult_value": hc_post_mult_value, + "sinkhorn_repeat": sinkhorn_repeat, + } + + +def test(n: int, hidden_size: int, hc_mult: int) -> None: + print(f"Testing mhc_pre with {n=} {hidden_size=} {hc_mult=}") + test_data = generate_test_data( + n=n, + hc_mult=hc_mult, + hidden_size=hidden_size, + ) + + # Forward pass with big fuse + post_mix_fused, comb_mix_fused, layer_input_fused = mhc_pre(**test_data) + + # Forward pass with reference + post_mix_ref, comb_mix_ref, layer_input_ref = mhc_pre_ref(**test_data) + + # Compare outputs + torch.testing.assert_close(post_mix_fused, post_mix_ref) + torch.testing.assert_close(comb_mix_fused, comb_mix_ref) + torch.testing.assert_close(layer_input_fused, layer_input_ref) + + +def main(): + for n1 in [512, 1024, 2048, 8192]: + for hidden_size in [1280, 2560, 4096]: + for hc_mult in [4]: + test(n=n1, hidden_size=hidden_size, hc_mult=hc_mult) + + +if __name__ == "__main__": + main() From af4f31582bb4d95db90cd04876f41b44917333fe Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Wed, 11 Feb 2026 08:23:10 +0000 Subject: [PATCH 2/2] Add mhc_pre triton version. --- .../ascend/deepseek_mhc/mhc_pre_triton.py | 226 ++++++++++++++++++ 1 file changed, 226 insertions(+) create mode 100644 dlblas/kernels/ascend/deepseek_mhc/mhc_pre_triton.py diff --git a/dlblas/kernels/ascend/deepseek_mhc/mhc_pre_triton.py b/dlblas/kernels/ascend/deepseek_mhc/mhc_pre_triton.py new file mode 100644 index 00000000..f24d4066 --- /dev/null +++ b/dlblas/kernels/ascend/deepseek_mhc/mhc_pre_triton.py @@ -0,0 +1,226 @@ +import triton +import triton.language as tl +import torch +import torch_npu + + +@triton.jit +def mhc_pre_gemm_sqrsum_triton_kernel( + # Pointers to matrices + x_ptr, fn_ptr, out_ptr, sqrsum_ptr, + # Matrix dimensions + num_tokens, hidden_size, hc_mult3, + # Stride variables + stride_x_token, stride_x_hidden, + stride_fn_mult3, stride_fn_hidden, + stride_out_token, + # Block sizes + token_block: tl.constexpr = 32, + hidden_block: tl.constexpr = 64, + # Meta-parameter for BLOCK_SIZE_N + BLOCK_SIZE_K: tl.constexpr = 64, + BLOCK_SIZE_N: tl.constexpr = 32, +): + """ + Triton kernel for fused GEMM and sqrsum computation. + + Args: + x_ptr: Input tensor of shape (num_tokens, hc_hidden_size) + fn_ptr: Weight tensor of shape (hc_mult3, hc_hidden_size) + out_ptr: Output tensor of shape (num_tokens, hc_mult3) + sqrsum_ptr: Square sum output of shape (num_tokens,) + num_tokens: Number of input tokens + hc_hidden_size: Hidden dimension size + hc_mult3: Size of the third dimension + """ + # Program IDs + token_pid = tl.program_id(axis=0) # Token block index + + # Token range for this program + token_offset = token_pid * token_block + tl.arange(0, token_block) + + # Initialize accumulators for square sums + sqrsum_acc = tl.zeros((1,), dtype=tl.float32) # token_block + thirty = tl.arange(0, 32) + + for hidden_idx in range(0, hidden_size, hidden_block): + hidden_offset = hidden_idx + tl.arange(0, hidden_block) + x_offset = token_offset[:, None] + hidden_offset[None, :] + + x = tl.load( + x_ptr + token_offset[:, None] + hidden_offset[None, :] + ).to(tl.bfloat16) + x = x.to(tl.float32) + sqrsum_prt = x * x + sqrsum_acc += tl.sum(sqrsum_prt) + + fn_offset = tl.arange(0, 32) + fn = tl.load( + fn_ptr + fn_offset[:, None] * stride_fn_mult3 + hidden_offset[None, :] * stride_fn_hidden + # mask=fn_mask[:, None] & (hidden_offsets[None, :] < hc_hidden_size), + # other=0.0 + ).to(tl.float32) + + # !TODO transpose fn + fn = tl.trans(fn, 1, 0) + out_val = tl.dot(x, fn) + + out_token_offset = token_offset[:, None] + out_mult3_offset = thirty[None, :] + tl.store( + out_ptr + out_token_offset * stride_out_token + out_mult3_offset * 1, + out_val, + ) + + # Write square sum results + sqrsum_val = tl.full((32, 32), 1, dtype=tl.float32) * sqrsum_acc + thirty = tl.arange(0, 32) + sqrsum_offset = thirty[:, None] + thirty[None, :] + tl.store( + sqrsum_ptr + sqrsum_offset, + sqrsum_val + ) + + +def mhc_pre_gemm_sqrsum_triton( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + hc_mult3: int, + hc_hidden_size: int, + token_block: int = 32, + hidden_block: int = 64, +): + """ + Triton implementation of fused gemm and sqrsum in mHC pre block. + + Args: + x: Input tensor of shape (num_tokens, hc_hidden_size), dtype torch.bfloat16 + fn: Weight tensor of shape (hc_mult3, hc_hidden_size), dtype torch.float32 + out: Output tensor of shape (num_tokens, hc_mult3), dtype torch.float32 + sqrsum: Square sum tensor of shape (num_tokens,), dtype torch.float32 + hc_mult3: Size of the third dimension + hc_hidden_size: Hidden size dimension + token_block: Size of token block for tiling + hidden_block: Size of hidden block for tiling + """ + # Get tensor dimensions + num_tokens = x.shape[0] + + # Define strides + stride_x_token = x.stride(0) + stride_x_hidden = x.stride(1) + stride_fn_mult3 = fn.stride(0) + stride_fn_hidden = fn.stride(1) + stride_out_token = out.stride(0) + + # Launch grid + grid = lambda META: (1,) + + # Set appropriate block sizes based on actual dimensions + BLOCK_SIZE_N = min(32, hc_mult3) + BLOCK_SIZE_K = hidden_block + + # Launch kernel + import time + start = time.time() + mhc_pre_gemm_sqrsum_triton_kernel[grid]( + x, fn, out, sqrsum, + num_tokens, hc_hidden_size, hc_mult3, + stride_x_token, stride_x_hidden, + stride_fn_mult3, stride_fn_hidden, + stride_out_token, + token_block=token_block, + hidden_block=hidden_block, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + print(time.time() - start) + + return out, sqrsum + + +def generate_test_data( + n: int, + hc_mult: int, + hidden_size: int, + rms_eps: float = 1e-6, + hc_pre_eps: float = 1e-6, + hc_sinkhorn_eps: float = 1e-6, + hc_post_mult_value: float = 1.0, + sinkhorn_repeat: int = 10, +): + """Generate test data for big fuse operator.""" + torch.random.manual_seed(0) + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + device = "npu" + + residual = ( + torch.ones((n, hc_mult, hidden_size), dtype=torch.float, device=device) + .mul(1 + torch.arange(hc_mult, device=device).mul(0.01).view(1, -1, 1)) + .bfloat16() + ) + + fn = ( + torch.ones((hc_mult3, hc_mult, hidden_size), dtype=torch.float, device=device) + * 1e-4 + * (1 + torch.arange(hc_mult, device=device).mul(0.01).view(1, -1, 1)) + ).flatten(1, 2) + + hc_scale = torch.ones((3,), dtype=torch.float, device=device) * 0.1 + + hc_base = torch.ones((hc_mult3,), dtype=torch.float, device=device) * 0.1 + + return { + "residual": residual, + "fn": fn, + "hc_scale": hc_scale, + "hc_base": hc_base, + "rms_eps": rms_eps, + "hc_pre_eps": hc_pre_eps, + "hc_sinkhorn_eps": hc_sinkhorn_eps, + "hc_post_mult_value": hc_post_mult_value, + "sinkhorn_repeat": sinkhorn_repeat, + } + + +def main(): + """ + Main function to test the fused moe kernel. + """ + # Parameters + dhidden = 1280 # 7168 + dexpert = 2048 + n_routed_experts = 8 + n_shared_experts = 1 + n_experts_per_token = 4 + batch_size = 1 + seq_len = 512 # 8192 + + # Generate test data + test_data = generate_test_data( + n=batch_size * seq_len, + hc_mult=n_experts_per_token, + hidden_size=dhidden, + rms_eps=1e-6, + hc_pre_eps=1e-6, + hc_sinkhorn_eps=1e-6, + hc_post_mult_value=1.0, + sinkhorn_repeat=10, + ) + + out, sqrsum = mhc_pre_gemm_sqrsum_triton( + test_data["residual"], + test_data["fn"], + torch.empty_like(test_data["residual"]), + torch.empty_like(test_data["residual"][:, 0]), + n_experts_per_token, + dhidden, + token_block=32, + hidden_block=64, + ) + +if __name__ == "__main__": + main()