diff --git a/tests/ap/__main__.py b/tests/ap/__main__.py index 9301b01..e9448b2 100644 --- a/tests/ap/__main__.py +++ b/tests/ap/__main__.py @@ -2,3 +2,4 @@ # import test_binary_trivial_reduce import test_matmul_binary import test_matmul_epilogue +import test_zip_binary diff --git a/tests/ap/make_axpr.sh b/tests/ap/make_axpr.sh index 77588e5..ee9d1aa 100644 --- a/tests/ap/make_axpr.sh +++ b/tests/ap/make_axpr.sh @@ -22,6 +22,8 @@ FILENAMES_ARRAY=( "matmul_epilogue_pass" "test_matmul_binary" "test_matmul_epilogue" + "zip_variadic_tpl" + "test_zip_binary" ) for filename in "${FILENAMES_ARRAY[@]}" do diff --git a/tests/ap/paddle-tests/test_zip.py b/tests/ap/paddle-tests/test_zip.py new file mode 100644 index 0000000..b69a1a6 --- /dev/null +++ b/tests/ap/paddle-tests/test_zip.py @@ -0,0 +1,163 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from os.path import dirname + +sys.path.append(dirname(__file__)) + +import unittest +import utils + +import paddle +from paddle.static import InputSpec + + +def moe_zip( + unzipped_tokens, + zipped_expertwise_rowmap, + expert_routemap_topk, + unzipped_token_probs, +): + zipped_tokens, zipped_prob_topk = paddle._C_ops._moe_zip( + unzipped_tokens, + zipped_expertwise_rowmap, + expert_routemap_topk, + unzipped_token_probs, + ) + return zipped_tokens, zipped_prob_topk + + +class CINNSubGraphNet(paddle.nn.Layer): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x1, x2, x3, x4): + zipped_tokens, zipped_prob_topk = self.fn(x1, x2, x3, x4) + return zipped_tokens, zipped_prob_topk + + +class TestAPZip(unittest.TestCase): + """ + Test Pir API + @to_static + CINN. + """ + + def setUp(self): + paddle.seed(2022) + self.prepare_data() + + def prepare_data(self): + u_seqlen = 4 + token_len = 8 + seqlen = 3 + num_experts = 4 + topk = 8 + unzipped_tokens_data = [ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0], + [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ] + self.unzipped_tokens_shape = [u_seqlen, token_len] + self.unzipped_tokens_dtype = "bfloat16" + self.unzipped_tokens = paddle.to_tensor( + unzipped_tokens_data, dtype=self.unzipped_tokens_dtype + ) + self.unzipped_tokens.stop_gradient = False + + zipped_expertwise_rowmap_data = [ + [0, 3, -1, -1], + [-1, 1, -1, -1], + [2, -1, -1, -1], + ] + self.zipped_expertwise_rowmap_shape = [seqlen, num_experts] + self.zipped_expertwise_rowmap_dtype = "int32" + self.zipped_expertwise_rowmap = paddle.to_tensor( + zipped_expertwise_rowmap_data, self.zipped_expertwise_rowmap_dtype + ) + self.zipped_expertwise_rowmap.stop_gradient = False + + routemap_topk_data = [ + [-1, -1, 0, 1, -1, -1, -1, -1], + [1, -1, -1, -1, -1, -1, -1, -1], + [-1, 0, -1, -1, -1, -1, -1, -1], + ] + self.expert_routemap_topk_shape = [seqlen, topk] + self.expert_routemap_topk_dtype = "int32" + self.expert_routemap_topk = paddle.to_tensor( + routemap_topk_data, dtype=self.expert_routemap_topk_dtype + ) + self.expert_routemap_topk.stop_gradient = False + + unzipped_token_probs_data = [[0.50000000], [1.0], [1.0], [0.50000000]] + self.unzipped_token_probs_shape = [u_seqlen, 1] + self.unzipped_token_probs_dtype = "float32" + self.unzipped_token_probs = paddle.to_tensor( + unzipped_token_probs_data, self.unzipped_token_probs_dtype + ) + self.unzipped_token_probs.stop_gradient = False + self.zipped_tokens_type = "bfloat16" + self.zipped_prob_topk_type = "float32" + + def eval_symbolic(self, net, use_cinn, profile): + input_spec = [ + InputSpec( + shape=self.unzipped_tokens_shape, dtype=self.unzipped_tokens_dtype + ), + InputSpec( + shape=self.zipped_expertwise_rowmap_shape, + dtype=self.zipped_expertwise_rowmap_dtype, + ), + InputSpec( + shape=self.expert_routemap_topk_shape, + dtype=self.expert_routemap_topk_dtype, + ), + InputSpec( + shape=self.unzipped_token_probs_shape, + dtype=self.unzipped_token_probs_dtype, + ), + ] + net = utils.apply_to_static(net, use_cinn, input_spec) + net.eval() + zipped_tokens, zipped_prob_topk = utils.run_with_profile( + profile, + net, + self.unzipped_tokens, + self.zipped_expertwise_rowmap, + self.expert_routemap_topk, + self.unzipped_token_probs, + ) + return zipped_tokens, zipped_prob_topk + + def test_pure_zip(self): + profile = False + net = CINNSubGraphNet(moe_zip) + cinn_out = self.eval_symbolic(net, use_cinn=True, profile=profile) + dy2st_out = self.eval_symbolic(net, use_cinn=False, profile=profile) + if not profile: + utils.check_result( + self.zipped_tokens_type, cinn_out[0].numpy(), dy2st_out[0].numpy(), True + ) + + utils.check_result( + self.zipped_prob_topk_type, + cinn_out[1].numpy(), + dy2st_out[1].numpy(), + True, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/ap/test_zip.sh b/tests/ap/test_zip.sh new file mode 100644 index 0000000..1c0f230 --- /dev/null +++ b/tests/ap/test_zip.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +export CUDA_VISIBLE_DEVICES="2" +export NVIDIA_TF32_OVERRIDE=0 + +sh make_axpr.sh + +# AP specific settings +export FLAGS_enable_ap=1 +export AP_WORKSPACE_DIR=$(pwd)/ap_workspace +export AP_PATH=$(pwd)/ + +# CINN related settings +export FLAGS_check_infer_symbolic=1 +export FLAGS_enable_pir_api=1 +export FLAGS_cinn_bucket_compile=True +export FLAGS_prim_enable_dynamic=true +export FLAGS_prim_all=True +export FLAGS_pir_apply_shape_optimization_pass=1 +export FLAGS_group_schedule_tiling_first=1 +export FLAGS_cinn_new_group_scheduler=1 + +export GLOG_vmodule=ap_generic_drr_pass=6 + +python $(pwd)/paddle-tests/test_zip.py diff --git a/tests/ap/test_zip_binary.py b/tests/ap/test_zip_binary.py new file mode 100644 index 0000000..f6b0c06 --- /dev/null +++ b/tests/ap/test_zip_binary.py @@ -0,0 +1,79 @@ +import pir +import abstract_drr + +import zip_variadic_tpl +import kernel_arg_id_util +import program_translator_util +import op_compute_translator_util + + +@abstract_drr.register_drr_pass("pure_zip_fuse", nice=0) +class PureZipFuse(abstract_drr.DrrPass): + + def source_pattern(self, o, t): + print("in source pattern") + o.moe_zip_op = o.ap_native_op("pd_op._moe_zip") + o.moe_zip_op( + [ + t.unzipped_tokens, + t.zipped_expertwise_rowmap, + t.expert_routemap_topk, + t.unzipped_token_probs, + ], + [t.zipped_tokens, t.zipped_probs_topk], + ) + + def constraint(self, o, t): + return True + + def result_pattern(self, o, t): + o.fustion_op = o.ap_pattern_fusion_op(self.code_gen) + o.fustion_op( + [ + t.unzipped_tokens, + t.zipped_expertwise_rowmap, + t.expert_routemap_topk, + t.unzipped_token_probs, + ], + [t.zipped_tokens, t.zipped_probs_topk], + ) + + def code_gen(self, ctx, o, t): + mut_kernel_arg_id_registry = kernel_arg_id_util.KernelArgIdNameRegistry( + code_gen_ctx=ctx, tensor_match_ctx=t, name_prefix="" + ) + template_module = zip_variadic_tpl.ZipVariadicTemplate( + mut_kernel_arg_id_registry=mut_kernel_arg_id_registry + ) + return template_module.compile( + unzipped_tokens_in_karg=ctx.in_tensor_data_ptr_kernel_arg_id( + t.unzipped_tokens + ), + zipped_expertwise_rowmap_in_karg=ctx.in_tensor_data_ptr_kernel_arg_id( + t.zipped_expertwise_rowmap + ), + expert_routemap_topk_in_karg=ctx.in_tensor_data_ptr_kernel_arg_id( + t.expert_routemap_topk + ), + unzipped_token_probs_in_karg=ctx.in_tensor_data_ptr_kernel_arg_id( + t.unzipped_token_probs + ), + zipped_tokens_out_karg=ctx.out_tensor_data_ptr_kernel_arg_id( + t.zipped_tokens + ), + zipped_probs_topk_out_karg=ctx.out_tensor_data_ptr_kernel_arg_id( + t.zipped_probs_topk + ), + topk_kargs=ctx.dim_expr_kernel_arg_id( + t.expert_routemap_topk.symbolic_shape_to_list()[1] + ), + num_experts_kargs=ctx.dim_expr_kernel_arg_id( + t.zipped_expertwise_rowmap.symbolic_shape_to_list()[1] + ), + token_length_kargs=ctx.dim_expr_kernel_arg_id( + t.unzipped_tokens.symbolic_shape_to_list()[1] + ), + total_zipped_tokens_num_kargs=ctx.dim_expr_kernel_arg_id( + t.zipped_tokens.symbolic_shape_to_list()[0] + ), + ) diff --git a/tests/ap/zip/naive_zip.cuh b/tests/ap/zip/naive_zip.cuh new file mode 100644 index 0000000..db86fdf --- /dev/null +++ b/tests/ap/zip/naive_zip.cuh @@ -0,0 +1,126 @@ +#pragma once + +#include +#include +#include +#include + +namespace ap { + +template +__global__ void zip_naive_kernel( + const half *__restrict__ unzipped_tokens_in, + const int *__restrict__ zipped_expertwise_rowmap, + const int *__restrict__ expert_routemap_topk, + const float *__restrict__ unzipped_token_probs, + half *__restrict__ zipped_tokens_out , + float *__restrict__ zipped_probs_topk, + const int token_length, + const int total_zipped_tokens_num) { + const int this_row = blockIdx.x; + if (this_row >= total_zipped_tokens_num) return; + + const __nv_bfloat16 *unzipped_tokens = + reinterpret_cast(unzipped_tokens_in); + __nv_bfloat16 *zipped_tokens = + reinterpret_cast<__nv_bfloat16 *>(zipped_tokens_out); + + int local_row_fetchlist[num_experts]; + +// -------------------------初始化任务表 ------------------------ +#pragma unroll + for (int expert = 0; expert < num_experts; ++expert) { + const int fetch_row = + zipped_expertwise_rowmap[this_row * num_experts + expert]; + local_row_fetchlist[expert] = fetch_row; + } + +#pragma unroll + for (int k = 0; k < topk; ++k) { + const int expert_idx = expert_routemap_topk[this_row * topk + k]; + if (expert_idx < 0) [[likely]] + continue; + const int expert_fetch_row = local_row_fetchlist[expert_idx]; + zipped_probs_topk[this_row * topk + k] = + unzipped_token_probs[expert_fetch_row]; + } + constexpr int vecSize = 2; // __nv_bfloat162 = 2 x bfloat16 + const int num_full_vec = token_length / vecSize; + const int remaining_elems = token_length % vecSize; + const int thread_stride = blockDim.x * vecSize; + + if constexpr (MP) { + // ------------------------ 手动混合精度 --------------------------------- + // 齐整区域向量化搬移 + for (int x_offset = threadIdx.x * vecSize; + x_offset < num_full_vec * vecSize; + x_offset += thread_stride) { + float2 sum = {0.0f, 0.0f}; + __nv_bfloat162 *out_ptr = reinterpret_cast<__nv_bfloat162 *>( + &zipped_tokens[this_row * token_length + x_offset]); +#pragma unroll + for (int expert = 0; expert < num_experts; ++expert) { + const int fetch_row = local_row_fetchlist[expert]; + if (fetch_row < 0) continue; + // 手动类型提升 + float2 token_vec = + __bfloat1622float2(*reinterpret_cast( + &unzipped_tokens[fetch_row * token_length + x_offset])); + sum.x = __fadd_rn(token_vec.x, sum.x); + sum.y = __fadd_rn(token_vec.y, sum.y); + } + // 类型下降为原有精度 + *out_ptr = __float22bfloat162_rn(sum); + } + + // 剩余元素处理 + for (int i = num_full_vec * vecSize + threadIdx.x; i < token_length; + i += blockDim.x) { + float sum = 0.0f; +#pragma unroll + for (int expert = 0; expert < num_experts; ++expert) { + int fetch_row = local_row_fetchlist[expert]; + if (fetch_row < 0) continue; + float token_val = + __bfloat162float(unzipped_tokens[fetch_row * token_length + i]); + sum = __fadd_rn(token_val, sum); + } + zipped_tokens[this_row * token_length + i] = __float2bfloat16_rn(sum); + } + } else { + // ------------------------ BF16 intrinsics 加权累加 ----------------------- + // 齐整区域向量化搬移 + for (int x_offset = threadIdx.x * vecSize; + x_offset < num_full_vec * vecSize; + x_offset += thread_stride) { + __nv_bfloat162 sum = {0, 0}; + __nv_bfloat162 *out_ptr = reinterpret_cast<__nv_bfloat162 *>( + &zipped_tokens[this_row * token_length + x_offset]); +#pragma unroll + for (int expert = 0; expert < num_experts; ++expert) { + const int fetch_row = local_row_fetchlist[expert]; + if (fetch_row < 0) continue; + __nv_bfloat162 token_vec = *reinterpret_cast( + &unzipped_tokens[fetch_row * token_length + x_offset]); + sum = __hadd2(sum, token_vec); + } + *out_ptr = sum; + } + + // 剩余元素处理 + for (int i = num_full_vec * vecSize + threadIdx.x; i < token_length; + i += blockDim.x) { + __nv_bfloat16 sum = (__nv_bfloat16)0; +#pragma unroll + for (int expert = 0; expert < num_experts; ++expert) { + int fetch_row = local_row_fetchlist[expert]; + if (fetch_row < 0) continue; + __nv_bfloat16 token_val = unzipped_tokens[fetch_row * token_length + i]; + sum = __hadd(sum, token_val); + } + zipped_tokens[this_row * token_length + i] = sum; + } + } +} + +} // namespace ap \ No newline at end of file diff --git a/tests/ap/zip_variadic_tpl.py b/tests/ap/zip_variadic_tpl.py new file mode 100644 index 0000000..d55ad60 --- /dev/null +++ b/tests/ap/zip_variadic_tpl.py @@ -0,0 +1,183 @@ +import low_level_ir_code_gen_ctx_util +import kernel_arg_translator_util + + +def make_kernel_arg_translator(): + return kernel_arg_translator_util.KernelArgTranslator(param_struct_name="args") + + +class ZipVariadicTemplate: + + def __init__(self, mut_kernel_arg_id_registry): + self.mut_kernel_arg_id_registry = mut_kernel_arg_id_registry + self.kernel_arg_translator = make_kernel_arg_translator() + self.dtype2type_name = OrderedDict( + [ + [PointerType.const_float_ptr, "const float*"], + [PointerType.const_float16_ptr, "const half*"], + [PointerType.const_bfloat16_ptr, "const half*"], + [PointerType.const_int32_t_ptr, "const int32_t*"], + [PointerType.float_ptr, "float*"], + [PointerType.float16_ptr, "half*"], + [PointerType.bfloat16_ptr, "half*"], + [PointerType.int32_t_ptr, "int32_t*"], + [DataType.float, "float"], + [DataType.float16, "half"], + [DataType.int64_t, "int64_t"], + ] + ) + self.kernel_name = "MoeZipVariadicKernel" + self.library_name = "moe_zip_variadic_kernel" + + def _register_name(self, pair): + registry = self.mut_kernel_arg_id_registry + registry.get_or_create_kernel_arg_id_manul_var_name( + kernel_arg_id=pair[0], cpp_var_name=pair[1] + ) + + def compile( + self, + unzipped_tokens_in_karg, + zipped_expertwise_rowmap_in_karg, + expert_routemap_topk_in_karg, + unzipped_token_probs_in_karg, + zipped_tokens_out_karg, + zipped_probs_topk_out_karg, + topk_kargs, + num_experts_kargs, + token_length_kargs, + total_zipped_tokens_num_kargs, + ): + kargs_name_pair_list = [ + [unzipped_tokens_in_karg, "unzipped_tokens_in"], + [zipped_expertwise_rowmap_in_karg, "zipped_expertwise_rowmap"], + [expert_routemap_topk_in_karg, "expert_routemap_topk"], + [unzipped_token_probs_in_karg, "unzipped_token_probs"], + [zipped_tokens_out_karg, "zipped_tokens_out"], + [zipped_probs_topk_out_karg, "zipped_probs_topk_out"], + ] + map(self._register_name, kargs_name_pair_list) + project_module = self.make_project() + return CodeGenResult( + module=project_module, + kernel_dispatch_func=KernelDispatch, + kernel_dispatch_const_data=BuiltinSerializableAttrMap( + kernel_args_getters=[ + *self.get_kernel_arg_runtime_getters(), + topk_kargs.runtime_getter, + num_experts_kargs.runtime_getter, + token_length_kargs.runtime_getter, + total_zipped_tokens_num_kargs.runtime_getter, + ] + ), + ) + + def get_kernel_arg_types(self): + all_kernel_arg_id_and_unique_names = ( + self.mut_kernel_arg_id_registry.all_kernel_arg_id2unique_name.items() + ) + return map(lambda pair: pair[0].type, all_kernel_arg_id_and_unique_names) + + def get_kernel_arg_runtime_getters(self): + all_kernel_arg_id_and_unique_names = ( + self.mut_kernel_arg_id_registry.all_kernel_arg_id2unique_name.items() + ) + return map( + lambda pair: pair[0].runtime_getter, all_kernel_arg_id_and_unique_names + ) + + def get_kernel_arg_list_str(self, for_declare): + + def declare_epilogue_arguments_field(pair): + kernel_arg_id = pair[0] + var_name = pair[1] + field_name = self.kernel_arg_translator.get_param_struct_field_name( + var_name + ) + dtype = kernel_arg_id.type + type_name = self.dtype2type_name[dtype] + return f"{type_name} {field_name}" if for_declare else f"{field_name}" + + all_kernel_arg_id_and_names = ( + self.mut_kernel_arg_id_registry.all_kernel_arg_id2unique_name.items() + ) + return ", ".join( + map(declare_epilogue_arguments_field, all_kernel_arg_id_and_names) + ) + + def make_project(self): + code_template = """ +// auto generated codes +#include "naive_zip.cuh" +namespace ap{ + +static void RunZipWithVariadicKernel(cudaStream_t cuda_stream_ptr, ${AP_KERNEL_ARGS_DECLARE}, const int64_t topk, const int64_t num_experts, const int64_t token_length, const int64_t total_zipped_tokens_num) { + dim3 grid, block; + grid.x = total_zipped_tokens_num; + block.x = 256; + if(topk == 8 && num_experts == 4) { + zip_naive_kernel<8, 4><<>>(${AP_KERNEL_ARGS_CALL}, token_length, total_zipped_tokens_num); + } + return; +} + +} // namespace ap + +extern "C" { +void ${kernel_name}(void* stream_ptr, ${AP_KERNEL_ARGS_DECLARE}, const int64_t topk, const int64_t num_experts, const int64_t token_length, const int64_t total_zipped_tokens_num) { + cudaStream_t* cuda_stream_ptr = reinterpret_cast(stream_ptr); + ap::RunZipWithVariadicKernel(*cuda_stream_ptr, ${AP_KERNEL_ARGS_CALL}, topk, num_experts, token_length, total_zipped_tokens_num); +} +} + """ + code = ( + code_template.replace( + "${AP_KERNEL_ARGS_DECLARE}", + self.get_kernel_arg_list_str(for_declare=True), + ) + .replace( + "${AP_KERNEL_ARGS_CALL}", + self.get_kernel_arg_list_str(for_declare=False), + ) + .replace("${kernel_name}", self.kernel_name) + ) + source_dir = "/project/AP/Athena/tests/ap/zip" + compile_cmd = ( + "nvcc -std=c++17 -O3 -Xcompiler=-fPIC -arch=sm_80 --expt-relaxed-constexpr" + ) + compile_cmd = compile_cmd + " -I " + source_dir + compile_cmd = compile_cmd + " -DAP_ENABLE_AUTOTUNE=1 -DAP_ENABLE_DEBUG=0" + compile_cmd = ( + compile_cmd + + f" --shared {self.library_name}.cu -o lib{self.library_name}.so" + ) + return CodeModule( + FuncDeclare( + DataType.void, + self.kernel_name, + [ + PointerType.void_ptr, + *self.get_kernel_arg_types(), + DataType.const_int64_t, + DataType.const_int64_t, + DataType.const_int64_t, + DataType.const_int64_t, + ], + ), + Project( + nested_files=Project.Directory( + [f"{self.library_name}.cu", Project.FileContent(code)], + ["make.sh", Project.FileContent(compile_cmd)], + ), + compile_cmd="sh make.sh", + so_relative_path=f"lib{self.library_name}.so", + ), + ) + + +def KernelDispatch(ctx): + so_func = ctx.get_so_function("MoeZipVariadicKernel") + stream_ptr = ctx.device_ctx.get_stream_addr_as_void_ptr() + getters = ctx.kernel_dispatch_const_data.kernel_args_getters + args = [stream_ptr, *map(lambda getter: getter(ctx), getters)] + apply(so_func, args)