From 56478e68f8878639095db1a84cb7fc4d551928ab Mon Sep 17 00:00:00 2001 From: hxzd5568 <3257591325@qq.com> Date: Wed, 16 Apr 2025 22:02:54 +0800 Subject: [PATCH] Add a pattern for handling the matmul_epilogue_with_mm_out --- tests/ap/test_matmul_epilogue.py | 13 +- tests/ap/test_matmul_epilogue_with_mm_out.py | 545 +++++++++++++++++++ tests/ap/topo_drr_pass.py | 60 ++ 3 files changed, 610 insertions(+), 8 deletions(-) create mode 100644 tests/ap/test_matmul_epilogue_with_mm_out.py diff --git a/tests/ap/test_matmul_epilogue.py b/tests/ap/test_matmul_epilogue.py index 9c005dd..cacd1af 100644 --- a/tests/ap/test_matmul_epilogue.py +++ b/tests/ap/test_matmul_epilogue.py @@ -54,11 +54,10 @@ def constraint(self, o, t): init_pass_manager.add_pass( ir_tools.create_access_topo_drr_one_step_pass(init_down_spider) ) - outputs_name_list = map(lambda i: f"output{i}", range(self.number_of_outputs())) inputs_name_list = map(lambda i: f"input{i+2}", range(self.number_of_inputs() - 2)) if self.number_of_inputs() > 2 else [] print('inputs_name_list: ', ', '.join(inputs_name_list)) init_fake_data_for_yield_input = topo_drr_pass.FakeDataForYieldAccessTopoPass( - outputs_name_list + map(lambda i: f"output{i}", range(self.number_of_outputs())) ) init_pass_manager.add_pass( ir_tools.create_access_topo_drr_one_step_pass(init_fake_data_for_yield_input) @@ -85,7 +84,7 @@ def constraint(self, o, t): map(lambda dst_name: pass_manager.add_pass( ir_tools.create_access_topo_drr_one_step_pass( matmul_epilogue_pass.RemoveDataOpPairPass(src_data_op_name="mm_out", dst_data_op_name=dst_name))), - outputs_name_list + map(lambda i: f"output{i}", range(self.number_of_outputs())) ) pass_manager.add_pass(ir_tools.create_dce_pass()) pass_manager.run(program) @@ -102,7 +101,7 @@ def AddPass(input_name): map(AddPass, input_names) init_pass_manager.run(program) - def _insert_store_to_global(self, program, output_names): + def _replace_yeild_with_store_to_global(self, program, output_names): init_pass_manager = ir_tools.create_pass_manager() ir_pass = topo_drr_pass.FakeDataStoreToGlobalForYieldAccessTopoPass(output_names) init_pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(ir_pass)) @@ -211,9 +210,7 @@ def _replace_with_store_to_register( return mut_program def _get_program_translator(self, ctx, o, t): - outputs_name_list = map(lambda i: f"output{i}", range(self.number_of_outputs())) other_outputs_name_list = map(lambda i: f"output{i+1}", range(self.number_of_outputs()-1)) - local_outputs_name_list = map(lambda i: f"out{i}", range(self.number_of_outputs())) inputs_name_list = map(lambda i: f"input{i+2}", range(self.number_of_inputs() - 2)) if self.number_of_inputs() > 2 else [] mut_program = ir_tools.copy_fused_ops_to_program( o.trivial_op, tensor_match_ctx=t @@ -231,9 +228,9 @@ def _get_program_translator(self, ctx, o, t): mut_program, input_names=inputs_name_list ) - self._insert_store_to_global( + self._replace_yeild_with_store_to_global( mut_program, - output_names=outputs_name_list + output_names=map(lambda i: f"output{i}", range(self.number_of_outputs())) ) kernel_arg_translator = self._make_kernel_arg_translator() index_func_unique_id2index_program = self._make_index_func_unique_id2index_program( diff --git a/tests/ap/test_matmul_epilogue_with_mm_out.py b/tests/ap/test_matmul_epilogue_with_mm_out.py new file mode 100644 index 0000000..6ce179b --- /dev/null +++ b/tests/ap/test_matmul_epilogue_with_mm_out.py @@ -0,0 +1,545 @@ +import abstract_drr +import access_topo_drr +import topo_drr_pass +import umprime +import matmul_epilogue_pass +import op_convertion_drr_pass +import matmul_binary_tpl +import ir_tools +import index_program_translator_util +import op_compute_translator_util +import program_translator_util +import kernel_arg_id_util +import low_level_ir_code_gen_ctx_util +import kernel_arg_translator_util +import pir + + +class MatmulEpilogueFusion(abstract_drr.DrrPass): + def source_pattern(self, o, t): + in_num = self.number_of_inputs() + out_num = self.number_of_outputs() + o.matmul_op = o.ap_native_op("pd_op.matmul") + o.matmul_op( + [t.input0, t.input1], + [t.mm_out] + ) + o.trivial_op = o.ap_trivial_fusion_op() + o.trivial_op( + [t.mm_out, *map(lambda index: getattr(t, f"input{index+2}"), range(in_num - 2))], + map(lambda index: getattr(t, f"output{index}"), range(out_num)), + ) + + + def result_pattern(self, o, t): + in_num = self.number_of_inputs() + out_num = self.number_of_outputs() + o.fustion_op = o.ap_pattern_fusion_op(self.code_gen) + o.fustion_op( + map(lambda index: getattr(t, f"input{index}"), range(in_num)), + map(lambda index: getattr(t, f"output{index}"), range(out_num)) + ) + + def constraint(self, o, t): + program = ir_tools.copy_fused_ops_to_program(o.trivial_op, tensor_match_ctx=t) + print("before-umprime: ", program) + # umprime passes + pass_manager = ir_tools.create_pass_manager() + pass_manager.add_pass(ir_tools.create_access_topo_drr_pass("umprime")) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(program) + print("before-access_topo_pass", program) + init_pass_manager = ir_tools.create_pass_manager() + init_down_spider = topo_drr_pass.InitDownSpiderAccessTopoPass("mm_out") + init_pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass(init_down_spider) + ) + inputs_name_list = map(lambda i: f"input{i+2}", range(self.number_of_inputs() - 2)) if self.number_of_inputs() > 2 else [] + print('inputs_name_list: ', ', '.join(inputs_name_list)) + init_fake_data_for_yield_input = topo_drr_pass.FakeDataForYieldAccessTopoPass( + map(lambda i: f"output{i}", range(self.number_of_outputs())) + ) + init_pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass(init_fake_data_for_yield_input) + ) + init_pass_manager.run(program) + print("after-init-access_topo_pass", program) + pass_manager = ir_tools.create_pass_manager() + pass_manager.add_pass(ir_tools.create_access_topo_drr_pass("default")) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(program) + print("after-apply-access_topo_pass", program) + pass_manager = ir_tools.create_pass_manager() + map(lambda dst_name: pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass( + matmul_epilogue_pass.RemoveDataOpPairPass(src_data_op_name="mm_out", dst_data_op_name=dst_name))), + inputs_name_list + ) + map(lambda dst_name: pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass( + matmul_epilogue_pass.RemoveDataOp2SumOp2DataOpPass(src_data_op_name="mm_out", dst_data_op_name=dst_name))), + inputs_name_list + ) + + map(lambda dst_name: pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass( + matmul_epilogue_pass.RemoveDataOpPairPass(src_data_op_name="mm_out", dst_data_op_name=dst_name))), + map(lambda i: f"output{i}", range(self.number_of_outputs())) + ) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(program) + print("after-remove-input-output-access_topo_pass", program) + return program.empty() + + def _insert_load_from_global(self, program, input_names): + init_pass_manager = ir_tools.create_pass_manager() + def AddPass(input_name): + ir_pass = topo_drr_pass.InitNaiveLoadFromGlobalAccessTopoPass(input_name) + init_pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass(ir_pass) + ) + map(AddPass, input_names) + init_pass_manager.run(program) + + def _replace_yeild_with_store_to_global(self, program, output_names): + init_pass_manager = ir_tools.create_pass_manager() + ir_pass = topo_drr_pass.FakeDataStoreToGlobalForYieldAccessTopoPass(output_names) + init_pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(ir_pass)) + init_pass_manager.run(program) + + def _insert_store_to_global_mm(self, program): + init_pass_manager = ir_tools.create_pass_manager() + ir_pass = topo_drr_pass.InsertStoreToGlobalForMMoutAccessTopoPass() + init_pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(ir_pass)) + init_pass_manager.run(program) + + def _make_kernel_arg_translator(self): + return matmul_binary_tpl.make_kernel_arg_translator() + + def _apply_topo_access_passes(self, mut_program, anchor_data_op_name): + init_pass_manager = ir_tools.create_pass_manager() + init_down_spider = topo_drr_pass.InitDownSpiderAccessTopoPass(anchor_data_op_name) + init_pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass(init_down_spider) + ) + init_pass_manager.run(mut_program) + pass_manager = ir_tools.create_pass_manager() + pass_manager.add_pass(ir_tools.create_access_topo_drr_pass("default")) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(mut_program) + + def _simplify_index_program(self, mut_program): + pass_manager = ir_tools.create_pass_manager() + drr_pass = topo_drr_pass.ConvertUpSpiderStoreDataOpToYieldOpPass() + pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(drr_pass)) + drr_pass = topo_drr_pass.ConvertDownSpiderStoreDataOpToYieldOpPass() + pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(drr_pass)) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(mut_program) + return mut_program + + def _make_index_func_unique_id2index_program( + self, compute_program, anchor_data_op_name, input_names, output_names): + full_index_program = compute_program.clone() + self._apply_topo_access_passes(full_index_program, anchor_data_op_name) + print('full_index_program: ', full_index_program) + def MatchAndCopyInputIndex(dst_input_name): + pass_manager = ir_tools.create_pass_manager() + removed_programs = MutableList() + rm_elementwise_drr_pass = matmul_epilogue_pass.RemoveElementInputIndexPass( + src_data_op_name=anchor_data_op_name, + dst_load_from_global_op_name=dst_input_name + ) + rm_elementwise_ir_pass = ir_tools.create_access_topo_drr_one_step_pass( + rm_elementwise_drr_pass, + matched_pattern_mut_list=removed_programs + ) + pass_manager.add_pass(rm_elementwise_ir_pass) + rm_broadcast_drr_pass = matmul_epilogue_pass.RemoveBroadcastInputIndexPass( + src_data_op_name=anchor_data_op_name, + dst_load_from_global_op_name=dst_input_name + ) + rm_broadcast_ir_pass = ir_tools.create_access_topo_drr_one_step_pass( + rm_broadcast_drr_pass, + matched_pattern_mut_list=removed_programs + ) + pass_manager.add_pass(rm_broadcast_ir_pass) + pass_manager.run(full_index_program) + def Converter(program): + return [dst_input_name, self._simplify_index_program(program)] + return map(Converter, removed_programs) + input_and_index_programs = flat_map(MatchAndCopyInputIndex, input_names) + def MatchAndCopyOutputIndex(dst_output_name): + print('full_index_program output: ', full_index_program) + pass_manager = ir_tools.create_pass_manager() + removed_programs = MutableList() + drr_pass = matmul_epilogue_pass.RemoveOutputIndexPass( + src_data_op_name=anchor_data_op_name, + dst_store_to_global_op_name=dst_output_name + ) + ir_pass = ir_tools.create_access_topo_drr_one_step_pass( + drr_pass, + matched_pattern_mut_list=removed_programs + ) + pass_manager.add_pass(ir_pass) + pass_manager.run(full_index_program) + + def Converter(program): + return [dst_output_name, self._simplify_index_program(program)] + print('len removed of output: ', len(removed_programs)) + return map(Converter, removed_programs) + output_and_index_programs = flat_map(MatchAndCopyOutputIndex, output_names) + return OrderedDict([*input_and_index_programs, *output_and_index_programs]) + + def _replace_with_load_from_register( + self, mut_program, load_ir_value_name, register_var_name): + pass_manager = ir_tools.create_pass_manager() + drr_pass = topo_drr_pass.ReplaceWithLoadFromRegisterPass( + name=load_ir_value_name, + register_var_name=register_var_name + ) + pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(drr_pass)) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(mut_program) + return mut_program + + def _replace_with_store_to_register( + self, mut_program, store_ir_value_name, register_var_name): + pass_manager = ir_tools.create_pass_manager() + drr_pass = topo_drr_pass.ReplaceWithStoreToRegisterPass( + name=store_ir_value_name, + register_var_name=register_var_name + ) + pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(drr_pass)) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(mut_program) + return mut_program + + def _get_program_translator(self, ctx, o, t): + other_outputs_name_list = [*map(lambda i: f"output{i+1}", range(self.number_of_outputs()-1)), "mm_out"] + inputs_name_list = map(lambda i: f"input{i+2}", range(self.number_of_inputs() - 2)) if self.number_of_inputs() > 2 else [] + mut_program = ir_tools.copy_fused_ops_to_program( + o.trivial_op, tensor_match_ctx=t + ) + print("before-umprime: ", mut_program) + pass_manager = ir_tools.create_pass_manager() + pass_manager.add_pass(ir_tools.create_access_topo_drr_pass("umprime")) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(mut_program) + self._insert_load_from_global( + mut_program, + input_names=["mm_out"] + ) + self._insert_load_from_global( + mut_program, + input_names=inputs_name_list + ) + self._replace_yeild_with_store_to_global( + mut_program, + output_names=map(lambda i: f"output{i}", range(self.number_of_outputs())) + ) + self._insert_store_to_global_mm( + mut_program, + ) + kernel_arg_translator = self._make_kernel_arg_translator() + index_func_unique_id2index_program = self._make_index_func_unique_id2index_program( + mut_program, + anchor_data_op_name="mm_out", + input_names=inputs_name_list, + output_names=other_outputs_name_list, + ) + print("index_func_unique_id2index_program:\n", index_func_unique_id2index_program) + index_program_translator_map = index_program_translator_util.IndexProgramTranslatorMap( + index_func_unique_id2index_program=index_func_unique_id2index_program, + kernel_arg_translator=kernel_arg_translator, + anchor_iter_var_names=matmul_binary_tpl.get_anchor_iter_var_names() + ) + self._replace_with_load_from_register( + mut_program, + load_ir_value_name="mm_out", + register_var_name="x" + ) + self._replace_with_store_to_register(mut_program, "output0", "out") + print("mut_program:", mut_program) + op_compute_translator_maker = op_compute_translator_util.OpComputeTranslatorFactory() + print('berfore translator') + program_translator = program_translator_util.ProgramTranslator( + program_property=mut_program.copy_to_const_program_data(), + kernel_arg_translator=kernel_arg_translator, + index_program_translator_map=index_program_translator_map, + op_translator_maker=op_compute_translator_maker + ) + print('after translator') + + return program_translator + + def code_gen(self, ctx, o, t): + program_translator = self._get_program_translator(ctx, o, t) + mut_kernel_arg_id_registry = kernel_arg_id_util.KernelArgIdNameRegistry( + code_gen_ctx=ctx, + tensor_match_ctx=t, + name_prefix="" + ) + print('after registry') + + template_module = matmul_binary_tpl.MatmulBinaryTemplate( + program_translator=program_translator, + mut_kernel_arg_id_registry=mut_kernel_arg_id_registry, + ) + print('after module') + + def get_symbolic_shape_args_list(sym_dim): + return ctx.dim_expr_kernel_arg_id(sym_dim) + input0_shape_kargs = map(get_symbolic_shape_args_list, t.input0.symbolic_shape_to_list()) + input1_shape_kargs = map(get_symbolic_shape_args_list, t.input1.symbolic_shape_to_list()) + print('before compile') + return template_module.compile( + input0_karg=ctx.in_tensor_data_ptr_kernel_arg_id(t.input0), + input1_karg=ctx.in_tensor_data_ptr_kernel_arg_id(t.input1), + output_karg=ctx.out_tensor_data_ptr_kernel_arg_id(t.output0), + input0_shape_kargs=input0_shape_kargs, + input1_shape_kargs=input1_shape_kargs, + ) + +class NumberOfInputsTrait0(): + def number_of_inputs(self): + return 0 + +class NumberOfInputsTrait1(): + def number_of_inputs(self): + return 1 + +class NumberOfInputsTrait2(): + def number_of_inputs(self): + return 2 + +class NumberOfInputsTrait3(): + def number_of_inputs(self): + return 3 + +class NumberOfInputsTrait4(): + def number_of_inputs(self): + return 4 + +class NumberOfInputsTrait5(): + def number_of_inputs(self): + return 5 + +class NumberOfInputsTrait6(): + def number_of_inputs(self): + return 6 + +class NumberOfInputsTrait7(): + def number_of_inputs(self): + return 7 + +class NumberOfInputsTrait8(): + def number_of_inputs(self): + return 8 + +class NumberOfInputsTrait9(): + def number_of_inputs(self): + return 9 + +class NumberOfInputsTrait10(): + def number_of_inputs(self): + return 10 + +class NumberOfInputsTrait11(): + def number_of_inputs(self): + return 11 + +class NumberOfInputsTrait12(): + def number_of_inputs(self): + return 12 + +class NumberOfInputsTrait13(): + def number_of_inputs(self): + return 13 + +class NumberOfInputsTrait14(): + def number_of_inputs(self): + return 14 + +class NumberOfInputsTrait15(): + def number_of_inputs(self): + return 15 + +class NumberOfInputsTrait16(): + def number_of_inputs(self): + return 16 + +class NumberOfInputsTrait17(): + def number_of_inputs(self): + return 17 + + +class NumberOfOutputsTrait0(): + def number_of_outputs(self): + return 0 + +class NumberOfOutputsTrait1(): + def number_of_outputs(self): + return 1 + +class NumberOfOutputsTrait2(): + def number_of_outputs(self): + return 2 + +class NumberOfOutputsTrait3(): + def number_of_outputs(self): + return 3 + +class NumberOfOutputsTrait4(): + def number_of_outputs(self): + return 4 + +class NumberOfOutputsTrait5(): + def number_of_outputs(self): + return 5 + +class NumberOfOutputsTrait6(): + def number_of_outputs(self): + return 6 + +class NumberOfOutputsTrait7(): + def number_of_outputs(self): + return 7 + +class NumberOfOutputsTrait8(): + def number_of_outputs(self): + return 8 + +class NumberOfOutputsTrait9(): + def number_of_outputs(self): + return 9 + +class NumberOfOutputsTrait10(): + def number_of_outputs(self): + return 10 + +class NumberOfOutputsTrait11(): + def number_of_outputs(self): + return 11 + +class NumberOfOutputsTrait12(): + def number_of_outputs(self): + return 12 + +class NumberOfOutputsTrait13(): + def number_of_outputs(self): + return 13 + +class NumberOfOutputsTrait14(): + def number_of_outputs(self): + return 14 + +class NumberOfOutputsTrait15(): + def number_of_outputs(self): + return 15 + +class NumberOfOutputsTrait16(): + def number_of_outputs(self): + return 16 + +class NumberOfOutputsTrait17(): + def number_of_outputs(self): + return 17 + +class NumberOfOutputsTrait18(): + def number_of_outputs(self): + return 18 + +class NumberOfOutputsTrait19(): + def number_of_outputs(self): + return 19 + +class NumberOfOutputsTrait20(): + def number_of_outputs(self): + return 20 + +class NumberOfOutputsTrait21(): + def number_of_outputs(self): + return 21 + +class NumberOfOutputsTrait22(): + def number_of_outputs(self): + return 22 +def get_mixin_class(base_class, number_of_inputs, number_of_outputs): + num_inputs_to_input_trait_class = [ + None, + NumberOfInputsTrait1, + NumberOfInputsTrait2, + NumberOfInputsTrait3, + NumberOfInputsTrait3, + NumberOfInputsTrait4, + NumberOfInputsTrait5, + NumberOfInputsTrait6, + NumberOfInputsTrait7, + NumberOfInputsTrait8, + NumberOfInputsTrait9, + NumberOfInputsTrait10, + NumberOfInputsTrait11, + NumberOfInputsTrait12, + NumberOfInputsTrait13, + NumberOfInputsTrait14, + NumberOfInputsTrait15, + NumberOfInputsTrait16, + NumberOfInputsTrait17, + ] + num_outputs_to_output_trait_class = [ + None, + NumberOfOutputsTrait1, + NumberOfOutputsTrait2, + NumberOfOutputsTrait3, + NumberOfOutputsTrait4, + NumberOfOutputsTrait5, + NumberOfOutputsTrait6, + NumberOfOutputsTrait7, + NumberOfOutputsTrait8, + NumberOfOutputsTrait9, + NumberOfOutputsTrait10, + NumberOfOutputsTrait11, + NumberOfOutputsTrait12, + NumberOfOutputsTrait13, + NumberOfOutputsTrait14, + NumberOfOutputsTrait15, + NumberOfOutputsTrait16, + NumberOfOutputsTrait17, + NumberOfOutputsTrait18, + NumberOfOutputsTrait19, + NumberOfOutputsTrait20, + NumberOfOutputsTrait21, + NumberOfOutputsTrait22, + ] + return type( + f"MatmulEpilogueFusion{number_of_inputs}_{number_of_outputs}", + [ + base_class, + num_inputs_to_input_trait_class[number_of_inputs], + num_outputs_to_output_trait_class[number_of_outputs], + ], + BuiltinSerializableAttrMap() + ) + +# abstract_drr.register_drr_pass("matmul_binary_outs_fusion", nice=0)(get_mixin_class(MatmulEpilogueFusion, 3, 2)) + +def register_class(base_class, max_num_inputs, max_num_outputs): + def register_drr_class(num_inputs, num_outputs): + abstract_drr.register_drr_pass(f"matmul_binary_in{num_inputs}_out{num_outputs}_fusion", nice=0)( + get_mixin_class(base_class, num_inputs, num_outputs) + ) + + def register_num_inputs_drr_classes(num_inputs): + + def register_num_outputs_drr_classes(num_outputs): + return register_drr_class(num_inputs+2, num_outputs+1) + + map(register_num_outputs_drr_classes, range(max_num_outputs)) + print('done max outputs') + return None + + map(register_num_inputs_drr_classes, range(max_num_inputs)) + return None + +register_class(base_class=MatmulEpilogueFusion, max_num_inputs=10, max_num_outputs=10) diff --git a/tests/ap/topo_drr_pass.py b/tests/ap/topo_drr_pass.py index 9d7a768..878a581 100644 --- a/tests/ap/topo_drr_pass.py +++ b/tests/ap/topo_drr_pass.py @@ -136,6 +136,66 @@ def store_to_global_op_for_output(self, o, t, i): [] ) +class InsertStoreToGlobalForMMoutAccessTopoPass(access_topo_drr.DrrPass): + + def __init__(self): + self.undefined_place = pir.a_place(pir.UndefinedPlace()) + self.data_input_name_attr = pir.a_str('mm_out') + + def source_pattern(self, o, t): + o.load_from_global = o.ap_native_op("ap_op.load_from_global") + o.load_from_global( + [t.mm_out], + [t.mm_local_out] + ) + + def constraint(self, o, t): + return o.load_from_global.index_func_unique_id == self.data_input_name_attr + + def result_pattern(self, o, t): + t.declare_internal_native_ir_value("input") + self.result_pattern_data_op(o, t) + store_to_global_op_name = "store_to_global_mm_op" + setattr(o, store_to_global_op_name, o.ap_native_op("ap_op.store_to_global")) + store_to_global_op = getattr(o, store_to_global_op_name) + store_to_global_op.index_func_unique_id = lambda o, t: self.data_input_name_attr + store_to_global_op( + [t.mm_out, t.mm_local_out], + [] + ) + + def result_pattern_data_op(self, o, t): + t.declare_internal_native_ir_value(f"mm_out") + data_op_unique_name = f"data_op_for_mm_output" + setattr(o, data_op_unique_name, o.ap_native_op("pd_op.data")) + data_op = getattr(o, data_op_unique_name) + data_op.name = lambda o, t: self.get_name(o, t) + data_op.shape = lambda o, t: self.get_shape(o, t) + data_op.dtype = lambda o, t: self.get_dtype(o, t) + data_op.place = lambda o, t: self.get_place(o, t) + data_op( + [], + [t.mm_out] + ) + + def get_name(self, o, t): + return pir.a_str("mm_glb_out") + + def get_shape(self, o, t): + ir_tensor = getattr(t, "mm_out") + def GetDims(dtype, dims, data_layout): + return dims + return pir.a_intarray(ir_tensor.type.match(t_dtensor=GetDims)) + + def get_dtype(self, o, t): + ir_tensor = getattr(t, "mm_out") + def GetDtype(dtype, dims, data_layout): + return dtype + return pir.a_dtype(ir_tensor.type.match(t_dtensor=GetDtype)) + + def get_place(self, o, t): + return self.undefined_place + class ConvertUpSpiderStoreDataOpToYieldOpPass(access_topo_drr.DrrPass): def source_pattern(self, o, t):