From 68837fea33eee29d6958b74876b9308584c1e303 Mon Sep 17 00:00:00 2001 From: hxzd5568 <3257591325@qq.com> Date: Thu, 10 Apr 2025 13:20:11 +0800 Subject: [PATCH] support arbitrary input/ouput and non-3-dim mm_out --- tests/ap/__main__.py | 4 +- tests/ap/kernel_arg_id_util.py | 2 + tests/ap/make_axpr.sh | 4 +- tests/ap/matmul/matmul.h | 13 +- tests/ap/matmul/tests/matmul_binary_kernel.cu | 6 +- tests/ap/matmul_binary_tpl.py | 76 ++- ...atmul_epilogue_pass_of_remove_functions.py | 158 ++++++ tests/ap/op_compute_translator_util.py | 68 +++ tests/ap/op_index_translator_util.py | 21 +- tests/ap/test_matmul_epilogue.py | 537 ++++++++++++++++++ tests/ap/topo_drr_pass.py | 54 +- tests/ap/umprime.py | 66 +++ 12 files changed, 982 insertions(+), 27 deletions(-) create mode 100644 tests/ap/matmul_epilogue_pass_of_remove_functions.py create mode 100644 tests/ap/test_matmul_epilogue.py create mode 100644 tests/ap/umprime.py diff --git a/tests/ap/__main__.py b/tests/ap/__main__.py index b22c965..b3bbff3 100644 --- a/tests/ap/__main__.py +++ b/tests/ap/__main__.py @@ -1,3 +1 @@ -# import test_trivial_reduce -# import test_binary_trivial_reduce -import test_matmul_binary +import test_matmul_epilogue diff --git a/tests/ap/kernel_arg_id_util.py b/tests/ap/kernel_arg_id_util.py index 3b53ab5..6ac7169 100644 --- a/tests/ap/kernel_arg_id_util.py +++ b/tests/ap/kernel_arg_id_util.py @@ -15,6 +15,7 @@ def get_or_create_kernel_arg_id_manul_var_name(self, kernel_arg_id, cpp_var_name return self.all_kernel_arg_id2unique_name.get_or_create(kernel_arg_id, create) def get_in_tensor_data_ptr_var_name(self, in_ir_value_name): + print('in_ir_value_name: ', in_ir_value_name) ir_value = getattr(self.tensor_match_ctx, in_ir_value_name) kernel_arg_id = self.code_gen_ctx.in_tensor_data_ptr_kernel_arg_id(ir_value) create = self._get_creator(kernel_arg_id, self._create_in_tensor_data_ptr_var_name) @@ -29,6 +30,7 @@ def _create_in_tensor_data_ptr_var_name(self): return name def get_out_tensor_data_ptr_var_name(self, out_ir_value_name): + out_ir_value_name = out_ir_value_name.replace("out", "output") ir_value = getattr(self.tensor_match_ctx, out_ir_value_name) kernel_arg_id = self.code_gen_ctx.out_tensor_data_ptr_kernel_arg_id(ir_value) create = self._get_creator(kernel_arg_id, self._create_out_tensor_data_ptr_var_name) diff --git a/tests/ap/make_axpr.sh b/tests/ap/make_axpr.sh index ff37b9f..5f476a6 100644 --- a/tests/ap/make_axpr.sh +++ b/tests/ap/make_axpr.sh @@ -14,11 +14,13 @@ FILENAMES_ARRAY=( "__main__" "topo_drr_pass" "op_convertion_drr_pass" + "umprime" "access_topo_drr" "abstract_drr" + "matmul_epilogue_pass_of_remove_functions" "ap_tpl_codegen" "matmul_binary_tpl" - "test_matmul_binary" + "test_matmul_epilogue" ) for filename in "${FILENAMES_ARRAY[@]}" do diff --git a/tests/ap/matmul/matmul.h b/tests/ap/matmul/matmul.h index 3ac0fd5..82b5543 100644 --- a/tests/ap/matmul/matmul.h +++ b/tests/ap/matmul/matmul.h @@ -89,7 +89,9 @@ struct GemmEpilogueParams { std::vector input0_shape; std::vector input1_shape; std::vector epilogue_in_ptrs; + std::vector epilogue_out_ptrs; std::vector> epilogue_in_shapes; + std::vector> epilogue_out_shapes; GemmEpilogueParams() {} GemmEpilogueParams(cudaStream_t stream, const void *input, const void *weight, @@ -156,16 +158,23 @@ struct GemmEpilogueParams { shape_args.ldc_bias = (!bias || is_C_bias) ? 0 : n; } - void SetEpilogues(const std::vector &in_ptrs) { + void SetEpilogues(const std::vector &in_ptrs, + const std::vector< void *> &out_ptrs) { epilogue_in_ptrs = in_ptrs; + epilogue_out_ptrs = out_ptrs; } void SetEpilogueAndShapes(const std::vector &in_ptrs, - const std::vector> &in_shapes) { + const std::vector> &in_shapes, + const std::vector &out_ptrs, + const std::vector> &out_shapes) { ASSERT_CHECK(in_ptrs.size() == in_shapes.size()); epilogue_in_ptrs = in_ptrs; epilogue_in_shapes = in_shapes; + ASSERT_CHECK(out_ptrs.size() == out_shapes.size()); + epilogue_out_ptrs = out_ptrs; + epilogue_out_shapes = out_shapes; } }; diff --git a/tests/ap/matmul/tests/matmul_binary_kernel.cu b/tests/ap/matmul/tests/matmul_binary_kernel.cu index 1d0cb55..5fc8b99 100644 --- a/tests/ap/matmul/tests/matmul_binary_kernel.cu +++ b/tests/ap/matmul/tests/matmul_binary_kernel.cu @@ -47,13 +47,15 @@ void MatmulAddBinaryKernel( cudaStream_t *stream, const void *input, const void *weight, const void *bias, void *output, const std::vector &epilogue_ins, + const std::vector &epilogue_outs, const std::vector &input_shape, const std::vector &weight_shape, const std::vector &bias_shape, - const std::vector> &epilogue_shapes) { + const std::vector> &epilogue_in_shapes, + const std::vector> &epilogue_out_shapes) { GemmEpilogueParams params(*stream, input, weight, bias, output, input_shape, weight_shape, bias_shape); - params.SetEpilogueAndShapes(epilogue_ins, epilogue_shapes); + params.SetEpilogueAndShapes(epilogue_ins, epilogue_in_shapes, epilogue_outs, epilogue_out_shapes); #if AP_ENABLE_AUTOTUNE #if AP_USE_FLOAT16 diff --git a/tests/ap/matmul_binary_tpl.py b/tests/ap/matmul_binary_tpl.py index dc92853..61aeb06 100644 --- a/tests/ap/matmul_binary_tpl.py +++ b/tests/ap/matmul_binary_tpl.py @@ -16,6 +16,11 @@ def is_in_tensor_karg(kernel_arg_id): ) return kernel_arg_id_type_name == "InTensorDataPtrKernelArgId" +def is_out_tensor_karg(kernel_arg_id): + kernel_arg_id_type_name = f"{type(kernel_arg_id)}".replace("", "" + ) + return kernel_arg_id_type_name == "OutTensorDataPtrKernelArgId" class MatmulBinaryTemplate: def __init__( @@ -39,6 +44,7 @@ def __init__( ) self.input_dim_karg_to_shape_access = MutableOrderedDict() self.input_tensor_karg_to_shape_access = MutableOrderedDict() + self.output_tensor_karg_to_shape_access = MutableOrderedDict() self.kernel_name = "MatmulBinaryKernel" self.library_name = "matmul_binary_kernel" @@ -105,6 +111,11 @@ def get_kernel_arg_runtime_getters(self): lambda pair: pair[0].runtime_getter, all_kernel_arg_id_and_unique_names ) + def init_outputs(self): + out_tensor_data_nums = self.mut_kernel_arg_id_registry.out_tensor_data_ptr_seq_no + stmt = map(lambda i: f"out{i}", range(out_tensor_data_nums + 1)) + return "T " + f", ".join(stmt) + ";" + def get_kernel_arg_types(self): all_kernel_arg_id_and_unique_names = ( self.mut_kernel_arg_id_registry.all_kernel_arg_id2unique_name.items() @@ -159,6 +170,7 @@ def get_epilogue_arguments_init_str( def declare_epilogue_arguments_assign(pair): kernel_arg_id = pair[0] is_in_tensor_type = is_in_tensor_karg(kernel_arg_id) + is_out_tensor_type = is_out_tensor_karg(kernel_arg_id) var_name = pair[1] field_name = self.kernel_arg_translator.get_param_struct_field_name( @@ -169,6 +181,10 @@ def get_in_tensor_statement(): param_name_for_var = self.input_tensor_karg_to_shape_access[var_name] return f"reinterpret_cast({params_name}.{param_name_for_var})" + def get_out_tensor_statement(): + param_name_for_var = self.output_tensor_karg_to_shape_access[var_name] + return f"reinterpret_cast<{output_dtype} *>({params_name}.{param_name_for_var})" + def get_dim_expr_statement(): param_name_for_var = self.input_dim_karg_to_shape_access[var_name] return f"{params_name}.{param_name_for_var}" @@ -176,23 +192,25 @@ def get_dim_expr_statement(): statement = ( get_in_tensor_statement() if is_in_tensor_type - else get_dim_expr_statement() + else get_out_tensor_statement() + if is_out_tensor_type + else get_dim_expr_statement() ) return f"{obj_name}.{field_name} = {statement};" generated_kernel_arg_id_and_names = ( self.mut_kernel_arg_id_registry.generated_kernel_arg_id2unique_name.items() ) + return f"\n{indent}".join( map(declare_epilogue_arguments_assign, generated_kernel_arg_id_and_names) ) - def get_params_epilogue_ptrs_init_str(self, obj_name, indent): + def get_params_epilogue_ptrs_init_str(self, in_obj_name, out_obj_name, indent): in_tensor_id = 0 - - def declare_params_epilogue_arguments_assign(pair): + def declare_in_params_epilogue_arguments_assign(pair): def get_creator(): - return f"{obj_name}[{in_tensor_id}]" + return f"{in_obj_name}[{in_tensor_id}]" kernel_arg_id = pair[0] is_in_tensor_type = is_in_tensor_karg(kernel_arg_id) @@ -201,7 +219,7 @@ def generate_statement(): self.input_tensor_karg_to_shape_access.get_or_create( pair[1], get_creator ) - statement = f"{obj_name}.push_back({pair[1]});" + statement = f"{in_obj_name}.push_back({pair[1]});" in_tensor_id = in_tensor_id + 1 return statement @@ -210,13 +228,39 @@ def generate_statement(): generated_kernel_arg_id_and_names = ( self.mut_kernel_arg_id_registry.generated_kernel_arg_id2unique_name.items() ) - return f"\n{indent}".join( - map( - declare_params_epilogue_arguments_assign, + in_str_list = map( + declare_in_params_epilogue_arguments_assign, generated_kernel_arg_id_and_names, - ) ) + out_tensor_id = 0 + def declare_out_params_epilogue_arguments_assign(pair): + def get_creator(): + return f"{out_obj_name}[{out_tensor_id}]" + + kernel_arg_id = pair[0] + is_out_tensor_type = is_out_tensor_karg(kernel_arg_id) + + def generate_statement(): + self.output_tensor_karg_to_shape_access.get_or_create( + pair[1], get_creator + ) + statement = f"{out_obj_name}.push_back({pair[1]});" + out_tensor_id = out_tensor_id + 1 + return statement + + return generate_statement() if is_out_tensor_type else "" + + out_str_list = map( + declare_out_params_epilogue_arguments_assign, + generated_kernel_arg_id_and_names, + ) + str_list = filter( + lambda ss: ss != "", + [*in_str_list, *out_str_list] + ) + return f"\n{indent}".join(str_list) + def get_params_input_shape_init_str(self, input_name, input_shape_kargs, indent): def init_input_shape_with_args(i): def get_creator(): @@ -264,9 +308,9 @@ def make_project( // Note: need to support vectorized operation __forceinline__ __host__ __device__ T operator()(T x, const Arguments& args, const MatrixCoord& coord) const { - T out; + AP_OUTPUTS_INIT AP_GENERATED_BINARY_EPILOGUE_STRING - return out; + return out0; } }; @@ -303,9 +347,10 @@ def make_project( *cuda_stream_ptr, ${input0}, ${input1}, nullptr, ${output}, ${input0}_shape, ${input1}_shape, std::vector{}); std::vector epilogue_in_ptrs; + std::vector epilogue_out_ptrs; AP_PARAMS_EPILOGUE_PTRS_INIT - params.SetEpilogues(epilogue_in_ptrs); + params.SetEpilogues(epilogue_in_ptrs, epilogue_out_ptrs); #if AP_ENABLE_AUTOTUNE AP_AUTOTUNE_${output_dtype}(ap::RunMatmulWithVariadicKernel); @@ -321,6 +366,7 @@ def make_project( code_template.replace( "AP_GENERATED_BINARY_EPILOGUE_STRING", trivial_code_str ) + .replace("AP_OUTPUTS_INIT", self.init_outputs()) .replace("AP_KERNEL_ARGS_DECLARE", self.get_kernel_arg_list_str()) .replace( "AP_PARAMS_INPUT0_SHAPE_INIT", @@ -336,7 +382,7 @@ def make_project( ) .replace( "AP_PARAMS_EPILOGUE_PTRS_INIT", - self.get_params_epilogue_ptrs_init_str("epilogue_in_ptrs", indent=" "), + self.get_params_epilogue_ptrs_init_str("epilogue_in_ptrs", "epilogue_out_ptrs", indent=" "), ) .replace( "AP_EPILOGUE_ARGUMENTS_FIELDS", @@ -356,7 +402,7 @@ def make_project( .replace("${k_value}", f"{input0_shape_kargs[-1].value}") .replace("${n_value}", f"{input1_shape_kargs[-1].value}") ) - + print('cuda code is: ', code) source_dir = "/work/abstract_pass/Athena/tests/ap/matmul" cutlass_dir = "/work/abstract_pass/Athena/tests/ap/matmul/cutlass" compile_cmd = ( diff --git a/tests/ap/matmul_epilogue_pass_of_remove_functions.py b/tests/ap/matmul_epilogue_pass_of_remove_functions.py new file mode 100644 index 0000000..d3e3c03 --- /dev/null +++ b/tests/ap/matmul_epilogue_pass_of_remove_functions.py @@ -0,0 +1,158 @@ +import access_topo_drr +import pir + +class RemoveDataOpPairPass(access_topo_drr.DrrPass): + def __init__(self, src_data_op_name, dst_data_op_name): + self.src_data_op_name = pir.a_str(src_data_op_name) + self.dst_data_op_name = pir.a_str(dst_data_op_name) + def source_pattern(self, o, t): + o.src_data_op = o.ap_native_op("pd_op.data") + o.src_data_op( + [], + [t.input0] + ) + o.dst_data_op = o.ap_native_op("pd_op.data") + o.dst_data_op( + [], + [t.input1] + ) + o.up_spider_op = o.ap_native_op("ap_op.up_spider") + o.up_spider_op( + [t.input0, t.input1], + [] + ) + def constraint(self, o, t): + return [o.src_data_op.name, o.dst_data_op.name] == [self.src_data_op_name, self.dst_data_op_name] + def result_pattern(self, o, t): + pass + +class RemoveDataOp2SumOp2DataOpPass(access_topo_drr.DrrPass): + def __init__(self, src_data_op_name, dst_data_op_name): + self.src_data_op_name = pir.a_str(src_data_op_name) + self.dst_data_op_name = pir.a_str(dst_data_op_name) + + def source_pattern(self, o, t): + o.src_data_op = o.ap_native_op("pd_op.data") + o.src_data_op.name = self.src_data_op_name + o.src_data_op( + [], + [t.input0] + ) + o.full_int_array_op = o.ap_native_op("pd_op.full_int_array") + o.full_int_array_op( + [], + [t.axis] + ) + o.sum_op = o.ap_native_op("pd_op.sum") + o.sum_op( + [t.input0, t.axis], + [t.sum_out] + ) + o.dst_data_op = o.ap_native_op("pd_op.data") + o.dst_data_op.name = self.dst_data_op_name + o.dst_data_op( + [], + [t.input1] + ) + o.up_spider_op = o.ap_native_op("ap_op.up_spider") + o.up_spider_op( + [t.sum_out, t.input1], + [] + ) + + def result_pattern(self, o, t): + pass + +class RemoveElementInputIndexPass(access_topo_drr.DrrPass): + + def __init__(self, src_data_op_name, dst_load_from_global_op_name): + self.src_data_op_name = pir.a_str(src_data_op_name) + self.dst_load_from_global_op_name = pir.a_str(dst_load_from_global_op_name) + + def source_pattern(self, o, t): + o.src_data_op = o.ap_native_op("pd_op.data") + o.src_data_op.name = self.src_data_op_name + o.src_data_op( + [], + [t.src_input] + ) + + o.dst_load_from_global_op = o.ap_native_op("ap_op.load_from_global") + o.dst_load_from_global_op.index_func_unique_id = self.dst_load_from_global_op_name + o.dst_load_from_global_op( + [t.dst_input], + [t.dst_load_from_global_output] + ) + o.up_spider_op = o.ap_native_op("ap_op.up_spider") + o.up_spider_op( + [t.src_input, t.dst_load_from_global_output], + [] + ) + + def result_pattern(self, o, t): + pass + +class RemoveBroadcastInputIndexPass(access_topo_drr.DrrPass): + def __init__(self, src_data_op_name, dst_load_from_global_op_name): + self.src_data_op_name = pir.a_str(src_data_op_name) + self.dst_load_from_global_op_name = pir.a_str(dst_load_from_global_op_name) + + def source_pattern(self, o, t): + o.src_data_op = o.ap_native_op("pd_op.data") + o.src_data_op.name = self.src_data_op_name + o.src_data_op( + [], + [t.input0] + ) + o.full_int_array_op = o.ap_native_op("pd_op.full_int_array") + o.full_int_array_op( + [], + [t.axis] + ) + o.sum_op = o.ap_native_op("pd_op.sum") + o.sum_op( + [t.input0, t.axis], + [t.sum_out] + ) + o.dst_load_from_global_op = o.ap_native_op("ap_op.load_from_global") + o.dst_load_from_global_op.index_func_unique_id = self.dst_load_from_global_op_name + o.dst_load_from_global_op( + [t.dst_input], + [t.dst_load_from_global_output] + ) + o.up_spider_op = o.ap_native_op("ap_op.up_spider") + o.up_spider_op( + [t.sum_out, t.dst_load_from_global_output], + [] + ) + + def result_pattern(self, o, t): + pass + +class RemoveOutputIndexPass(access_topo_drr.DrrPass): + + def __init__(self, src_data_op_name, dst_store_to_global_op_name): + self.src_data_op_name = pir.a_str(src_data_op_name) + self.dst_store_to_global_op_name = pir.a_str(dst_store_to_global_op_name) + + def source_pattern(self, o, t): + o.src_data_op = o.ap_native_op("pd_op.data") + o.src_data_op.name = self.src_data_op_name + o.src_data_op( + [], + [t.src_input] + ) + o.down_spider_op = o.ap_native_op("ap_op.down_spider") + o.down_spider_op( + [t.src_input], + [t.dst_output_val] + ) + o.dst_store_to_global_op = o.ap_native_op("ap_op.store_to_global") + o.dst_store_to_global_op.index_func_unique_id = self.dst_store_to_global_op_name + o.dst_store_to_global_op( + [t.dst_output, t.dst_output_val], + [] + ) + + def result_pattern(self, o, t): + pass diff --git a/tests/ap/op_compute_translator_util.py b/tests/ap/op_compute_translator_util.py index 16b2f5a..42b6042 100644 --- a/tests/ap/op_compute_translator_util.py +++ b/tests/ap/op_compute_translator_util.py @@ -48,8 +48,10 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): mut_lir_code_gen_ctx=mut_lir_code_gen_ctx, ) data_op_name = inputs[0].var_name + print('data_name of OpLoadFromGlobal is: ', data_op_name) arg_name = mut_kernel_arg_id_registry.get_in_tensor_data_ptr_var_name(data_op_name) ptr_var_name = self.kernel_arg_translator.get_use_name(arg_name) + print('ptr_var of OpLoadFromGlobal is: ', ptr_var_name) out = self.get_out_cg_val(0) mut_lir_code_gen_ctx.let(out, f"{ptr_var_name}[{offset_var_name}]") return [out] @@ -73,9 +75,58 @@ def __init__(self, self.output_properties = output_properties self.kernel_arg_translator = kernel_arg_translator self.index_program_translator_map = index_program_translator_map + self.dtype2type_name = OrderedDict( + [ + [PointerType.const_float_ptr, "const float*"], + [PointerType.const_float16_ptr, "const half*"], + [PointerType.float_ptr, "float*"], + [PointerType.float16_ptr, "half*"], + [DataType.float, "float"], + [DataType.float16, "half"], + [DataType.int64_t, "int64_t"], + ] + ) + self.ptr2type = OrderedDict( + [ + ["float*", "float"], + ["half*", "half"], + ] + ) + # TODO: replace the dictionary with more robust method + # to find the corresponding global variable name of the output register + self.local_to_glb_out = OrderedDict( + map(lambda i: [f"out{i+1}", f"args.out_ptr_{i}"], range(30)) + ) def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): mut_lir_code_gen_ctx.stmts.append(f"{self.get_out_var_name()} = {inputs[0].var_name};") + out_name = self.get_out_var_name() + mut_kernel_arg_id_registry.get_out_tensor_data_ptr_var_name(out_name) if out_name != "out0" else None + + def generate_store_stmt(): + index_func_unique_id_attr = self.op_property.attributes.name + index_func_unique_id = index_func_unique_id_attr.match(a_str=lambda x:x) + output_seq_name = f"out_ptr_{mut_kernel_arg_id_registry.out_tensor_data_ptr_seq_no}" + generated_kernel_arg_id_and_names = ( + mut_kernel_arg_id_registry.generated_kernel_arg_id2unique_name.items() + ) + kernel_arg_id = filter( + lambda item: item[1] == output_seq_name, + generated_kernel_arg_id_and_names + )[0] + dtype = kernel_arg_id[0].type + type_name = self.dtype2type_name[dtype] + data_type = self.ptr2type[type_name] + offset_var_name = self.index_program_translator_map.get_offset_var_name( + index_func_unique_id=index_func_unique_id, + mut_kernel_arg_id_registry=mut_kernel_arg_id_registry, + mut_lir_code_gen_ctx=mut_lir_code_gen_ctx, + ) + ptr_name = self.local_to_glb_out[out_name] + mut_lir_code_gen_ctx.stmts.append( + f"{ptr_name}[{offset_var_name}] = static_cast<{data_type}>({out_name});" + ) + generate_store_stmt() if out_name != "out0" else None return [] def get_out_var_name(self): @@ -493,6 +544,22 @@ def __init__(self, def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return inputs +class CinnOpExpandCodeGen: + def __init__(self, + op_property, + input_properties, + output_properties, + kernel_arg_translator, + index_program_translator_map): + self.op_property = op_property + self.input_properties = input_properties + self.output_properties = output_properties + self.kernel_arg_translator = kernel_arg_translator + self.index_program_translator_map = index_program_translator_map + + def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): + return [inputs[0]] + class CinnOpGenerateShapeCodeGen: def __init__(self, op_property, @@ -539,6 +606,7 @@ def __init__(self): ["pd_op.maximum", PdOpMaximumCodeGen], ["cinn_op.yield_store", CinnOpYieldStoreCodeGen], ["cinn_op.broadcast", CinnOpBroadcastCodeGen], + ["pd_op.expand", CinnOpExpandCodeGen], ["cinn_op.generate_shape", CinnOpGenerateShapeCodeGen] ]) diff --git a/tests/ap/op_index_translator_util.py b/tests/ap/op_index_translator_util.py index bea4d80..ba0c28b 100644 --- a/tests/ap/op_index_translator_util.py +++ b/tests/ap/op_index_translator_util.py @@ -62,8 +62,16 @@ def __init__(self, def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): input_iter_var_names = inputs[0].iter_var_names + # TODO: Only applicable to matrix multiplication cases of b-m-n mode + # handle the cases with mm_out's dim != 3 + reduced_axes = map( + lambda index: inputs[1].const_data[index], range(2) + ) if len(inputs[1].const_data) >= 3 + else [0, 1] if len(inputs[1].const_data) == 1 + else inputs[1].const_data + reduced_axes_set = OrderedDict( - map(lambda x: [int(x), True], inputs[1].const_data) + map(lambda x: [int(x), True], reduced_axes) ) non_reduced_axes = filter( lambda x: reduced_axes_set.contains(x) == False, @@ -91,15 +99,21 @@ def __init__(self, self.kernel_arg_translator = kernel_arg_translator self.anchor_iter_var_names = anchor_iter_var_names + # TODO: Only applicable to matrix multiplication cases of b-m-n mode def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): symbolic_shape = self.input_properties[0].symbolic_shape + print('input symbolic_shape of reshape is: ', symbolic_shape) def get_or_create_dim_var_name(dim_expr): arg_var_name = mut_kernel_arg_id_registry.get_dim_expr_var_name(dim_expr) return self.kernel_arg_translator.get_use_name(arg_var_name) + + rank = 3 if len(symbolic_shape) > 3 else len(symbolic_shape) + rank_bias = len(symbolic_shape) - 3 if len(symbolic_shape) > 3 else 0 + def get_dim_var_name(i): - dim_expr = symbolic_shape[i] + dim_expr = symbolic_shape[i + rank_bias] return get_or_create_dim_var_name(dim_expr) - rank = len(symbolic_shape) + stride_dims_list = map( lambda num_dims: map(lambda i: get_dim_var_name(num_dims + i + 1), range(rank - 1 - num_dims)), range(rank) @@ -165,4 +179,5 @@ def __call__(self, ) def _get_class(self, op_name): + print('translating op: ', op_name) return self.op_name2class[op_name] diff --git a/tests/ap/test_matmul_epilogue.py b/tests/ap/test_matmul_epilogue.py new file mode 100644 index 0000000..ccd2f80 --- /dev/null +++ b/tests/ap/test_matmul_epilogue.py @@ -0,0 +1,537 @@ +import abstract_drr +import access_topo_drr +import topo_drr_pass +import umprime +import matmul_epilogue_pass_of_remove_functions +import op_convertion_drr_pass +import matmul_binary_tpl +import ir_tools +import index_program_translator_util +import op_compute_translator_util +import program_translator_util +import kernel_arg_id_util +import low_level_ir_code_gen_ctx_util +import kernel_arg_translator_util +import pir + + +class MatmulEpilogueFusion(abstract_drr.DrrPass): + def source_pattern(self, o, t): + in_num = self.number_of_inputs() + out_num = self.number_of_outputs() + o.matmul_op = o.ap_native_op("pd_op.matmul") + o.matmul_op( + [t.input0, t.input1], + [t.mm_out] + ) + o.trivial_op = o.ap_trivial_fusion_op() + o.trivial_op( + [t.mm_out, *map(lambda index: getattr(t, f"input{index+2}"), range(in_num - 2))], + map(lambda index: getattr(t, f"output{index}"), range(out_num)), + ) + + + def result_pattern(self, o, t): + in_num = self.number_of_inputs() + out_num = self.number_of_outputs() + o.fustion_op = o.ap_pattern_fusion_op(self.code_gen) + o.fustion_op( + map(lambda index: getattr(t, f"input{index}"), range(in_num)), + map(lambda index: getattr(t, f"output{index}"), range(out_num)), + ) + + def constraint(self, o, t): + program = ir_tools.copy_fused_ops_to_program(o.trivial_op, tensor_match_ctx=t) + print("before-umprime: ", program) + # umprime passes + pass_manager = ir_tools.create_pass_manager() + pass_manager.add_pass(ir_tools.create_access_topo_drr_pass("umprime")) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(program) + print("before-access_topo_pass", program) + init_pass_manager = ir_tools.create_pass_manager() + init_down_spider = topo_drr_pass.InitDownSpiderAccessTopoPass("mm_out") + init_pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass(init_down_spider) + ) + outputs_name_list = map(lambda i: f"output{i}", range(self.number_of_outputs())) + inputs_name_list = map(lambda i: f"input{i+2}", range(self.number_of_inputs() - 2)) if self.number_of_inputs() > 2 else [] + print('inputs_name_list: ', ', '.join(inputs_name_list)) + init_fake_data_for_yield_input = topo_drr_pass.FakeDataForYieldAccessTopoPass( + outputs_name_list + ) + init_pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass(init_fake_data_for_yield_input) + ) + init_pass_manager.run(program) + print("after-init-access_topo_pass", program) + pass_manager = ir_tools.create_pass_manager() + pass_manager.add_pass(ir_tools.create_access_topo_drr_pass("default")) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(program) + print("after-apply-access_topo_pass", program) + pass_manager = ir_tools.create_pass_manager() + map(lambda dst_name: pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass( + matmul_epilogue_pass_of_remove_functions.RemoveDataOpPairPass(src_data_op_name="mm_out", dst_data_op_name=dst_name))), + inputs_name_list + ) + map(lambda dst_name: pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass( + matmul_epilogue_pass_of_remove_functions.RemoveDataOp2SumOp2DataOpPass(src_data_op_name="mm_out", dst_data_op_name=dst_name))), + inputs_name_list + ) + + map(lambda dst_name: pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass( + matmul_epilogue_pass_of_remove_functions.RemoveDataOpPairPass(src_data_op_name="mm_out", dst_data_op_name=dst_name))), + outputs_name_list + ) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(program) + print("after-remove-input-output-access_topo_pass", program) + return program.empty() + + def _insert_load_from_global(self, program, input_names): + init_pass_manager = ir_tools.create_pass_manager() + def AddPass(input_name): + ir_pass = topo_drr_pass.InitNaiveLoadFromGlobalAccessTopoPass(input_name) + init_pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass(ir_pass) + ) + map(AddPass, input_names) + init_pass_manager.run(program) + + def _insert_store_to_global(self, program, output_names): + init_pass_manager = ir_tools.create_pass_manager() + ir_pass = topo_drr_pass.FakeDataStoreToGlobalForYieldAccessTopoPass(output_names) + init_pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(ir_pass)) + init_pass_manager.run(program) + + def _make_kernel_arg_translator(self): + return matmul_binary_tpl.make_kernel_arg_translator() + + def _apply_topo_access_passes(self, mut_program, anchor_data_op_name): + init_pass_manager = ir_tools.create_pass_manager() + init_down_spider = topo_drr_pass.InitDownSpiderAccessTopoPass(anchor_data_op_name) + init_pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass(init_down_spider) + ) + init_pass_manager.run(mut_program) + pass_manager = ir_tools.create_pass_manager() + pass_manager.add_pass(ir_tools.create_access_topo_drr_pass("default")) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(mut_program) + + def _simplify_index_program(self, mut_program): + pass_manager = ir_tools.create_pass_manager() + drr_pass = topo_drr_pass.ConvertUpSpiderStoreDataOpToYieldOpPass() + pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(drr_pass)) + drr_pass = topo_drr_pass.ConvertDownSpiderStoreDataOpToYieldOpPass() + pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(drr_pass)) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(mut_program) + return mut_program + + def _make_index_func_unique_id2index_program( + self, compute_program, anchor_data_op_name, input_names, output_names): + full_index_program = compute_program.clone() + self._apply_topo_access_passes(full_index_program, anchor_data_op_name) + print('full_index_program: ', full_index_program) + def MatchAndCopyInputIndex(dst_input_name): + pass_manager = ir_tools.create_pass_manager() + removed_programs = MutableList() + rm_elementwise_drr_pass = matmul_epilogue_pass_of_remove_functions.RemoveElementInputIndexPass( + src_data_op_name=anchor_data_op_name, + dst_load_from_global_op_name=dst_input_name + ) + rm_elementwise_ir_pass = ir_tools.create_access_topo_drr_one_step_pass( + rm_elementwise_drr_pass, + matched_pattern_mut_list=removed_programs + ) + pass_manager.add_pass(rm_elementwise_ir_pass) + rm_broadcast_drr_pass = matmul_epilogue_pass_of_remove_functions.RemoveBroadcastInputIndexPass( + src_data_op_name=anchor_data_op_name, + dst_load_from_global_op_name=dst_input_name + ) + rm_broadcast_ir_pass = ir_tools.create_access_topo_drr_one_step_pass( + rm_broadcast_drr_pass, + matched_pattern_mut_list=removed_programs + ) + pass_manager.add_pass(rm_broadcast_ir_pass) + pass_manager.run(full_index_program) + def Converter(program): + return [dst_input_name, self._simplify_index_program(program)] + return map(Converter, removed_programs) + input_and_index_programs = flat_map(MatchAndCopyInputIndex, input_names) + def MatchAndCopyOutputIndex(dst_output_name): + pass_manager = ir_tools.create_pass_manager() + removed_programs = MutableList() + drr_pass = matmul_epilogue_pass_of_remove_functions.RemoveOutputIndexPass( + src_data_op_name=anchor_data_op_name, + dst_store_to_global_op_name=dst_output_name + ) + ir_pass = ir_tools.create_access_topo_drr_one_step_pass( + drr_pass, + matched_pattern_mut_list=removed_programs + ) + pass_manager.add_pass(ir_pass) + pass_manager.run(full_index_program) + + def Converter(program): + return [dst_output_name, self._simplify_index_program(program)] + return map(Converter, removed_programs) + output_and_index_programs = flat_map(MatchAndCopyOutputIndex, output_names) + return OrderedDict([*input_and_index_programs, *output_and_index_programs]) + + def _replace_with_load_from_register( + self, mut_program, load_ir_value_name, register_var_name): + pass_manager = ir_tools.create_pass_manager() + drr_pass = topo_drr_pass.ReplaceWithLoadFromRegisterPass( + name=load_ir_value_name, + register_var_name=register_var_name + ) + pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(drr_pass)) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(mut_program) + return mut_program + + def _replace_with_store_to_register( + self, mut_program, store_ir_value_name, register_var_name): + pass_manager = ir_tools.create_pass_manager() + drr_pass = topo_drr_pass.ReplaceWithStoreToRegisterPass( + name=store_ir_value_name, + register_var_name=register_var_name + ) + pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(drr_pass)) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(mut_program) + return mut_program + + def _get_program_translator(self, ctx, o, t): + outputs_name_list = map(lambda i: f"output{i}", range(self.number_of_outputs())) + other_outputs_name_list = map(lambda i: f"output{i+1}", range(self.number_of_outputs()-1)) + local_outputs_name_list = map(lambda i: f"out{i}", range(self.number_of_outputs())) + inputs_name_list = map(lambda i: f"input{i+2}", range(self.number_of_inputs() - 2)) if self.number_of_inputs() > 2 else [] + mut_program = ir_tools.copy_fused_ops_to_program( + o.trivial_op, tensor_match_ctx=t + ) + print("before-umprime: ", mut_program) + # umprime passes + pass_manager = ir_tools.create_pass_manager() + pass_manager.add_pass(ir_tools.create_access_topo_drr_pass("umprime")) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(mut_program) + self._insert_load_from_global( + mut_program, + input_names=["mm_out"] + ) + self._insert_load_from_global( + mut_program, + input_names=inputs_name_list + ) + self._insert_store_to_global( + mut_program, + output_names=outputs_name_list + ) + kernel_arg_translator = self._make_kernel_arg_translator() + index_func_unique_id2index_program = self._make_index_func_unique_id2index_program( + mut_program, + anchor_data_op_name="mm_out", + input_names=inputs_name_list, + output_names=other_outputs_name_list, + ) + print("index_func_unique_id2index_program:\n", index_func_unique_id2index_program) + index_program_translator_map = index_program_translator_util.IndexProgramTranslatorMap( + index_func_unique_id2index_program=index_func_unique_id2index_program, + kernel_arg_translator=kernel_arg_translator, + anchor_iter_var_names=matmul_binary_tpl.get_anchor_iter_var_names() + ) + self._replace_with_load_from_register( + mut_program, + load_ir_value_name="mm_out", + register_var_name="x" + ) + out_pair = zip(outputs_name_list, local_outputs_name_list) + map(lambda pair: self._replace_with_store_to_register( + mut_program, + store_ir_value_name=pair[0], + register_var_name=pair[1] + ), out_pair + ) + op_compute_translator_maker = op_compute_translator_util.OpComputeTranslatorFactory() + program_translator = program_translator_util.ProgramTranslator( + program_property=mut_program.copy_to_const_program_data(), + kernel_arg_translator=kernel_arg_translator, + index_program_translator_map=index_program_translator_map, + op_translator_maker=op_compute_translator_maker + ) + + return program_translator + + def code_gen(self, ctx, o, t): + program_translator = self._get_program_translator(ctx, o, t) + mut_kernel_arg_id_registry = kernel_arg_id_util.KernelArgIdNameRegistry( + code_gen_ctx=ctx, + tensor_match_ctx=t, + name_prefix="" + ) + + template_module = matmul_binary_tpl.MatmulBinaryTemplate( + program_translator=program_translator, + mut_kernel_arg_id_registry=mut_kernel_arg_id_registry, + ) + + def get_symbolic_shape_args_list(sym_dim): + return ctx.dim_expr_kernel_arg_id(sym_dim) + input0_shape_kargs = map(get_symbolic_shape_args_list, t.input0.symbolic_shape_to_list()) + input1_shape_kargs = map(get_symbolic_shape_args_list, t.input1.symbolic_shape_to_list()) + return template_module.compile( + input0_karg=ctx.in_tensor_data_ptr_kernel_arg_id(t.input0), + input1_karg=ctx.in_tensor_data_ptr_kernel_arg_id(t.input1), + output_karg=ctx.out_tensor_data_ptr_kernel_arg_id(t.output0), + input0_shape_kargs=input0_shape_kargs, + input1_shape_kargs=input1_shape_kargs, + ) + +class NumberOfInputsTrait0(): + def number_of_inputs(self): + return 0 + +class NumberOfInputsTrait1(): + def number_of_inputs(self): + return 1 + +class NumberOfInputsTrait2(): + def number_of_inputs(self): + return 2 + +class NumberOfInputsTrait3(): + def number_of_inputs(self): + return 3 + +class NumberOfInputsTrait4(): + def number_of_inputs(self): + return 4 + +class NumberOfInputsTrait5(): + def number_of_inputs(self): + return 5 + +class NumberOfInputsTrait6(): + def number_of_inputs(self): + return 6 + +class NumberOfInputsTrait7(): + def number_of_inputs(self): + return 7 + +class NumberOfInputsTrait8(): + def number_of_inputs(self): + return 8 + +class NumberOfInputsTrait9(): + def number_of_inputs(self): + return 9 + +class NumberOfInputsTrait10(): + def number_of_inputs(self): + return 10 + +class NumberOfInputsTrait11(): + def number_of_inputs(self): + return 11 + +class NumberOfInputsTrait12(): + def number_of_inputs(self): + return 12 + +class NumberOfInputsTrait13(): + def number_of_inputs(self): + return 13 + +class NumberOfInputsTrait14(): + def number_of_inputs(self): + return 14 + +class NumberOfInputsTrait15(): + def number_of_inputs(self): + return 15 + +class NumberOfInputsTrait16(): + def number_of_inputs(self): + return 16 + +class NumberOfInputsTrait17(): + def number_of_inputs(self): + return 17 + + +class NumberOfOutputsTrait0(): + def number_of_outputs(self): + return 0 + +class NumberOfOutputsTrait1(): + def number_of_outputs(self): + return 1 + +class NumberOfOutputsTrait2(): + def number_of_outputs(self): + return 2 + +class NumberOfOutputsTrait3(): + def number_of_outputs(self): + return 3 + +class NumberOfOutputsTrait4(): + def number_of_outputs(self): + return 4 + +class NumberOfOutputsTrait5(): + def number_of_outputs(self): + return 5 + +class NumberOfOutputsTrait6(): + def number_of_outputs(self): + return 6 + +class NumberOfOutputsTrait7(): + def number_of_outputs(self): + return 7 + +class NumberOfOutputsTrait8(): + def number_of_outputs(self): + return 8 + +class NumberOfOutputsTrait9(): + def number_of_outputs(self): + return 9 + +class NumberOfOutputsTrait10(): + def number_of_outputs(self): + return 10 + +class NumberOfOutputsTrait11(): + def number_of_outputs(self): + return 11 + +class NumberOfOutputsTrait12(): + def number_of_outputs(self): + return 12 + +class NumberOfOutputsTrait13(): + def number_of_outputs(self): + return 13 + +class NumberOfOutputsTrait14(): + def number_of_outputs(self): + return 14 + +class NumberOfOutputsTrait15(): + def number_of_outputs(self): + return 15 + +class NumberOfOutputsTrait16(): + def number_of_outputs(self): + return 16 + +class NumberOfOutputsTrait17(): + def number_of_outputs(self): + return 17 + +class NumberOfOutputsTrait18(): + def number_of_outputs(self): + return 18 + +class NumberOfOutputsTrait19(): + def number_of_outputs(self): + return 19 + +class NumberOfOutputsTrait20(): + def number_of_outputs(self): + return 20 + +class NumberOfOutputsTrait21(): + def number_of_outputs(self): + return 21 + +class NumberOfOutputsTrait22(): + def number_of_outputs(self): + return 22 +def get_mixin_class(base_class, number_of_inputs, number_of_outputs): + num_inputs_to_input_trait_class = [ + None, + NumberOfInputsTrait1, + NumberOfInputsTrait2, + NumberOfInputsTrait3, + NumberOfInputsTrait3, + NumberOfInputsTrait4, + NumberOfInputsTrait5, + NumberOfInputsTrait6, + NumberOfInputsTrait7, + NumberOfInputsTrait8, + NumberOfInputsTrait9, + NumberOfInputsTrait10, + NumberOfInputsTrait11, + NumberOfInputsTrait12, + NumberOfInputsTrait13, + NumberOfInputsTrait14, + NumberOfInputsTrait15, + NumberOfInputsTrait16, + NumberOfInputsTrait17, + ] + num_outputs_to_output_trait_class = [ + None, + NumberOfOutputsTrait1, + NumberOfOutputsTrait2, + NumberOfOutputsTrait3, + NumberOfOutputsTrait4, + NumberOfOutputsTrait5, + NumberOfOutputsTrait6, + NumberOfOutputsTrait7, + NumberOfOutputsTrait8, + NumberOfOutputsTrait9, + NumberOfOutputsTrait10, + NumberOfOutputsTrait11, + NumberOfOutputsTrait12, + NumberOfOutputsTrait13, + NumberOfOutputsTrait14, + NumberOfOutputsTrait15, + NumberOfOutputsTrait16, + NumberOfOutputsTrait17, + NumberOfOutputsTrait18, + NumberOfOutputsTrait19, + NumberOfOutputsTrait20, + NumberOfOutputsTrait21, + NumberOfOutputsTrait22, + ] + return type( + f"MatmulEpilogueFusion{number_of_inputs}_{number_of_outputs}", + [ + base_class, + num_inputs_to_input_trait_class[number_of_inputs], + num_outputs_to_output_trait_class[number_of_outputs], + ], + BuiltinSerializableAttrMap() + ) + +# abstract_drr.register_drr_pass("matmul_binary_outs_fusion", nice=0)(get_mixin_class(MatmulEpilogueFusion, 3, 2)) + +def register_class(base_class, max_num_inputs, max_num_outputs): + def register_drr_class(num_inputs, num_outputs): + abstract_drr.register_drr_pass(f"matmul_binary_in{num_inputs}_out{num_outputs}_fusion", nice=0)( + get_mixin_class(base_class, num_inputs, num_outputs) + ) + + def register_num_inputs_drr_classes(num_inputs): + + def register_num_outputs_drr_classes(num_outputs): + return register_drr_class(num_inputs+2, num_outputs+1) + + map(register_num_outputs_drr_classes, range(max_num_outputs)) + return None + + map(register_num_inputs_drr_classes, range(max_num_inputs)) + return None + +register_class(base_class=MatmulEpilogueFusion, max_num_inputs=10, max_num_outputs=10) diff --git a/tests/ap/topo_drr_pass.py b/tests/ap/topo_drr_pass.py index 530892f..f6daebf 100644 --- a/tests/ap/topo_drr_pass.py +++ b/tests/ap/topo_drr_pass.py @@ -162,6 +162,32 @@ def result_pattern(self, o, t): [] ) +class ConvertDownSpiderStoreDataOpToYieldOpPass(access_topo_drr.DrrPass): + + def source_pattern(self, o, t): + o.data_mm_op = o.ap_native_op("pd_op.data") + o.data_mm_op( + [], + [t.input1] + ) + o.down_spider_op = o.ap_native_op("ap_op.down_spider") + o.down_spider_op( + [t.input1], + [t.tmp1] + ) + o.store_to_global = o.ap_native_op("ap_op.store_to_global") + o.store_to_global( + [t.input0, t.tmp1], + [] + ) + + def result_pattern(self, o, t): + o.yield_op = o.ap_native_op("cf.yield") + o.yield_op( + [t.input0], + [] + ) + class InitDownSpiderAccessTopoPass(access_topo_drr.DrrPass): def __init__(self, data_input_name): @@ -354,7 +380,7 @@ def result_pattern(self, o, t): pass -@access_topo_drr.register_drr_pass("down_spider_add", tag="default") +@access_topo_drr.register_drr_pass("left_down_spider_add", tag="default") class DownSpiderAddAccessTopoPass(access_topo_drr.DrrPass): def source_pattern(self, o, t): @@ -381,6 +407,32 @@ def result_pattern(self, o, t): [] ) +@access_topo_drr.register_drr_pass("right_down_spider_add", tag="default") +class DownSpiderAddAccessTopoPass(access_topo_drr.DrrPass): + + def source_pattern(self, o, t): + o.spider = o.ap_native_op("ap_op.down_spider") + o.spider( + [t.input0], + [t.tmp0] + ) + o.add = o.ap_native_op("pd_op.add") + o.add( + [t.tmp1, t.tmp0], + [t.output] + ) + + def result_pattern(self, o, t): + o.down_spider = o.ap_native_op("ap_op.down_spider") + o.down_spider( + [t.input0], + [t.output] + ) + o.up_spider = o.ap_native_op("ap_op.up_spider") + o.up_spider( + [t.tmp1, t.input0], + [] + ) @access_topo_drr.register_drr_pass("expand_up_spider", tag="default") class ExpandUpSpiderAccessTopoPass(access_topo_drr.DrrPass): diff --git a/tests/ap/umprime.py b/tests/ap/umprime.py new file mode 100644 index 0000000..a55476c --- /dev/null +++ b/tests/ap/umprime.py @@ -0,0 +1,66 @@ +import access_topo_drr +import pir + +@access_topo_drr.register_drr_pass("pd_op_static_relu", tag="umprime") +class PdOpReluAccessTopoPass(access_topo_drr.DrrPass): + def __init__(self): + self.zero = pir.a_f64(DataValue.float64("0")) + + def source_pattern(self, o, t): + o.full_op = o.ap_native_op("pd_op.full") + o.full_op( + [], + [t.intermediate] + ) + o.maximum_op = o.ap_native_op("pd_op.maximum") + o.maximum_op( + [t.input, t.intermediate], + [t.output] + ) + def constraint(self, o, t): + return o.full_op.value == self.zero + + def result_pattern(self, o, t): + o.result_op = o.ap_native_op("pd_op.relu") + o.result_op( + [t.input], + [t.output] + ) + + +@access_topo_drr.register_drr_pass("pd_op_dynamic_relu", tag="umprime") +class PdOpReluAccessTopoPass(access_topo_drr.DrrPass): + def __init__(self): + self.zero = pir.a_f64(DataValue.float64("0")) + + def source_pattern(self, o, t): + o.full_op = o.ap_native_op("pd_op.full") + o.full_op( + [], + [t.intermediate0] + ) + o.generate_shape_op = o.ap_native_op("cinn_op.generate_shape") + o.generate_shape_op( + [t.input0], + [t.intermediate1] + ) + o.expand_op = o.ap_native_op("pd_op.expand") + o.expand_op( + [t.intermediate0, t.intermediate1], + [t.intermediate2] + ) + o.maximum_op = o.ap_native_op("pd_op.maximum") + o.maximum_op( + [t.input1, t.intermediate2], + [t.output] + ) + def constraint(self, o, t): + return o.full_op.value == self.zero + + def result_pattern(self, o, t): + o.result_op = o.ap_native_op("pd_op.relu") + o.result_op( + [t.input1], + [t.output] + ) +