From 594f6ea4fcdd2c697ba85ba419c0a00b4fdfed85 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 20 Feb 2025 13:49:18 +0800 Subject: [PATCH 1/5] Add matmul_unary_fusion. --- tests/ap/matmul_binary_tpl.py.json | 3387 ----------------- ...l_binary_tpl.py => matmul_variadic_tpl.py} | 16 +- tests/ap/paddle-tests/test_matmul_unary.py | 83 + tests/ap/test_matmul_binary.py | 11 +- tests/ap/test_matmul_unary.py | 188 + 5 files changed, 285 insertions(+), 3400 deletions(-) delete mode 100644 tests/ap/matmul_binary_tpl.py.json rename tests/ap/{matmul_binary_tpl.py => matmul_variadic_tpl.py} (95%) create mode 100644 tests/ap/paddle-tests/test_matmul_unary.py create mode 100644 tests/ap/test_matmul_unary.py diff --git a/tests/ap/matmul_binary_tpl.py.json b/tests/ap/matmul_binary_tpl.py.json deleted file mode 100644 index e7238fb..0000000 --- a/tests/ap/matmul_binary_tpl.py.json +++ /dev/null @@ -1,3387 +0,0 @@ -[ - "__builtin_let__", - [ - [ - "ap_tpl_codegen", - [ - "import", - { - "str": "ap_tpl_codegen" - } - ] - ], - [ - "low_level_ir_code_gen_ctx_util", - [ - "import", - { - "str": "low_level_ir_code_gen_ctx_util" - } - ] - ], - [ - "kernel_arg_translator_util", - [ - "import", - { - "str": "kernel_arg_translator_util" - } - ] - ], - [ - "make_kernel_arg_translator", - [ - "__builtin_identity__", - [ - "lambda", - [], - [ - "__builtin_let__", - [ - [ - "___0", - [ - "__builtin_getattr__", - "kernel_arg_translator_util", - { - "str": "KernelArgTranslator" - } - ] - ], - [ - "___1", - [ - "__builtin_list__", - { - "str": "param_struct_name" - }, - { - "str": "args" - } - ] - ], - [ - "___2", - [ - "__builtin_list__", - "___1" - ] - ], - [ - "___3", - [ - "__builtin_list__" - ] - ], - [ - "___4", - [ - "__builtin_PackedArgs__", - "___3", - "___2" - ] - ], - [ - "___5", - [ - "___0", - "___4" - ] - ], - [ - "___6", - [ - "__builtin_return__", - "___5" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "get_anchor_iter_var_names", - [ - "__builtin_identity__", - [ - "lambda", - [], - [ - "__builtin_let__", - [ - [ - "___7", - [ - "__builtin_list__", - { - "str": "coord.j" - }, - { - "str": "coord.k" - } - ] - ], - [ - "___8", - [ - "__builtin_return__", - "___7" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___9", - [ - "__builtin_list__" - ] - ], - [ - "__init__", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "program_translator", - "mut_kernel_arg_id_registry" - ], - [ - "__builtin_let__", - [ - [ - "___10", - [ - "__builtin_setattr__", - "self", - { - "str": "program_translator" - } - ] - ], - [ - "___11", - [ - "___10", - { - "str": "program_translator" - }, - "program_translator" - ] - ], - [ - "___12", - [ - "__builtin_setattr__", - "self", - { - "str": "mut_kernel_arg_id_registry" - } - ] - ], - [ - "___13", - [ - "___12", - { - "str": "mut_kernel_arg_id_registry" - }, - "mut_kernel_arg_id_registry" - ] - ], - [ - "___14", - [ - "make_kernel_arg_translator" - ] - ], - [ - "___15", - [ - "__builtin_setattr__", - "self", - { - "str": "kernel_arg_translator" - } - ] - ], - [ - "___16", - [ - "___15", - { - "str": "kernel_arg_translator" - }, - "___14" - ] - ], - [ - "___17", - [ - "__builtin_getattr__", - "PointerType", - { - "str": "const_float_ptr" - } - ] - ], - [ - "___18", - [ - "__builtin_list__", - "___17", - { - "str": "const float*" - } - ] - ], - [ - "___19", - [ - "__builtin_getattr__", - "PointerType", - { - "str": "const_float16_ptr" - } - ] - ], - [ - "___20", - [ - "__builtin_list__", - "___19", - { - "str": "const half*" - } - ] - ], - [ - "___21", - [ - "__builtin_getattr__", - "PointerType", - { - "str": "float_ptr" - } - ] - ], - [ - "___22", - [ - "__builtin_list__", - "___21", - { - "str": "float*" - } - ] - ], - [ - "___23", - [ - "__builtin_getattr__", - "PointerType", - { - "str": "float16_ptr" - } - ] - ], - [ - "___24", - [ - "__builtin_list__", - "___23", - { - "str": "half*" - } - ] - ], - [ - "___25", - [ - "__builtin_getattr__", - "DataType", - { - "str": "float" - } - ] - ], - [ - "___26", - [ - "__builtin_list__", - "___25", - { - "str": "float" - } - ] - ], - [ - "___27", - [ - "__builtin_getattr__", - "DataType", - { - "str": "float16" - } - ] - ], - [ - "___28", - [ - "__builtin_list__", - "___27", - { - "str": "half" - } - ] - ], - [ - "___29", - [ - "__builtin_getattr__", - "DataType", - { - "str": "int64_t" - } - ] - ], - [ - "___30", - [ - "__builtin_list__", - "___29", - { - "str": "int64_t" - } - ] - ], - [ - "___31", - [ - "__builtin_list__", - "___18", - "___20", - "___22", - "___24", - "___26", - "___28", - "___30" - ] - ], - [ - "___32", - [ - "OrderedDict", - "___31" - ] - ], - [ - "___33", - [ - "__builtin_setattr__", - "self", - { - "str": "dtype2type_name" - } - ] - ], - [ - "___34", - [ - "___33", - { - "str": "dtype2type_name" - }, - "___32" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___35", - [ - "__builtin_getattr__", - "__init__", - { - "str": "__function__" - } - ] - ], - [ - "___36", - [ - "__builtin_list__", - { - "str": "__init__" - }, - "___35" - ] - ], - [ - "_register_name", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "pair" - ], - [ - "__builtin_let__", - [ - [ - "___37", - [ - "__builtin_getattr__", - "self", - { - "str": "mut_kernel_arg_id_registry" - } - ] - ], - [ - "registry", - [ - "__builtin_identity__", - "___37" - ] - ], - [ - "___38", - [ - "__builtin_getattr__", - "registry", - { - "str": "get_or_create_kernel_arg_id_manul_var_name" - } - ] - ], - [ - "___39", - [ - "__builtin_getitem__", - "pair", - 0 - ] - ], - [ - "___40", - [ - "__builtin_list__", - { - "str": "kernel_arg_id" - }, - "___39" - ] - ], - [ - "___41", - [ - "__builtin_getitem__", - "pair", - 1 - ] - ], - [ - "___42", - [ - "__builtin_list__", - { - "str": "cpp_var_name" - }, - "___41" - ] - ], - [ - "___43", - [ - "__builtin_list__", - "___40", - "___42" - ] - ], - [ - "___44", - [ - "__builtin_list__" - ] - ], - [ - "___45", - [ - "__builtin_PackedArgs__", - "___44", - "___43" - ] - ], - [ - "___46", - [ - "___38", - "___45" - ] - ], - [ - "___47", - [ - "__builtin_identity__", - "___46" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___48", - [ - "__builtin_getattr__", - "_register_name", - { - "str": "__function__" - } - ] - ], - [ - "___49", - [ - "__builtin_list__", - { - "str": "_register_name" - }, - "___48" - ] - ], - [ - "compile", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "input_karg", - "weight_karg", - "output_karg", - "m_karg", - "n_karg", - "k_karg" - ], - [ - "__builtin_let__", - [ - [ - "___50", - [ - "__builtin_getattr__", - "self", - { - "str": "_register_name" - } - ] - ], - [ - "___51", - [ - "__builtin_list__", - "input_karg", - { - "str": "input" - } - ] - ], - [ - "___52", - [ - "__builtin_list__", - "weight_karg", - { - "str": "weight" - } - ] - ], - [ - "___53", - [ - "__builtin_list__", - "output_karg", - { - "str": "output" - } - ] - ], - [ - "___54", - [ - "__builtin_list__", - "m_karg", - { - "str": "m" - } - ] - ], - [ - "___55", - [ - "__builtin_list__", - "n_karg", - { - "str": "n" - } - ] - ], - [ - "___56", - [ - "__builtin_list__", - "k_karg", - { - "str": "k" - } - ] - ], - [ - "___57", - [ - "__builtin_list__", - "___51", - "___52", - "___53", - "___54", - "___55", - "___56" - ] - ], - [ - "___58", - [ - "map", - "___50", - "___57" - ] - ], - [ - "___59", - [ - "__builtin_identity__", - "___58" - ] - ], - [ - "___60", - [ - "__builtin_getattr__", - "low_level_ir_code_gen_ctx_util", - { - "str": "CudaLikeIrCodeGenCtx" - } - ] - ], - [ - "___61", - [ - "___60" - ] - ], - [ - "mut_lir_code_gen_ctx", - [ - "__builtin_identity__", - "___61" - ] - ], - [ - "___63", - [ - "__builtin_getattr__", - "self", - { - "str": "program_translator" - } - ] - ], - [ - "___62", - [ - "__builtin_getattr__", - "___63", - { - "str": "translate" - } - ] - ], - [ - "___64", - [ - "__builtin_getattr__", - "self", - { - "str": "mut_kernel_arg_id_registry" - } - ] - ], - [ - "___65", - [ - "__builtin_list__", - { - "str": "mut_kernel_arg_id_registry" - }, - "___64" - ] - ], - [ - "___66", - [ - "__builtin_list__", - { - "str": "mut_lir_code_gen_ctx" - }, - "mut_lir_code_gen_ctx" - ] - ], - [ - "___67", - [ - "__builtin_list__", - "___65", - "___66" - ] - ], - [ - "___68", - [ - "__builtin_list__" - ] - ], - [ - "___69", - [ - "__builtin_PackedArgs__", - "___68", - "___67" - ] - ], - [ - "___70", - [ - "___62", - "___69" - ] - ], - [ - "___71", - [ - "__builtin_identity__", - "___70" - ] - ], - [ - "___72", - [ - "__builtin_getattr__", - "mut_lir_code_gen_ctx", - { - "str": "get_stmts_joined_str" - } - ] - ], - [ - "___73", - [ - "___72" - ] - ], - [ - "trivial_code_str", - [ - "__builtin_identity__", - "___73" - ] - ], - [ - "___74", - [ - "print", - { - "str": "matmul_binary_epilogue_code:\n" - }, - "trivial_code_str" - ] - ], - [ - "___75", - [ - "__builtin_identity__", - "___74" - ] - ], - [ - "___76", - [ - "__builtin_getattr__", - "self", - { - "str": "make_project" - } - ] - ], - [ - "___77", - [ - "___76", - "trivial_code_str", - "input_karg", - "weight_karg", - "output_karg", - "m_karg", - "n_karg", - "k_karg" - ] - ], - [ - "project_module", - [ - "__builtin_identity__", - "___77" - ] - ], - [ - "___78", - [ - "__builtin_list__", - { - "str": "module" - }, - "project_module" - ] - ], - [ - "___79", - [ - "__builtin_list__", - { - "str": "kernel_dispatch_func" - }, - "KernelDispatch" - ] - ], - [ - "___80", - [ - "__builtin_getattr__", - "self", - { - "str": "get_kernel_arg_runtime_getters" - } - ] - ], - [ - "___81", - [ - "___80" - ] - ], - [ - "___82", - [ - "__builtin_list__", - { - "str": "kernel_args_getters" - }, - "___81" - ] - ], - [ - "___83", - [ - "__builtin_list__", - "___82" - ] - ], - [ - "___84", - [ - "__builtin_list__" - ] - ], - [ - "___85", - [ - "__builtin_PackedArgs__", - "___84", - "___83" - ] - ], - [ - "___86", - [ - "BuiltinSerializableAttrMap", - "___85" - ] - ], - [ - "___87", - [ - "__builtin_list__", - { - "str": "kernel_dispatch_const_data" - }, - "___86" - ] - ], - [ - "___88", - [ - "__builtin_list__", - "___78", - "___79", - "___87" - ] - ], - [ - "___89", - [ - "__builtin_list__" - ] - ], - [ - "___90", - [ - "__builtin_PackedArgs__", - "___89", - "___88" - ] - ], - [ - "___91", - [ - "CodeGenResult", - "___90" - ] - ], - [ - "___92", - [ - "__builtin_return__", - "___91" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___93", - [ - "__builtin_getattr__", - "compile", - { - "str": "__function__" - } - ] - ], - [ - "___94", - [ - "__builtin_list__", - { - "str": "compile" - }, - "___93" - ] - ], - [ - "get_kernel_arg_runtime_getters", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self" - ], - [ - "__builtin_let__", - [ - [ - "___97", - [ - "__builtin_getattr__", - "self", - { - "str": "mut_kernel_arg_id_registry" - } - ] - ], - [ - "___96", - [ - "__builtin_getattr__", - "___97", - { - "str": "all_kernel_arg_id2unique_name" - } - ] - ], - [ - "___95", - [ - "__builtin_getattr__", - "___96", - { - "str": "items" - } - ] - ], - [ - "___98", - [ - "___95" - ] - ], - [ - "all_kernel_arg_id_and_unique_names", - [ - "__builtin_identity__", - "___98" - ] - ], - [ - "___101", - [ - "map", - [ - "lambda", - [ - "pair" - ], - [ - "__builtin_let__", - [ - [ - "___100", - [ - "__builtin_getitem__", - "pair", - 0 - ] - ], - [ - "___99", - [ - "__builtin_getattr__", - "___100", - { - "str": "runtime_getter" - } - ] - ] - ], - [ - "__builtin_identity__", - "___99" - ] - ] - ], - "all_kernel_arg_id_and_unique_names" - ] - ], - [ - "___102", - [ - "__builtin_return__", - "___101" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___103", - [ - "__builtin_getattr__", - "get_kernel_arg_runtime_getters", - { - "str": "__function__" - } - ] - ], - [ - "___104", - [ - "__builtin_list__", - { - "str": "get_kernel_arg_runtime_getters" - }, - "___103" - ] - ], - [ - "get_kernel_arg_types", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self" - ], - [ - "__builtin_let__", - [ - [ - "___107", - [ - "__builtin_getattr__", - "self", - { - "str": "mut_kernel_arg_id_registry" - } - ] - ], - [ - "___106", - [ - "__builtin_getattr__", - "___107", - { - "str": "all_kernel_arg_id2unique_name" - } - ] - ], - [ - "___105", - [ - "__builtin_getattr__", - "___106", - { - "str": "items" - } - ] - ], - [ - "___108", - [ - "___105" - ] - ], - [ - "all_kernel_arg_id_and_unique_names", - [ - "__builtin_identity__", - "___108" - ] - ], - [ - "___111", - [ - "map", - [ - "lambda", - [ - "pair" - ], - [ - "__builtin_let__", - [ - [ - "___110", - [ - "__builtin_getitem__", - "pair", - 0 - ] - ], - [ - "___109", - [ - "__builtin_getattr__", - "___110", - { - "str": "type" - } - ] - ] - ], - [ - "__builtin_identity__", - "___109" - ] - ] - ], - "all_kernel_arg_id_and_unique_names" - ] - ], - [ - "___112", - [ - "__builtin_return__", - "___111" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___113", - [ - "__builtin_getattr__", - "get_kernel_arg_types", - { - "str": "__function__" - } - ] - ], - [ - "___114", - [ - "__builtin_list__", - { - "str": "get_kernel_arg_types" - }, - "___113" - ] - ], - [ - "get_kernel_arg_id_var_name", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "kernel_arg_id" - ], - [ - "__builtin_let__", - [ - [ - "___116", - [ - "__builtin_getattr__", - "self", - { - "str": "mut_kernel_arg_id_registry" - } - ] - ], - [ - "___115", - [ - "__builtin_getattr__", - "___116", - { - "str": "all_kernel_arg_id2unique_name" - } - ] - ], - [ - "all_kernel_arg_id2unique_name", - [ - "__builtin_identity__", - "___115" - ] - ], - [ - "___117", - [ - "__builtin_getitem__", - "all_kernel_arg_id2unique_name", - "kernel_arg_id" - ] - ], - [ - "___118", - [ - "__builtin_return__", - "___117" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___119", - [ - "__builtin_getattr__", - "get_kernel_arg_id_var_name", - { - "str": "__function__" - } - ] - ], - [ - "___120", - [ - "__builtin_list__", - { - "str": "get_kernel_arg_id_var_name" - }, - "___119" - ] - ], - [ - "get_kernel_arg_list_str", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self" - ], - [ - "__builtin_let__", - [ - [ - "declare_epilogue_arguments_field", - [ - "__builtin_identity__", - [ - "lambda", - [ - "pair" - ], - [ - "__builtin_let__", - [ - [ - "___121", - [ - "__builtin_getitem__", - "pair", - 0 - ] - ], - [ - "kernel_arg_id", - [ - "__builtin_identity__", - "___121" - ] - ], - [ - "___122", - [ - "__builtin_getitem__", - "pair", - 1 - ] - ], - [ - "var_name", - [ - "__builtin_identity__", - "___122" - ] - ], - [ - "___124", - [ - "__builtin_getattr__", - "self", - { - "str": "kernel_arg_translator" - } - ] - ], - [ - "___123", - [ - "__builtin_getattr__", - "___124", - { - "str": "get_param_struct_field_name" - } - ] - ], - [ - "___125", - [ - "___123", - "var_name" - ] - ], - [ - "field_name", - [ - "__builtin_identity__", - "___125" - ] - ], - [ - "___126", - [ - "__builtin_getattr__", - "kernel_arg_id", - { - "str": "type" - } - ] - ], - [ - "dtype", - [ - "__builtin_identity__", - "___126" - ] - ], - [ - "___127", - [ - "__builtin_getattr__", - "self", - { - "str": "dtype2type_name" - } - ] - ], - [ - "___128", - [ - "__builtin_getitem__", - "___127", - "dtype" - ] - ], - [ - "type_name", - [ - "__builtin_identity__", - "___128" - ] - ], - [ - "___129", - [ - "__builtin_ToString__", - "type_name" - ] - ], - [ - "___130", - [ - "__builtin_ToString__", - { - "str": " " - } - ] - ], - [ - "___131", - [ - "__builtin_Add__", - "___129", - "___130" - ] - ], - [ - "___132", - [ - "__builtin_ToString__", - "field_name" - ] - ], - [ - "___133", - [ - "__builtin_Add__", - "___131", - "___132" - ] - ], - [ - "___134", - [ - "__builtin_return__", - "___133" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___137", - [ - "__builtin_getattr__", - "self", - { - "str": "mut_kernel_arg_id_registry" - } - ] - ], - [ - "___136", - [ - "__builtin_getattr__", - "___137", - { - "str": "all_kernel_arg_id2unique_name" - } - ] - ], - [ - "___135", - [ - "__builtin_getattr__", - "___136", - { - "str": "items" - } - ] - ], - [ - "___138", - [ - "___135" - ] - ], - [ - "all_kernel_arg_id_and_names", - [ - "__builtin_identity__", - "___138" - ] - ], - [ - "___139", - [ - "__builtin_getattr__", - { - "str": ", " - }, - { - "str": "join" - } - ] - ], - [ - "___140", - [ - "map", - "declare_epilogue_arguments_field", - "all_kernel_arg_id_and_names" - ] - ], - [ - "___141", - [ - "___139", - "___140" - ] - ], - [ - "___142", - [ - "__builtin_return__", - "___141" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___143", - [ - "__builtin_getattr__", - "get_kernel_arg_list_str", - { - "str": "__function__" - } - ] - ], - [ - "___144", - [ - "__builtin_list__", - { - "str": "get_kernel_arg_list_str" - }, - "___143" - ] - ], - [ - "get_epilogue_arguments_fields_str", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self" - ], - [ - "__builtin_let__", - [ - [ - "declare_epilogue_arguments_field", - [ - "__builtin_identity__", - [ - "lambda", - [ - "pair" - ], - [ - "__builtin_let__", - [ - [ - "___145", - [ - "__builtin_getitem__", - "pair", - 0 - ] - ], - [ - "kernel_arg_id", - [ - "__builtin_identity__", - "___145" - ] - ], - [ - "___146", - [ - "__builtin_getitem__", - "pair", - 1 - ] - ], - [ - "var_name", - [ - "__builtin_identity__", - "___146" - ] - ], - [ - "___148", - [ - "__builtin_getattr__", - "self", - { - "str": "kernel_arg_translator" - } - ] - ], - [ - "___147", - [ - "__builtin_getattr__", - "___148", - { - "str": "get_param_struct_field_name" - } - ] - ], - [ - "___149", - [ - "___147", - "var_name" - ] - ], - [ - "field_name", - [ - "__builtin_identity__", - "___149" - ] - ], - [ - "___150", - [ - "__builtin_getattr__", - "kernel_arg_id", - { - "str": "type" - } - ] - ], - [ - "dtype", - [ - "__builtin_identity__", - "___150" - ] - ], - [ - "___151", - [ - "__builtin_getattr__", - "self", - { - "str": "dtype2type_name" - } - ] - ], - [ - "___152", - [ - "__builtin_getitem__", - "___151", - "dtype" - ] - ], - [ - "type_name", - [ - "__builtin_identity__", - "___152" - ] - ], - [ - "___153", - [ - "__builtin_ToString__", - { - "str": " " - } - ] - ], - [ - "___154", - [ - "__builtin_ToString__", - "type_name" - ] - ], - [ - "___155", - [ - "__builtin_Add__", - "___153", - "___154" - ] - ], - [ - "___156", - [ - "__builtin_ToString__", - { - "str": " " - } - ] - ], - [ - "___157", - [ - "__builtin_Add__", - "___155", - "___156" - ] - ], - [ - "___158", - [ - "__builtin_ToString__", - "field_name" - ] - ], - [ - "___159", - [ - "__builtin_Add__", - "___157", - "___158" - ] - ], - [ - "___160", - [ - "__builtin_ToString__", - { - "str": ";" - } - ] - ], - [ - "___161", - [ - "__builtin_Add__", - "___159", - "___160" - ] - ], - [ - "___162", - [ - "__builtin_return__", - "___161" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___165", - [ - "__builtin_getattr__", - "self", - { - "str": "mut_kernel_arg_id_registry" - } - ] - ], - [ - "___164", - [ - "__builtin_getattr__", - "___165", - { - "str": "generated_kernel_arg_id2unique_name" - } - ] - ], - [ - "___163", - [ - "__builtin_getattr__", - "___164", - { - "str": "items" - } - ] - ], - [ - "___166", - [ - "___163" - ] - ], - [ - "generated_kernel_arg_id_and_names", - [ - "__builtin_identity__", - "___166" - ] - ], - [ - "___167", - [ - "__builtin_getattr__", - { - "str": "\n" - }, - { - "str": "join" - } - ] - ], - [ - "___168", - [ - "map", - "declare_epilogue_arguments_field", - "generated_kernel_arg_id_and_names" - ] - ], - [ - "___169", - [ - "___167", - "___168" - ] - ], - [ - "___170", - [ - "__builtin_return__", - "___169" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___171", - [ - "__builtin_getattr__", - "get_epilogue_arguments_fields_str", - { - "str": "__function__" - } - ] - ], - [ - "___172", - [ - "__builtin_list__", - { - "str": "get_epilogue_arguments_fields_str" - }, - "___171" - ] - ], - [ - "get_epilogue_arguments_init_str", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "param_obj_name" - ], - [ - "__builtin_let__", - [ - [ - "declare_epilogue_arguments_assign", - [ - "__builtin_identity__", - [ - "lambda", - [ - "pair" - ], - [ - "__builtin_let__", - [ - [ - "___173", - [ - "__builtin_getitem__", - "pair", - 0 - ] - ], - [ - "kernel_arg_id", - [ - "__builtin_identity__", - "___173" - ] - ], - [ - "___174", - [ - "__builtin_getitem__", - "pair", - 1 - ] - ], - [ - "var_name", - [ - "__builtin_identity__", - "___174" - ] - ], - [ - "___176", - [ - "__builtin_getattr__", - "self", - { - "str": "kernel_arg_translator" - } - ] - ], - [ - "___175", - [ - "__builtin_getattr__", - "___176", - { - "str": "get_param_struct_field_name" - } - ] - ], - [ - "___177", - [ - "___175", - "var_name" - ] - ], - [ - "field_name", - [ - "__builtin_identity__", - "___177" - ] - ], - [ - "___178", - [ - "__builtin_ToString__", - { - "str": " " - } - ] - ], - [ - "___179", - [ - "__builtin_ToString__", - "param_obj_name" - ] - ], - [ - "___180", - [ - "__builtin_Add__", - "___178", - "___179" - ] - ], - [ - "___181", - [ - "__builtin_ToString__", - { - "str": "." - } - ] - ], - [ - "___182", - [ - "__builtin_Add__", - "___180", - "___181" - ] - ], - [ - "___183", - [ - "__builtin_ToString__", - "field_name" - ] - ], - [ - "___184", - [ - "__builtin_Add__", - "___182", - "___183" - ] - ], - [ - "___185", - [ - "__builtin_ToString__", - { - "str": " = " - } - ] - ], - [ - "___186", - [ - "__builtin_Add__", - "___184", - "___185" - ] - ], - [ - "___187", - [ - "__builtin_ToString__", - "var_name" - ] - ], - [ - "___188", - [ - "__builtin_Add__", - "___186", - "___187" - ] - ], - [ - "___189", - [ - "__builtin_ToString__", - { - "str": ";" - } - ] - ], - [ - "___190", - [ - "__builtin_Add__", - "___188", - "___189" - ] - ], - [ - "___191", - [ - "__builtin_return__", - "___190" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___194", - [ - "__builtin_getattr__", - "self", - { - "str": "mut_kernel_arg_id_registry" - } - ] - ], - [ - "___193", - [ - "__builtin_getattr__", - "___194", - { - "str": "generated_kernel_arg_id2unique_name" - } - ] - ], - [ - "___192", - [ - "__builtin_getattr__", - "___193", - { - "str": "items" - } - ] - ], - [ - "___195", - [ - "___192" - ] - ], - [ - "generated_kernel_arg_id_and_names", - [ - "__builtin_identity__", - "___195" - ] - ], - [ - "___196", - [ - "__builtin_getattr__", - { - "str": "\n" - }, - { - "str": "join" - } - ] - ], - [ - "___197", - [ - "map", - "declare_epilogue_arguments_assign", - "generated_kernel_arg_id_and_names" - ] - ], - [ - "___198", - [ - "___196", - "___197" - ] - ], - [ - "___199", - [ - "__builtin_return__", - "___198" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___200", - [ - "__builtin_getattr__", - "get_epilogue_arguments_init_str", - { - "str": "__function__" - } - ] - ], - [ - "___201", - [ - "__builtin_list__", - { - "str": "get_epilogue_arguments_init_str" - }, - "___200" - ] - ], - [ - "make_project", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "trivial_code_str", - "input_karg", - "weight_karg", - "output_karg", - "m_karg", - "n_karg", - "k_karg" - ], - [ - "__builtin_let__", - [ - [ - "code_template", - [ - "__builtin_identity__", - { - "str": "\n// auto generated codes\n#include \n#include \n\n#include \"native_matmul.cuh\"\n\nnamespace ap {\n\ntemplate \nstruct AddFunctor {\n struct Arguments {\nEPILOGUE_ARGUMENTS_FIELDS\n };\n\n // Note: need to support vectorized operation\n __forceinline__ __host__ __device__\n T operator()(T x, const Arguments& args, const MatrixCoord& coord) const {\n T out;\n AP_GENERATED_BINARY_EPILOGUE_STRING;\n return out;\n }\n};\n\n}\n\nextern \"C\" {\n\nvoid MatmulBinaryKernel(void* stream_ptr, AP_KERNEL_ARGS_DECLARE) {\n ap::GemmEpilogueParams params;\n\n params.batch_count = 1;\n params.m = $m;\n params.n = $n;\n params.k = $k;\n\n params.input = $input;\n params.weight = $weight;\n params.bias = nullptr;\n params.output = $output;\n\n cudaStream_t* cuda_stream_ptr = reinterpret_cast(stream_ptr);\n params.stream = *cuda_stream_ptr;\n\n typename ap::AddFunctor::Arguments epilogue_args;\n\nEPILOGUE_ARGUMENTS_INIT\n\n ap::NativeMatmulAdd>(params, epilogue_args);\n}\n}\n\n " - } - ] - ], - [ - "___202", - [ - "__builtin_getattr__", - "self", - { - "str": "dtype2type_name" - } - ] - ], - [ - "___204", - [ - "__builtin_getattr__", - "output_karg", - { - "str": "type" - } - ] - ], - [ - "___203", - [ - "__builtin_getattr__", - "___204", - { - "str": "data_type" - } - ] - ], - [ - "___205", - [ - "__builtin_getitem__", - "___202", - "___203" - ] - ], - [ - "output_dtype", - [ - "__builtin_identity__", - "___205" - ] - ], - [ - "___216", - [ - "__builtin_getattr__", - "code_template", - { - "str": "replace" - } - ] - ], - [ - "___217", - [ - "___216", - { - "str": "AP_GENERATED_BINARY_EPILOGUE_STRING" - }, - "trivial_code_str" - ] - ], - [ - "___215", - [ - "__builtin_getattr__", - "___217", - { - "str": "replace" - } - ] - ], - [ - "___218", - [ - "___215", - { - "str": "AP_GENERATED_ELEMENT_DTYPE" - }, - "output_dtype" - ] - ], - [ - "___214", - [ - "__builtin_getattr__", - "___218", - { - "str": "replace" - } - ] - ], - [ - "___219", - [ - "__builtin_getattr__", - "self", - { - "str": "get_kernel_arg_list_str" - } - ] - ], - [ - "___220", - [ - "___219" - ] - ], - [ - "___221", - [ - "___214", - { - "str": "AP_KERNEL_ARGS_DECLARE" - }, - "___220" - ] - ], - [ - "___213", - [ - "__builtin_getattr__", - "___221", - { - "str": "replace" - } - ] - ], - [ - "___222", - [ - "__builtin_getattr__", - "self", - { - "str": "get_epilogue_arguments_fields_str" - } - ] - ], - [ - "___223", - [ - "___222" - ] - ], - [ - "___224", - [ - "___213", - { - "str": "EPILOGUE_ARGUMENTS_FIELDS" - }, - "___223" - ] - ], - [ - "___212", - [ - "__builtin_getattr__", - "___224", - { - "str": "replace" - } - ] - ], - [ - "___225", - [ - "__builtin_getattr__", - "self", - { - "str": "get_epilogue_arguments_init_str" - } - ] - ], - [ - "___226", - [ - "___225", - { - "str": "epilogue_args" - } - ] - ], - [ - "___227", - [ - "___212", - { - "str": "EPILOGUE_ARGUMENTS_INIT" - }, - "___226" - ] - ], - [ - "___211", - [ - "__builtin_getattr__", - "___227", - { - "str": "replace" - } - ] - ], - [ - "___228", - [ - "__builtin_getattr__", - "self", - { - "str": "get_kernel_arg_id_var_name" - } - ] - ], - [ - "___229", - [ - "___228", - "input_karg" - ] - ], - [ - "___230", - [ - "___211", - { - "str": "$input" - }, - "___229" - ] - ], - [ - "___210", - [ - "__builtin_getattr__", - "___230", - { - "str": "replace" - } - ] - ], - [ - "___231", - [ - "__builtin_getattr__", - "self", - { - "str": "get_kernel_arg_id_var_name" - } - ] - ], - [ - "___232", - [ - "___231", - "weight_karg" - ] - ], - [ - "___233", - [ - "___210", - { - "str": "$weight" - }, - "___232" - ] - ], - [ - "___209", - [ - "__builtin_getattr__", - "___233", - { - "str": "replace" - } - ] - ], - [ - "___234", - [ - "__builtin_getattr__", - "self", - { - "str": "get_kernel_arg_id_var_name" - } - ] - ], - [ - "___235", - [ - "___234", - "output_karg" - ] - ], - [ - "___236", - [ - "___209", - { - "str": "$output" - }, - "___235" - ] - ], - [ - "___208", - [ - "__builtin_getattr__", - "___236", - { - "str": "replace" - } - ] - ], - [ - "___237", - [ - "__builtin_getattr__", - "self", - { - "str": "get_kernel_arg_id_var_name" - } - ] - ], - [ - "___238", - [ - "___237", - "m_karg" - ] - ], - [ - "___239", - [ - "___208", - { - "str": "$m" - }, - "___238" - ] - ], - [ - "___207", - [ - "__builtin_getattr__", - "___239", - { - "str": "replace" - } - ] - ], - [ - "___240", - [ - "__builtin_getattr__", - "self", - { - "str": "get_kernel_arg_id_var_name" - } - ] - ], - [ - "___241", - [ - "___240", - "n_karg" - ] - ], - [ - "___242", - [ - "___207", - { - "str": "$n" - }, - "___241" - ] - ], - [ - "___206", - [ - "__builtin_getattr__", - "___242", - { - "str": "replace" - } - ] - ], - [ - "___243", - [ - "__builtin_getattr__", - "self", - { - "str": "get_kernel_arg_id_var_name" - } - ] - ], - [ - "___244", - [ - "___243", - "k_karg" - ] - ], - [ - "___245", - [ - "___206", - { - "str": "$k" - }, - "___244" - ] - ], - [ - "code", - [ - "__builtin_identity__", - "___245" - ] - ], - [ - "source_dir", - [ - "__builtin_identity__", - { - "str": "/workspace/Athena/tests/ap/matmul" - } - ] - ], - [ - "cutlass_dir", - [ - "__builtin_identity__", - { - "str": "/workspace/Athena/tests/ap/matmul/cutlass" - } - ] - ], - [ - "compile_cmd", - [ - "__builtin_identity__", - { - "str": "nvcc -std=c++17 -O3 -Xcompiler=-fPIC -arch=sm_80 --expt-relaxed-constexpr" - } - ] - ], - [ - "___246", - [ - "__builtin_Add__", - "compile_cmd", - { - "str": " -I " - } - ] - ], - [ - "___247", - [ - "__builtin_Add__", - "___246", - "cutlass_dir" - ] - ], - [ - "___248", - [ - "__builtin_Add__", - "___247", - { - "str": "/include" - } - ] - ], - [ - "compile_cmd", - [ - "__builtin_identity__", - "___248" - ] - ], - [ - "___249", - [ - "__builtin_Add__", - "compile_cmd", - { - "str": " -I " - } - ] - ], - [ - "___250", - [ - "__builtin_Add__", - "___249", - "cutlass_dir" - ] - ], - [ - "___251", - [ - "__builtin_Add__", - "___250", - { - "str": "/tools/util/include" - } - ] - ], - [ - "compile_cmd", - [ - "__builtin_identity__", - "___251" - ] - ], - [ - "___252", - [ - "__builtin_Add__", - "compile_cmd", - { - "str": " -I " - } - ] - ], - [ - "___253", - [ - "__builtin_Add__", - "___252", - "source_dir" - ] - ], - [ - "compile_cmd", - [ - "__builtin_identity__", - "___253" - ] - ], - [ - "___254", - [ - "__builtin_Add__", - "compile_cmd", - { - "str": " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -DCUTLASS_DEBUG_TRACE_LEVEL=0 " - } - ] - ], - [ - "compile_cmd", - [ - "__builtin_identity__", - "___254" - ] - ], - [ - "___255", - [ - "__builtin_Add__", - "compile_cmd", - { - "str": " --shared matmul_binary_kernel.cu -o libmatmul_binary_kernel.so" - } - ] - ], - [ - "compile_cmd", - [ - "__builtin_identity__", - "___255" - ] - ], - [ - "___256", - [ - "__builtin_getattr__", - "DataType", - { - "str": "void" - } - ] - ], - [ - "___257", - [ - "__builtin_getattr__", - "PointerType", - { - "str": "void_ptr" - } - ] - ], - [ - "___258", - [ - "__builtin_getattr__", - "self", - { - "str": "get_kernel_arg_types" - } - ] - ], - [ - "___259", - [ - "___258" - ] - ], - [ - "___260", - [ - "__builtin_starred__", - "___259" - ] - ], - [ - "___261", - [ - "__builtin_list__", - "___257", - "___260" - ] - ], - [ - "___262", - [ - "FuncDeclare", - "___256", - { - "str": "MatmulBinaryKernel" - }, - "___261" - ] - ], - [ - "___263", - [ - "__builtin_getattr__", - "Project", - { - "str": "Directory" - } - ] - ], - [ - "___264", - [ - "__builtin_getattr__", - "Project", - { - "str": "FileContent" - } - ] - ], - [ - "___265", - [ - "___264", - "code" - ] - ], - [ - "___266", - [ - "__builtin_list__", - { - "str": "matmul_binary_kernel.cu" - }, - "___265" - ] - ], - [ - "___267", - [ - "__builtin_getattr__", - "Project", - { - "str": "FileContent" - } - ] - ], - [ - "___268", - [ - "___267", - "compile_cmd" - ] - ], - [ - "___269", - [ - "__builtin_list__", - { - "str": "make.sh" - }, - "___268" - ] - ], - [ - "___270", - [ - "___263", - "___266", - "___269" - ] - ], - [ - "___271", - [ - "__builtin_list__", - { - "str": "nested_files" - }, - "___270" - ] - ], - [ - "___272", - [ - "__builtin_list__", - { - "str": "compile_cmd" - }, - { - "str": "sh make.sh" - } - ] - ], - [ - "___273", - [ - "__builtin_list__", - { - "str": "so_relative_path" - }, - { - "str": "libmatmul_binary_kernel.so" - } - ] - ], - [ - "___274", - [ - "__builtin_list__", - "___271", - "___272", - "___273" - ] - ], - [ - "___275", - [ - "__builtin_list__" - ] - ], - [ - "___276", - [ - "__builtin_PackedArgs__", - "___275", - "___274" - ] - ], - [ - "___277", - [ - "Project", - "___276" - ] - ], - [ - "___278", - [ - "CodeModule", - "___262", - "___277" - ] - ], - [ - "___279", - [ - "__builtin_return__", - "___278" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___280", - [ - "__builtin_getattr__", - "make_project", - { - "str": "__function__" - } - ] - ], - [ - "___281", - [ - "__builtin_list__", - { - "str": "make_project" - }, - "___280" - ] - ], - [ - "___282", - [ - "__builtin_list__" - ] - ], - [ - "___283", - [ - "__builtin_list__", - "___36", - "___49", - "___94", - "___104", - "___114", - "___120", - "___144", - "___172", - "___201", - "___281" - ] - ], - [ - "___284", - [ - "__builtin_PackedArgs__", - "___282", - "___283" - ] - ], - [ - "___285", - [ - "BuiltinSerializableAttrMap", - "___284" - ] - ], - [ - "___286", - [ - "type", - { - "str": "MatmulBinaryTemplate" - }, - "___9", - "___285" - ] - ], - [ - "MatmulBinaryTemplate", - [ - "__builtin_identity__", - "___286" - ] - ], - [ - "KernelDispatch", - [ - "__builtin_identity__", - [ - "lambda", - [ - "ctx" - ], - [ - "__builtin_let__", - [ - [ - "___287", - [ - "__builtin_getattr__", - "ctx", - { - "str": "get_so_function" - } - ] - ], - [ - "___288", - [ - "___287", - { - "str": "MatmulBinaryKernel" - } - ] - ], - [ - "so_func", - [ - "__builtin_identity__", - "___288" - ] - ], - [ - "___290", - [ - "__builtin_getattr__", - "ctx", - { - "str": "device_ctx" - } - ] - ], - [ - "___289", - [ - "__builtin_getattr__", - "___290", - { - "str": "get_stream_addr_as_void_ptr" - } - ] - ], - [ - "___291", - [ - "___289" - ] - ], - [ - "stream_ptr", - [ - "__builtin_identity__", - "___291" - ] - ], - [ - "___293", - [ - "__builtin_getattr__", - "ctx", - { - "str": "kernel_dispatch_const_data" - } - ] - ], - [ - "___292", - [ - "__builtin_getattr__", - "___293", - { - "str": "kernel_args_getters" - } - ] - ], - [ - "getters", - [ - "__builtin_identity__", - "___292" - ] - ], - [ - "___295", - [ - "map", - [ - "lambda", - [ - "getter" - ], - [ - "__builtin_let__", - [ - [ - "___294", - [ - "getter", - "ctx" - ] - ] - ], - [ - "__builtin_identity__", - "___294" - ] - ] - ], - "getters" - ] - ], - [ - "___296", - [ - "__builtin_starred__", - "___295" - ] - ], - [ - "___297", - [ - "__builtin_list__", - "stream_ptr", - "___296" - ] - ], - [ - "args", - [ - "__builtin_identity__", - "___297" - ] - ], - [ - "___298", - [ - "apply", - "so_func", - "args" - ] - ], - [ - "___299", - [ - "__builtin_identity__", - "___298" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ] - ], - [ - "__builtin_identity__", - "KernelDispatch" - ] -] \ No newline at end of file diff --git a/tests/ap/matmul_binary_tpl.py b/tests/ap/matmul_variadic_tpl.py similarity index 95% rename from tests/ap/matmul_binary_tpl.py rename to tests/ap/matmul_variadic_tpl.py index 5cd31c1..dc122c2 100644 --- a/tests/ap/matmul_binary_tpl.py +++ b/tests/ap/matmul_variadic_tpl.py @@ -10,7 +10,7 @@ def get_anchor_iter_var_names(): return ["coord.batch", "coord.row", "coord.column"] -class MatmulBinaryTemplate: +class MatmulVariadicTemplate: def __init__( self, program_translator, @@ -69,7 +69,7 @@ def compile( mut_lir_code_gen_ctx=mut_lir_code_gen_ctx, ) trivial_code_str = mut_lir_code_gen_ctx.get_stmts_joined_str(indent=" ") - print("-- matmul_binary_epilogue_code:\n", trivial_code_str) + print("-- matmul_variadic_epilogue_code:\n", trivial_code_str) project_module = self.make_project( trivial_code_str, input0_karg, @@ -208,7 +208,7 @@ def make_project( extern "C" { -void MatmulBinaryKernel(void* stream_ptr, AP_KERNEL_ARGS_DECLARE) { +void MatmulVariadicKernel(void* stream_ptr, AP_KERNEL_ARGS_DECLARE) { std::vector $input0_shape; AP_INPUT0_SHAPE_INIT @@ -273,28 +273,28 @@ def make_project( ) compile_cmd = ( compile_cmd - + " --shared matmul_binary_kernel.cu -o libmatmul_binary_kernel.so" + + " --shared matmul_variadic_kernel.cu -o libmatmul_variadic_kernel.so" ) return CodeModule( FuncDeclare( DataType.void, - "MatmulBinaryKernel", + "MatmulVariadicKernel", [PointerType.void_ptr, *self.get_kernel_arg_types()], ), Project( nested_files=Project.Directory( - ["matmul_binary_kernel.cu", Project.FileContent(code)], + ["matmul_variadic_kernel.cu", Project.FileContent(code)], ["make.sh", Project.FileContent(compile_cmd)], ), compile_cmd="sh make.sh", - so_relative_path="libmatmul_binary_kernel.so", + so_relative_path="libmatmul_variadic_kernel.so", ), ) def KernelDispatch(ctx): - so_func = ctx.get_so_function("MatmulBinaryKernel") + so_func = ctx.get_so_function("MatmulVariadicKernel") stream_ptr = ctx.device_ctx.get_stream_addr_as_void_ptr() getters = ctx.kernel_dispatch_const_data.kernel_args_getters args = [stream_ptr, *map(lambda getter: getter(ctx), getters)] diff --git a/tests/ap/paddle-tests/test_matmul_unary.py b/tests/ap/paddle-tests/test_matmul_unary.py new file mode 100644 index 0000000..5f7bf2d --- /dev/null +++ b/tests/ap/paddle-tests/test_matmul_unary.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from os.path import dirname + +sys.path.append(dirname(__file__)) + +import unittest + +import utils + +import paddle +from paddle.static import InputSpec + + +def trivial_matrix_unary(x, y): + out = paddle.matmul(x, y) + return paddle.scale(out, scale=0.1) + + +class CINNSubGraphNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.fn = trivial_matrix_unary + + def forward(self, x, y): + out = self.fn(x, y) + return out + + +class TestAPMatmulUnary(unittest.TestCase): + """ + Test Pir API + @to_static + CINN. + """ + + def setUp(self): + paddle.seed(2022) + self.prepare_data() + + def prepare_data(self): + self.dtype = "float16" + + self.x_shape = [4, 65536, 128] + self.x = paddle.randn(self.x_shape, dtype=self.dtype) + self.x.stop_gradient = False + + self.y_shape = [128, 32] + self.y = paddle.randn(self.y_shape, dtype=self.dtype) + self.y.stop_gradient = False + + def eval_symbolic(self, use_cinn, profile): + net = CINNSubGraphNet() + input_spec = [ + InputSpec(shape=self.x_shape, dtype=self.dtype), + InputSpec(shape=self.y_shape, dtype=self.dtype), + ] + net = utils.apply_to_static(net, use_cinn, input_spec) + net.eval() + out = utils.run_with_profile(profile, net, self.x, self.y) + return out + + def test_eval_symbolic(self): + profile = False + cinn_out = self.eval_symbolic(use_cinn=True, profile=profile) + dy_out = self.eval_symbolic(use_cinn=False, profile=profile) + if not profile: + utils.check_result(self.dtype, cinn_out.numpy(), dy_out.numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/ap/test_matmul_binary.py b/tests/ap/test_matmul_binary.py index 4b48ee8..ae94b5e 100644 --- a/tests/ap/test_matmul_binary.py +++ b/tests/ap/test_matmul_binary.py @@ -2,7 +2,7 @@ import access_topo_drr import topo_drr_pass import op_convertion_drr_pass -import matmul_binary_tpl +import matmul_variadic_tpl import ir_tools import index_program_translator_util import op_compute_translator_util @@ -250,7 +250,7 @@ def _insert_store_to_global(self, program, output_names): init_pass_manager.run(program) def _make_kernel_arg_translator(self): - return matmul_binary_tpl.make_kernel_arg_translator() + return matmul_variadic_tpl.make_kernel_arg_translator() def _apply_topo_access_passes(self, mut_program, anchor_data_op_name): init_pass_manager = ir_tools.create_pass_manager() @@ -349,6 +349,7 @@ def _get_program_translator(self, ctx, o, t): mut_program = ir_tools.copy_fused_ops_to_program( o.trivial_op, tensor_match_ctx=t ) + print("origin-program_translator", mut_program) self._insert_load_from_global( mut_program, input_names=["mm_out", "input2"] @@ -368,7 +369,7 @@ def _get_program_translator(self, ctx, o, t): index_program_translator_map = index_program_translator_util.IndexProgramTranslatorMap( index_func_unique_id2index_program=index_func_unique_id2index_program, kernel_arg_translator=kernel_arg_translator, - anchor_iter_var_names=matmul_binary_tpl.get_anchor_iter_var_names() + anchor_iter_var_names=matmul_variadic_tpl.get_anchor_iter_var_names() ) self._replace_with_load_from_register( mut_program, @@ -380,7 +381,7 @@ def _get_program_translator(self, ctx, o, t): store_ir_value_name="output", register_var_name="out" ) - print("mut_program:", mut_program) + print("after-insert-load-store-program_translator", mut_program) op_compute_translator_maker = op_compute_translator_util.OpComputeTranslatorFactory() program_translator = program_translator_util.ProgramTranslator( program_property=mut_program.copy_to_const_program_data(), @@ -397,7 +398,7 @@ def code_gen(self, ctx, o, t): tensor_match_ctx=t, name_prefix="" ) - template_module = matmul_binary_tpl.MatmulBinaryTemplate( + template_module = matmul_variadic_tpl.MatmulVariadicTemplate( program_translator=program_translator, mut_kernel_arg_id_registry=mut_kernel_arg_id_registry, ) diff --git a/tests/ap/test_matmul_unary.py b/tests/ap/test_matmul_unary.py new file mode 100644 index 0000000..6b99fa5 --- /dev/null +++ b/tests/ap/test_matmul_unary.py @@ -0,0 +1,188 @@ +import abstract_drr +import access_topo_drr +import topo_drr_pass +import op_convertion_drr_pass +import low_level_ir_code_gen_ctx_util +import matmul_variadic_tpl +import ir_tools +import op_compute_translator_util +import program_translator_util +import kernel_arg_id_util +import kernel_arg_translator_util +import pir + + +class RemoveDataOp2DownSpiderOp2YieldOpPass(access_topo_drr.DrrPass): + def __init__(self, data_op_name): + self.data_op_name = pir.a_str(data_op_name) + + def source_pattern(self, o, t): + o.data_op = o.ap_native_op("pd_op.data") + o.data_op.name = self.data_op_name + o.data_op( + [], + [t.data_op_out] + ) + o.down_spider_op = o.ap_native_op("ap_op.down_spider") + o.down_spider_op( + [t.data_op_out], + [t.down_spider_op_out]) + o.yield_op = o.ap_native_op("cf.yield") + o.yield_op( + [t.down_spider_op_out], + [] + ) + + def result_pattern(self, o, t): + pass + + +@abstract_drr.register_drr_pass("matmul_unary_fusion", nice=0) +class MatmulUnaryFusion(abstract_drr.DrrPass): + def source_pattern(self, o, t): + o.matmul_op = o.ap_native_op("pd_op.matmul") + o.matmul_op( + [t.input0, t.input1], + [t.mm_out] + ) + o.trivial_op = o.ap_trivial_fusion_op() + o.trivial_op( + [t.mm_out], + [t.output] + ) + + def result_pattern(self, o, t): + o.fustion_op = o.ap_pattern_fusion_op(self.code_gen) + o.fustion_op([t.input0, t.input1], [t.output]) + + def constraint(self, o, t): + program = ir_tools.copy_fused_ops_to_program(o.trivial_op, tensor_match_ctx=t) + print("before-access_topo_pass", program) + init_pass_manager = ir_tools.create_pass_manager() + init_down_spider = topo_drr_pass.InitDownSpiderAccessTopoPass("mm_out") + init_pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass(init_down_spider) + ) + init_pass_manager.run(program) + print("after-init-access_topo_pass", program) + pass_manager = ir_tools.create_pass_manager() + pass_manager.add_pass(ir_tools.create_access_topo_drr_pass("default")) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(program) + print("after-apply-access_topo_pass", program) + pass_manager = ir_tools.create_pass_manager() + remove_data_op2down_spider_op2yield_op_pass = ( + RemoveDataOp2DownSpiderOp2YieldOpPass( + data_op_name="mm_out", + ) + ) + pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass( + remove_data_op2down_spider_op2yield_op_pass + ) + ) + pass_manager.run(program) + print("after-remove-input-output-access_topo_pass", program) + return program.empty() + + def _insert_load_from_global(self, program, input_names): + init_pass_manager = ir_tools.create_pass_manager() + + def AddPass(input_name): + ir_pass = topo_drr_pass.InitNaiveLoadFromGlobalAccessTopoPass(input_name) + init_pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass(ir_pass) + ) + + map(AddPass, input_names) + init_pass_manager.run(program) + + def _insert_store_to_global(self, program, output_names): + init_pass_manager = ir_tools.create_pass_manager() + ir_pass = topo_drr_pass.FakeDataStoreToGlobalForYieldAccessTopoPass( + output_names + ) + init_pass_manager.add_pass( + ir_tools.create_access_topo_drr_one_step_pass(ir_pass) + ) + init_pass_manager.run(program) + + def _make_kernel_arg_translator(self): + return matmul_variadic_tpl.make_kernel_arg_translator() + + def _replace_with_load_from_register( + self, mut_program, load_ir_value_name, register_var_name + ): + pass_manager = ir_tools.create_pass_manager() + drr_pass = topo_drr_pass.ReplaceWithLoadFromRegisterPass( + name=load_ir_value_name, register_var_name=register_var_name + ) + pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(drr_pass)) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(mut_program) + return mut_program + + def _replace_with_store_to_register( + self, mut_program, store_ir_value_name, register_var_name + ): + pass_manager = ir_tools.create_pass_manager() + drr_pass = topo_drr_pass.ReplaceWithStoreToRegisterPass( + name=store_ir_value_name, register_var_name=register_var_name + ) + pass_manager.add_pass(ir_tools.create_access_topo_drr_one_step_pass(drr_pass)) + pass_manager.add_pass(ir_tools.create_dce_pass()) + pass_manager.run(mut_program) + return mut_program + + def _get_program_translator(self, ctx, o, t): + mut_program = ir_tools.copy_fused_ops_to_program( + o.trivial_op, tensor_match_ctx=t + ) + print("origin-program_translator", mut_program) + self._insert_load_from_global(mut_program, input_names=["mm_out"]) + self._insert_store_to_global(mut_program, output_names=["output"]) + kernel_arg_translator = self._make_kernel_arg_translator() + self._replace_with_load_from_register( + mut_program, load_ir_value_name="mm_out", register_var_name="x" + ) + self._replace_with_store_to_register( + mut_program, store_ir_value_name="output", register_var_name="out" + ) + print("after-insert-load-store-program_translator", mut_program) + op_compute_translator_maker = ( + op_compute_translator_util.OpComputeTranslatorFactory() + ) + program_translator = program_translator_util.ProgramTranslator( + program_property=mut_program.copy_to_const_program_data(), + kernel_arg_translator=kernel_arg_translator, + index_program_translator_map=None, + op_translator_maker=op_compute_translator_maker, + ) + return program_translator + + def code_gen(self, ctx, o, t): + program_translator = self._get_program_translator(ctx, o, t) + mut_kernel_arg_id_registry = kernel_arg_id_util.KernelArgIdNameRegistry( + code_gen_ctx=ctx, tensor_match_ctx=t, name_prefix="" + ) + template_module = matmul_variadic_tpl.MatmulVariadicTemplate( + program_translator=program_translator, + mut_kernel_arg_id_registry=mut_kernel_arg_id_registry, + ) + + def get_symbolic_shape_args_list(sym_dim): + return ctx.dim_expr_kernel_arg_id(sym_dim) + + input0_shape_kargs = map( + get_symbolic_shape_args_list, t.input0.symbolic_shape_to_list() + ) + input1_shape_kargs = map( + get_symbolic_shape_args_list, t.input1.symbolic_shape_to_list() + ) + return template_module.compile( + input0_karg=ctx.in_tensor_data_ptr_kernel_arg_id(t.input0), + input1_karg=ctx.in_tensor_data_ptr_kernel_arg_id(t.input1), + output_karg=ctx.out_tensor_data_ptr_kernel_arg_id(t.output), + input0_shape_kargs=input0_shape_kargs, + input1_shape_kargs=input1_shape_kargs, + ) From 28b6cd3173386bd87d2709cefb872bc7d5aaa8c2 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 20 Feb 2025 14:22:30 +0800 Subject: [PATCH 2/5] Optimize the CodeGen for scale. --- tests/ap/code_gen_value_util.py | 32 ++++++++++++------ tests/ap/op_compute_translator_util.py | 47 ++++++++++++++++---------- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/tests/ap/code_gen_value_util.py b/tests/ap/code_gen_value_util.py index 545757c..46569e6 100644 --- a/tests/ap/code_gen_value_util.py +++ b/tests/ap/code_gen_value_util.py @@ -1,14 +1,24 @@ +class TensorCodeGenValue: + def __init__(self, pir_type, var_name): + self.pir_type = pir_type + self.var_name = var_name + self.const_value = None -class CodeGenValue: - def __init__(self, pir_type, var_name): - self.pir_type = pir_type - self.var_name = var_name - self.const_value = None + def get_dtype(self): + def convert_to_dtype(pir_dtype, shape, data_layout): + return pir_dtype.convert_to_dtype() - def get_dtype(self): - def convert_to_dtype(pir_dtype, shape, data_layout): - return pir_dtype.convert_to_dtype() - return self.pir_type.match(t_dtensor=convert_to_dtype) + return self.pir_type.match(t_dtensor=convert_to_dtype) - def is_dense_tensor_type(self): - return self.pir_type.get_type_name() == "t_dtensor" + def is_dense_tensor_type(self): + return self.pir_type.get_type_name() == "t_dtensor" + + +class AttrCodeGenValue: + def __init__(self, data_type, var_name): + self.data_type = data_type + self.var_name = var_name + self.const_value = None + + def get_dtype(self): + return self.data_type diff --git a/tests/ap/op_compute_translator_util.py b/tests/ap/op_compute_translator_util.py index 3ecc50d..6fe0370 100644 --- a/tests/ap/op_compute_translator_util.py +++ b/tests/ap/op_compute_translator_util.py @@ -20,7 +20,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): def get_out_cg_val(self, i): register_var_name_attr = self.op_property.attributes.register_var_name register_var_name = register_var_name_attr.match(a_str=lambda x:x) - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, register_var_name ) @@ -55,7 +55,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -102,7 +102,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): def get_out_cg_val(self, i): name = self.op_property.attributes.name.match(a_str=lambda x:x) - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, name ) @@ -128,7 +128,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -163,7 +163,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -188,7 +188,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -213,7 +213,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -233,22 +233,33 @@ def __init__(self, self.index_program_translator_map = index_program_translator_map def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): - scale = self.op_property.attributes.scale.match(a_f32=lambda x:x) - bias = self.op_property.attributes.bias.match(a_f32=lambda x:x) + scale_value = self.op_property.attributes.scale.match(a_f32=lambda x:x) + scale_var_name = f"op{self.op_property.op_index}_scale" + scale = self.get_tmp_cg_val(scale_var_name) + mut_lir_code_gen_ctx.let(scale, f"{scale_value}") + + bias_value = self.op_property.attributes.bias.match(a_f32=lambda x:x) + bias_var_name = f"op{self.op_property.op_index}_bias" + bias = self.get_tmp_cg_val(bias_var_name) + mut_lir_code_gen_ctx.let(bias, f"{bias_value}") + bias_after_scale = self.op_property.attributes.bias_after_scale.match(a_bool=lambda x:x) in_name = inputs[0].var_name - true_str = f"{scale} * {in_name} + {bias}" - false_str = f"{scale} * ({in_name} + {bias})" + true_str = f"{scale_var_name} * {in_name} + {bias_var_name}" + false_str = f"{scale_var_name} * ({in_name} + {bias_var_name})" out = self.get_out_cg_val(0) mut_lir_code_gen_ctx.let(out, true_str if bias_after_scale else false_str) return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) + def get_tmp_cg_val(self, name): + return code_gen_value_util.AttrCodeGenValue(DataType.float, f"{name}") + class PdOpSubstractCodeGen: def __init__(self, @@ -271,7 +282,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -298,7 +309,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -325,7 +336,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -352,7 +363,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -379,7 +390,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -436,7 +447,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) From df14aeb77dafa159bc6981497aca393ee626b3ba Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 25 Feb 2025 10:39:50 +0800 Subject: [PATCH 3/5] Update make_axpr.sh and copyright. --- tests/ap/make_axpr.sh | 8 +++----- tests/ap/paddle-tests/test_matmul_unary.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/ap/make_axpr.sh b/tests/ap/make_axpr.sh index 0cd0476..bc91f19 100644 --- a/tests/ap/make_axpr.sh +++ b/tests/ap/make_axpr.sh @@ -3,12 +3,9 @@ TEST_FILENAME=${1:-"test_trivial_reduce"} #TEST_FILENAME=${1:-"test_matmul_unary"} -TEST_TPL_FILENAME=`echo ${TEST_FILENAME/test_/}` - echo "-- Write 'import ${TEST_FILENAME}' to __main__.py" echo "import ${TEST_FILENAME}" > __main__.py - FILENAMES_ARRAY=( "index_code_gen_value_util" "index_drr_pass_util" @@ -26,8 +23,9 @@ FILENAMES_ARRAY=( "access_topo_drr" "abstract_drr" "ap_tpl_codegen" - "${TEST_FILENAME}" - "${TEST_TPL_FILENAME}_tpl" + "matmul_variadic_tpl" + "test_matmul_unary" + "test_matmul_binary" ) for filename in "${FILENAMES_ARRAY[@]}" do diff --git a/tests/ap/paddle-tests/test_matmul_unary.py b/tests/ap/paddle-tests/test_matmul_unary.py index 5f7bf2d..bade8cc 100644 --- a/tests/ap/paddle-tests/test_matmul_unary.py +++ b/tests/ap/paddle-tests/test_matmul_unary.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 30ff7c601640d74549d87bc4fa4bf29c71081c8c Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 26 Feb 2025 20:47:14 +0800 Subject: [PATCH 4/5] Rename test_matmul_unary/binary to matmul_unary/binary_pattern. --- tests/ap/__main__.py | 5 +- tests/ap/make_axpr.sh | 10 +- ...mul_binary.py => matmul_binary_pattern.py} | 0 ...atmul_unary.py => matmul_unary_pattern.py} | 0 tests/ap/op_compute_translator_util.py | 8 +- tests/ap/paddle-tests/test_matmul_binary.py | 30 +- tests/ap/test_matmul_binary.py.json | 8709 ----------------- tests/ap/test_matmul_binary.sh | 2 +- 8 files changed, 29 insertions(+), 8735 deletions(-) rename tests/ap/{test_matmul_binary.py => matmul_binary_pattern.py} (100%) rename tests/ap/{test_matmul_unary.py => matmul_unary_pattern.py} (100%) delete mode 100644 tests/ap/test_matmul_binary.py.json diff --git a/tests/ap/__main__.py b/tests/ap/__main__.py index b22c965..c3000f4 100644 --- a/tests/ap/__main__.py +++ b/tests/ap/__main__.py @@ -1,3 +1,2 @@ -# import test_trivial_reduce -# import test_binary_trivial_reduce -import test_matmul_binary +import matmul_unary_pattern +import matmul_binary_pattern diff --git a/tests/ap/make_axpr.sh b/tests/ap/make_axpr.sh index bc91f19..2551800 100644 --- a/tests/ap/make_axpr.sh +++ b/tests/ap/make_axpr.sh @@ -1,11 +1,5 @@ #!/bin/bash -TEST_FILENAME=${1:-"test_trivial_reduce"} -#TEST_FILENAME=${1:-"test_matmul_unary"} - -echo "-- Write 'import ${TEST_FILENAME}' to __main__.py" -echo "import ${TEST_FILENAME}" > __main__.py - FILENAMES_ARRAY=( "index_code_gen_value_util" "index_drr_pass_util" @@ -24,8 +18,8 @@ FILENAMES_ARRAY=( "abstract_drr" "ap_tpl_codegen" "matmul_variadic_tpl" - "test_matmul_unary" - "test_matmul_binary" + "matmul_unary_pattern" + "matmul_binary_pattern" ) for filename in "${FILENAMES_ARRAY[@]}" do diff --git a/tests/ap/test_matmul_binary.py b/tests/ap/matmul_binary_pattern.py similarity index 100% rename from tests/ap/test_matmul_binary.py rename to tests/ap/matmul_binary_pattern.py diff --git a/tests/ap/test_matmul_unary.py b/tests/ap/matmul_unary_pattern.py similarity index 100% rename from tests/ap/test_matmul_unary.py rename to tests/ap/matmul_unary_pattern.py diff --git a/tests/ap/op_compute_translator_util.py b/tests/ap/op_compute_translator_util.py index 31c7d88..3a26142 100644 --- a/tests/ap/op_compute_translator_util.py +++ b/tests/ap/op_compute_translator_util.py @@ -238,7 +238,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -260,11 +260,11 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): exponent = inputs[1].var_name var_name = inputs[0].var_name out = self.get_out_cg_val(0) - mut_lir_code_gen_ctx.let(out, f"pow({var_name},{exponent})") + mut_lir_code_gen_ctx.let(out, f"pow({var_name}, {exponent})") return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -289,7 +289,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) diff --git a/tests/ap/paddle-tests/test_matmul_binary.py b/tests/ap/paddle-tests/test_matmul_binary.py index 8e38990..1efc99f 100644 --- a/tests/ap/paddle-tests/test_matmul_binary.py +++ b/tests/ap/paddle-tests/test_matmul_binary.py @@ -18,21 +18,22 @@ sys.path.append(dirname(__file__)) import unittest - import utils import paddle from paddle.static import InputSpec -def trivial_matrix_binary(x, y, b): +def matmul_add_relu(x, y, b): out = paddle.matmul(x, y) return paddle.nn.functional.relu(out + b) -def trivial_matrix_binary_gelu_true(x, y, b): + +def matmul_add_gelu_true(x, y, b): out = paddle.matmul(x, y) return paddle.nn.functional.gelu(out + b, True) + class CINNSubGraphNet(paddle.nn.Layer): def __init__(self, fn): super().__init__() @@ -42,7 +43,8 @@ def forward(self, x, y, b): out = self.fn(x, y, b) return out -class TestAPMatmulBinaryTriangleShape(unittest.TestCase): + +class TestAPMatmulBinary(unittest.TestCase): """ Test Pir API + @to_static + CINN. """ @@ -66,8 +68,7 @@ def prepare_data(self): self.b = paddle.randn(self.b_shape, dtype=self.dtype) self.b.stop_gradient = False - def eval_symbolic(self, use_cinn, profile): - net = CINNSubGraphNet(trivial_matrix_binary_gelu_true) + def eval_symbolic(self, net, use_cinn, profile): input_spec = [ InputSpec(shape=self.x_shape, dtype=self.dtype), InputSpec(shape=self.y_shape, dtype=self.dtype), @@ -78,12 +79,21 @@ def eval_symbolic(self, use_cinn, profile): out = utils.run_with_profile(profile, net, self.x, self.y, self.b) return out - def test_eval_symbolic(self): + def test_matmul_add_relu(self): + profile = False + net = CINNSubGraphNet(matmul_add_relu) + cinn_out = self.eval_symbolic(net, use_cinn=True, profile=profile) + dy2st_out = self.eval_symbolic(net, use_cinn=False, profile=profile) + if not profile: + utils.check_result(self.dtype, cinn_out.numpy(), dy2st_out.numpy()) + + def test_matmul_add_gelu(self): profile = False - cinn_out = self.eval_symbolic(use_cinn=True, profile=profile) - dy_out = self.eval_symbolic(use_cinn=False, profile=profile) + net = CINNSubGraphNet(matmul_add_gelu_true) + cinn_out = self.eval_symbolic(net, use_cinn=True, profile=profile) + dy2st_out = self.eval_symbolic(net, use_cinn=False, profile=profile) if not profile: - utils.check_result(self.dtype, cinn_out.numpy(), dy_out.numpy()) + utils.check_result(self.dtype, cinn_out.numpy(), dy2st_out.numpy()) if __name__ == "__main__": diff --git a/tests/ap/test_matmul_binary.py.json b/tests/ap/test_matmul_binary.py.json deleted file mode 100644 index d737618..0000000 --- a/tests/ap/test_matmul_binary.py.json +++ /dev/null @@ -1,8709 +0,0 @@ -[ - "__builtin_let__", - [ - [ - "abstract_drr", - [ - "import", - { - "str": "abstract_drr" - } - ] - ], - [ - "access_topo_drr", - [ - "import", - { - "str": "access_topo_drr" - } - ] - ], - [ - "topo_drr_pass", - [ - "import", - { - "str": "topo_drr_pass" - } - ] - ], - [ - "op_convertion_drr_pass", - [ - "import", - { - "str": "op_convertion_drr_pass" - } - ] - ], - [ - "matmul_binary_tpl", - [ - "import", - { - "str": "matmul_binary_tpl" - } - ] - ], - [ - "ir_tools", - [ - "import", - { - "str": "ir_tools" - } - ] - ], - [ - "index_program_translator_util", - [ - "import", - { - "str": "index_program_translator_util" - } - ] - ], - [ - "op_compute_translator_util", - [ - "import", - { - "str": "op_compute_translator_util" - } - ] - ], - [ - "program_translator_util", - [ - "import", - { - "str": "program_translator_util" - } - ] - ], - [ - "kernel_arg_id_util", - [ - "import", - { - "str": "kernel_arg_id_util" - } - ] - ], - [ - "low_level_ir_code_gen_ctx_util", - [ - "import", - { - "str": "low_level_ir_code_gen_ctx_util" - } - ] - ], - [ - "kernel_arg_translator_util", - [ - "import", - { - "str": "kernel_arg_translator_util" - } - ] - ], - [ - "pir", - [ - "import", - { - "str": "pir" - } - ] - ], - [ - "___0", - [ - "__builtin_getattr__", - "access_topo_drr", - { - "str": "DrrPass" - } - ] - ], - [ - "___1", - [ - "__builtin_list__", - "___0" - ] - ], - [ - "__init__", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "src_data_op_name", - "dst_data_op_name" - ], - [ - "__builtin_let__", - [ - [ - "___2", - [ - "__builtin_getattr__", - "pir", - { - "str": "a_str" - } - ] - ], - [ - "___3", - [ - "___2", - "src_data_op_name" - ] - ], - [ - "___4", - [ - "__builtin_setattr__", - "self", - { - "str": "src_data_op_name" - } - ] - ], - [ - "___5", - [ - "___4", - { - "str": "src_data_op_name" - }, - "___3" - ] - ], - [ - "___6", - [ - "__builtin_getattr__", - "pir", - { - "str": "a_str" - } - ] - ], - [ - "___7", - [ - "___6", - "dst_data_op_name" - ] - ], - [ - "___8", - [ - "__builtin_setattr__", - "self", - { - "str": "dst_data_op_name" - } - ] - ], - [ - "___9", - [ - "___8", - { - "str": "dst_data_op_name" - }, - "___7" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___10", - [ - "__builtin_getattr__", - "__init__", - { - "str": "__function__" - } - ] - ], - [ - "___11", - [ - "__builtin_list__", - { - "str": "__init__" - }, - "___10" - ] - ], - [ - "source_pattern", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - [ - "__builtin_let__", - [ - [ - "___12", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___13", - [ - "___12", - { - "str": "pd_op.data" - } - ] - ], - [ - "___14", - [ - "__builtin_setattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___15", - [ - "___14", - { - "str": "src_data_op" - }, - "___13" - ] - ], - [ - "___16", - [ - "__builtin_getattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___17", - [ - "__builtin_list__" - ] - ], - [ - "___18", - [ - "__builtin_getattr__", - "t", - { - "str": "input0" - } - ] - ], - [ - "___19", - [ - "__builtin_list__", - "___18" - ] - ], - [ - "___20", - [ - "___16", - "___17", - "___19" - ] - ], - [ - "___21", - [ - "__builtin_identity__", - "___20" - ] - ], - [ - "___22", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___23", - [ - "___22", - { - "str": "pd_op.data" - } - ] - ], - [ - "___24", - [ - "__builtin_setattr__", - "o", - { - "str": "dst_data_op" - } - ] - ], - [ - "___25", - [ - "___24", - { - "str": "dst_data_op" - }, - "___23" - ] - ], - [ - "___26", - [ - "__builtin_getattr__", - "o", - { - "str": "dst_data_op" - } - ] - ], - [ - "___27", - [ - "__builtin_list__" - ] - ], - [ - "___28", - [ - "__builtin_getattr__", - "t", - { - "str": "input1" - } - ] - ], - [ - "___29", - [ - "__builtin_list__", - "___28" - ] - ], - [ - "___30", - [ - "___26", - "___27", - "___29" - ] - ], - [ - "___31", - [ - "__builtin_identity__", - "___30" - ] - ], - [ - "___32", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___33", - [ - "___32", - { - "str": "ap_op.up_spider" - } - ] - ], - [ - "___34", - [ - "__builtin_setattr__", - "o", - { - "str": "up_spider_op" - } - ] - ], - [ - "___35", - [ - "___34", - { - "str": "up_spider_op" - }, - "___33" - ] - ], - [ - "___36", - [ - "__builtin_getattr__", - "o", - { - "str": "up_spider_op" - } - ] - ], - [ - "___37", - [ - "__builtin_getattr__", - "t", - { - "str": "input0" - } - ] - ], - [ - "___38", - [ - "__builtin_getattr__", - "t", - { - "str": "input1" - } - ] - ], - [ - "___39", - [ - "__builtin_list__", - "___37", - "___38" - ] - ], - [ - "___40", - [ - "__builtin_list__" - ] - ], - [ - "___41", - [ - "___36", - "___39", - "___40" - ] - ], - [ - "___42", - [ - "__builtin_identity__", - "___41" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___43", - [ - "__builtin_getattr__", - "source_pattern", - { - "str": "__function__" - } - ] - ], - [ - "___44", - [ - "__builtin_list__", - { - "str": "source_pattern" - }, - "___43" - ] - ], - [ - "constraint", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - [ - "__builtin_let__", - [ - [ - "___46", - [ - "__builtin_getattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___45", - [ - "__builtin_getattr__", - "___46", - { - "str": "name" - } - ] - ], - [ - "___48", - [ - "__builtin_getattr__", - "o", - { - "str": "dst_data_op" - } - ] - ], - [ - "___47", - [ - "__builtin_getattr__", - "___48", - { - "str": "name" - } - ] - ], - [ - "___49", - [ - "__builtin_list__", - "___45", - "___47" - ] - ], - [ - "___50", - [ - "__builtin_getattr__", - "self", - { - "str": "src_data_op_name" - } - ] - ], - [ - "___51", - [ - "__builtin_getattr__", - "self", - { - "str": "dst_data_op_name" - } - ] - ], - [ - "___52", - [ - "__builtin_list__", - "___50", - "___51" - ] - ], - [ - "___53", - [ - "__builtin_EQ__", - "___49", - "___52" - ] - ], - [ - "___54", - [ - "__builtin_return__", - "___53" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___55", - [ - "__builtin_getattr__", - "constraint", - { - "str": "__function__" - } - ] - ], - [ - "___56", - [ - "__builtin_list__", - { - "str": "constraint" - }, - "___55" - ] - ], - [ - "result_pattern", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - null - ] - ] - ], - [ - "___57", - [ - "__builtin_getattr__", - "result_pattern", - { - "str": "__function__" - } - ] - ], - [ - "___58", - [ - "__builtin_list__", - { - "str": "result_pattern" - }, - "___57" - ] - ], - [ - "___59", - [ - "__builtin_list__" - ] - ], - [ - "___60", - [ - "__builtin_list__", - "___11", - "___44", - "___56", - "___58" - ] - ], - [ - "___61", - [ - "__builtin_PackedArgs__", - "___59", - "___60" - ] - ], - [ - "___62", - [ - "BuiltinSerializableAttrMap", - "___61" - ] - ], - [ - "___63", - [ - "type", - { - "str": "RemoveDataOpPairPass" - }, - "___1", - "___62" - ] - ], - [ - "RemoveDataOpPairPass", - [ - "__builtin_identity__", - "___63" - ] - ], - [ - "___64", - [ - "__builtin_getattr__", - "access_topo_drr", - { - "str": "DrrPass" - } - ] - ], - [ - "___65", - [ - "__builtin_list__", - "___64" - ] - ], - [ - "__init__", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "src_data_op_name", - "dst_data_op_name" - ], - [ - "__builtin_let__", - [ - [ - "___66", - [ - "__builtin_getattr__", - "pir", - { - "str": "a_str" - } - ] - ], - [ - "___67", - [ - "___66", - "src_data_op_name" - ] - ], - [ - "___68", - [ - "__builtin_setattr__", - "self", - { - "str": "src_data_op_name" - } - ] - ], - [ - "___69", - [ - "___68", - { - "str": "src_data_op_name" - }, - "___67" - ] - ], - [ - "___70", - [ - "__builtin_getattr__", - "pir", - { - "str": "a_str" - } - ] - ], - [ - "___71", - [ - "___70", - "dst_data_op_name" - ] - ], - [ - "___72", - [ - "__builtin_setattr__", - "self", - { - "str": "dst_data_op_name" - } - ] - ], - [ - "___73", - [ - "___72", - { - "str": "dst_data_op_name" - }, - "___71" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___74", - [ - "__builtin_getattr__", - "__init__", - { - "str": "__function__" - } - ] - ], - [ - "___75", - [ - "__builtin_list__", - { - "str": "__init__" - }, - "___74" - ] - ], - [ - "source_pattern", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - [ - "__builtin_let__", - [ - [ - "___76", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___77", - [ - "___76", - { - "str": "pd_op.data" - } - ] - ], - [ - "___78", - [ - "__builtin_setattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___79", - [ - "___78", - { - "str": "src_data_op" - }, - "___77" - ] - ], - [ - "___80", - [ - "__builtin_getattr__", - "self", - { - "str": "src_data_op_name" - } - ] - ], - [ - "___81", - [ - "__builtin_getattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___82", - [ - "__builtin_setattr__", - "___81", - { - "str": "name" - } - ] - ], - [ - "___83", - [ - "___82", - { - "str": "name" - }, - "___80" - ] - ], - [ - "___84", - [ - "__builtin_getattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___85", - [ - "__builtin_list__" - ] - ], - [ - "___86", - [ - "__builtin_getattr__", - "t", - { - "str": "input0" - } - ] - ], - [ - "___87", - [ - "__builtin_list__", - "___86" - ] - ], - [ - "___88", - [ - "___84", - "___85", - "___87" - ] - ], - [ - "___89", - [ - "__builtin_identity__", - "___88" - ] - ], - [ - "___90", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___91", - [ - "___90", - { - "str": "pd_op.full_int_array" - } - ] - ], - [ - "___92", - [ - "__builtin_setattr__", - "o", - { - "str": "full_int_array_op" - } - ] - ], - [ - "___93", - [ - "___92", - { - "str": "full_int_array_op" - }, - "___91" - ] - ], - [ - "___94", - [ - "__builtin_getattr__", - "o", - { - "str": "full_int_array_op" - } - ] - ], - [ - "___95", - [ - "__builtin_list__" - ] - ], - [ - "___96", - [ - "__builtin_getattr__", - "t", - { - "str": "axis" - } - ] - ], - [ - "___97", - [ - "__builtin_list__", - "___96" - ] - ], - [ - "___98", - [ - "___94", - "___95", - "___97" - ] - ], - [ - "___99", - [ - "__builtin_identity__", - "___98" - ] - ], - [ - "___100", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___101", - [ - "___100", - { - "str": "pd_op.sum" - } - ] - ], - [ - "___102", - [ - "__builtin_setattr__", - "o", - { - "str": "sum_op" - } - ] - ], - [ - "___103", - [ - "___102", - { - "str": "sum_op" - }, - "___101" - ] - ], - [ - "___104", - [ - "__builtin_getattr__", - "o", - { - "str": "sum_op" - } - ] - ], - [ - "___105", - [ - "__builtin_getattr__", - "t", - { - "str": "input0" - } - ] - ], - [ - "___106", - [ - "__builtin_getattr__", - "t", - { - "str": "axis" - } - ] - ], - [ - "___107", - [ - "__builtin_list__", - "___105", - "___106" - ] - ], - [ - "___108", - [ - "__builtin_getattr__", - "t", - { - "str": "sum_out" - } - ] - ], - [ - "___109", - [ - "__builtin_list__", - "___108" - ] - ], - [ - "___110", - [ - "___104", - "___107", - "___109" - ] - ], - [ - "___111", - [ - "__builtin_identity__", - "___110" - ] - ], - [ - "___112", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___113", - [ - "___112", - { - "str": "pd_op.data" - } - ] - ], - [ - "___114", - [ - "__builtin_setattr__", - "o", - { - "str": "dst_data_op" - } - ] - ], - [ - "___115", - [ - "___114", - { - "str": "dst_data_op" - }, - "___113" - ] - ], - [ - "___116", - [ - "__builtin_getattr__", - "self", - { - "str": "dst_data_op_name" - } - ] - ], - [ - "___117", - [ - "__builtin_getattr__", - "o", - { - "str": "dst_data_op" - } - ] - ], - [ - "___118", - [ - "__builtin_setattr__", - "___117", - { - "str": "name" - } - ] - ], - [ - "___119", - [ - "___118", - { - "str": "name" - }, - "___116" - ] - ], - [ - "___120", - [ - "__builtin_getattr__", - "o", - { - "str": "dst_data_op" - } - ] - ], - [ - "___121", - [ - "__builtin_list__" - ] - ], - [ - "___122", - [ - "__builtin_getattr__", - "t", - { - "str": "input1" - } - ] - ], - [ - "___123", - [ - "__builtin_list__", - "___122" - ] - ], - [ - "___124", - [ - "___120", - "___121", - "___123" - ] - ], - [ - "___125", - [ - "__builtin_identity__", - "___124" - ] - ], - [ - "___126", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___127", - [ - "___126", - { - "str": "ap_op.up_spider" - } - ] - ], - [ - "___128", - [ - "__builtin_setattr__", - "o", - { - "str": "up_spider_op" - } - ] - ], - [ - "___129", - [ - "___128", - { - "str": "up_spider_op" - }, - "___127" - ] - ], - [ - "___130", - [ - "__builtin_getattr__", - "o", - { - "str": "up_spider_op" - } - ] - ], - [ - "___131", - [ - "__builtin_getattr__", - "t", - { - "str": "sum_out" - } - ] - ], - [ - "___132", - [ - "__builtin_getattr__", - "t", - { - "str": "input1" - } - ] - ], - [ - "___133", - [ - "__builtin_list__", - "___131", - "___132" - ] - ], - [ - "___134", - [ - "__builtin_list__" - ] - ], - [ - "___135", - [ - "___130", - "___133", - "___134" - ] - ], - [ - "___136", - [ - "__builtin_identity__", - "___135" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___137", - [ - "__builtin_getattr__", - "source_pattern", - { - "str": "__function__" - } - ] - ], - [ - "___138", - [ - "__builtin_list__", - { - "str": "source_pattern" - }, - "___137" - ] - ], - [ - "result_pattern", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - null - ] - ] - ], - [ - "___139", - [ - "__builtin_getattr__", - "result_pattern", - { - "str": "__function__" - } - ] - ], - [ - "___140", - [ - "__builtin_list__", - { - "str": "result_pattern" - }, - "___139" - ] - ], - [ - "___141", - [ - "__builtin_list__" - ] - ], - [ - "___142", - [ - "__builtin_list__", - "___75", - "___138", - "___140" - ] - ], - [ - "___143", - [ - "__builtin_PackedArgs__", - "___141", - "___142" - ] - ], - [ - "___144", - [ - "BuiltinSerializableAttrMap", - "___143" - ] - ], - [ - "___145", - [ - "type", - { - "str": "RemoveDataOp2SumOp2DataOpPass" - }, - "___65", - "___144" - ] - ], - [ - "RemoveDataOp2SumOp2DataOpPass", - [ - "__builtin_identity__", - "___145" - ] - ], - [ - "___146", - [ - "__builtin_getattr__", - "access_topo_drr", - { - "str": "DrrPass" - } - ] - ], - [ - "___147", - [ - "__builtin_list__", - "___146" - ] - ], - [ - "__init__", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "src_data_op_name", - "dst_load_from_global_op_name" - ], - [ - "__builtin_let__", - [ - [ - "___148", - [ - "__builtin_getattr__", - "pir", - { - "str": "a_str" - } - ] - ], - [ - "___149", - [ - "___148", - "src_data_op_name" - ] - ], - [ - "___150", - [ - "__builtin_setattr__", - "self", - { - "str": "src_data_op_name" - } - ] - ], - [ - "___151", - [ - "___150", - { - "str": "src_data_op_name" - }, - "___149" - ] - ], - [ - "___152", - [ - "__builtin_getattr__", - "pir", - { - "str": "a_str" - } - ] - ], - [ - "___153", - [ - "___152", - "dst_load_from_global_op_name" - ] - ], - [ - "___154", - [ - "__builtin_setattr__", - "self", - { - "str": "dst_load_from_global_op_name" - } - ] - ], - [ - "___155", - [ - "___154", - { - "str": "dst_load_from_global_op_name" - }, - "___153" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___156", - [ - "__builtin_getattr__", - "__init__", - { - "str": "__function__" - } - ] - ], - [ - "___157", - [ - "__builtin_list__", - { - "str": "__init__" - }, - "___156" - ] - ], - [ - "source_pattern", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - [ - "__builtin_let__", - [ - [ - "___158", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___159", - [ - "___158", - { - "str": "pd_op.data" - } - ] - ], - [ - "___160", - [ - "__builtin_setattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___161", - [ - "___160", - { - "str": "src_data_op" - }, - "___159" - ] - ], - [ - "___162", - [ - "__builtin_getattr__", - "self", - { - "str": "src_data_op_name" - } - ] - ], - [ - "___163", - [ - "__builtin_getattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___164", - [ - "__builtin_setattr__", - "___163", - { - "str": "name" - } - ] - ], - [ - "___165", - [ - "___164", - { - "str": "name" - }, - "___162" - ] - ], - [ - "___166", - [ - "__builtin_getattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___167", - [ - "__builtin_list__" - ] - ], - [ - "___168", - [ - "__builtin_getattr__", - "t", - { - "str": "src_input" - } - ] - ], - [ - "___169", - [ - "__builtin_list__", - "___168" - ] - ], - [ - "___170", - [ - "___166", - "___167", - "___169" - ] - ], - [ - "___171", - [ - "__builtin_identity__", - "___170" - ] - ], - [ - "___172", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___173", - [ - "___172", - { - "str": "ap_op.load_from_global" - } - ] - ], - [ - "___174", - [ - "__builtin_setattr__", - "o", - { - "str": "dst_load_from_global_op" - } - ] - ], - [ - "___175", - [ - "___174", - { - "str": "dst_load_from_global_op" - }, - "___173" - ] - ], - [ - "___176", - [ - "__builtin_getattr__", - "self", - { - "str": "dst_load_from_global_op_name" - } - ] - ], - [ - "___177", - [ - "__builtin_getattr__", - "o", - { - "str": "dst_load_from_global_op" - } - ] - ], - [ - "___178", - [ - "__builtin_setattr__", - "___177", - { - "str": "index_func_unique_id" - } - ] - ], - [ - "___179", - [ - "___178", - { - "str": "index_func_unique_id" - }, - "___176" - ] - ], - [ - "___180", - [ - "__builtin_getattr__", - "o", - { - "str": "dst_load_from_global_op" - } - ] - ], - [ - "___181", - [ - "__builtin_getattr__", - "t", - { - "str": "dst_input" - } - ] - ], - [ - "___182", - [ - "__builtin_list__", - "___181" - ] - ], - [ - "___183", - [ - "__builtin_getattr__", - "t", - { - "str": "dst_load_from_global_output" - } - ] - ], - [ - "___184", - [ - "__builtin_list__", - "___183" - ] - ], - [ - "___185", - [ - "___180", - "___182", - "___184" - ] - ], - [ - "___186", - [ - "__builtin_identity__", - "___185" - ] - ], - [ - "___187", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___188", - [ - "___187", - { - "str": "ap_op.up_spider" - } - ] - ], - [ - "___189", - [ - "__builtin_setattr__", - "o", - { - "str": "up_spider_op" - } - ] - ], - [ - "___190", - [ - "___189", - { - "str": "up_spider_op" - }, - "___188" - ] - ], - [ - "___191", - [ - "__builtin_getattr__", - "o", - { - "str": "up_spider_op" - } - ] - ], - [ - "___192", - [ - "__builtin_getattr__", - "t", - { - "str": "src_input" - } - ] - ], - [ - "___193", - [ - "__builtin_getattr__", - "t", - { - "str": "dst_load_from_global_output" - } - ] - ], - [ - "___194", - [ - "__builtin_list__", - "___192", - "___193" - ] - ], - [ - "___195", - [ - "__builtin_list__" - ] - ], - [ - "___196", - [ - "___191", - "___194", - "___195" - ] - ], - [ - "___197", - [ - "__builtin_identity__", - "___196" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___198", - [ - "__builtin_getattr__", - "source_pattern", - { - "str": "__function__" - } - ] - ], - [ - "___199", - [ - "__builtin_list__", - { - "str": "source_pattern" - }, - "___198" - ] - ], - [ - "result_pattern", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - null - ] - ] - ], - [ - "___200", - [ - "__builtin_getattr__", - "result_pattern", - { - "str": "__function__" - } - ] - ], - [ - "___201", - [ - "__builtin_list__", - { - "str": "result_pattern" - }, - "___200" - ] - ], - [ - "___202", - [ - "__builtin_list__" - ] - ], - [ - "___203", - [ - "__builtin_list__", - "___157", - "___199", - "___201" - ] - ], - [ - "___204", - [ - "__builtin_PackedArgs__", - "___202", - "___203" - ] - ], - [ - "___205", - [ - "BuiltinSerializableAttrMap", - "___204" - ] - ], - [ - "___206", - [ - "type", - { - "str": "RemoveElementInputIndexPass" - }, - "___147", - "___205" - ] - ], - [ - "RemoveElementInputIndexPass", - [ - "__builtin_identity__", - "___206" - ] - ], - [ - "___207", - [ - "__builtin_getattr__", - "access_topo_drr", - { - "str": "DrrPass" - } - ] - ], - [ - "___208", - [ - "__builtin_list__", - "___207" - ] - ], - [ - "__init__", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "src_data_op_name", - "dst_load_from_global_op_name" - ], - [ - "__builtin_let__", - [ - [ - "___209", - [ - "__builtin_getattr__", - "pir", - { - "str": "a_str" - } - ] - ], - [ - "___210", - [ - "___209", - "src_data_op_name" - ] - ], - [ - "___211", - [ - "__builtin_setattr__", - "self", - { - "str": "src_data_op_name" - } - ] - ], - [ - "___212", - [ - "___211", - { - "str": "src_data_op_name" - }, - "___210" - ] - ], - [ - "___213", - [ - "__builtin_getattr__", - "pir", - { - "str": "a_str" - } - ] - ], - [ - "___214", - [ - "___213", - "dst_load_from_global_op_name" - ] - ], - [ - "___215", - [ - "__builtin_setattr__", - "self", - { - "str": "dst_load_from_global_op_name" - } - ] - ], - [ - "___216", - [ - "___215", - { - "str": "dst_load_from_global_op_name" - }, - "___214" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___217", - [ - "__builtin_getattr__", - "__init__", - { - "str": "__function__" - } - ] - ], - [ - "___218", - [ - "__builtin_list__", - { - "str": "__init__" - }, - "___217" - ] - ], - [ - "source_pattern", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - [ - "__builtin_let__", - [ - [ - "___219", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___220", - [ - "___219", - { - "str": "pd_op.data" - } - ] - ], - [ - "___221", - [ - "__builtin_setattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___222", - [ - "___221", - { - "str": "src_data_op" - }, - "___220" - ] - ], - [ - "___223", - [ - "__builtin_getattr__", - "self", - { - "str": "src_data_op_name" - } - ] - ], - [ - "___224", - [ - "__builtin_getattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___225", - [ - "__builtin_setattr__", - "___224", - { - "str": "name" - } - ] - ], - [ - "___226", - [ - "___225", - { - "str": "name" - }, - "___223" - ] - ], - [ - "___227", - [ - "__builtin_getattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___228", - [ - "__builtin_list__" - ] - ], - [ - "___229", - [ - "__builtin_getattr__", - "t", - { - "str": "input0" - } - ] - ], - [ - "___230", - [ - "__builtin_list__", - "___229" - ] - ], - [ - "___231", - [ - "___227", - "___228", - "___230" - ] - ], - [ - "___232", - [ - "__builtin_identity__", - "___231" - ] - ], - [ - "___233", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___234", - [ - "___233", - { - "str": "pd_op.full_int_array" - } - ] - ], - [ - "___235", - [ - "__builtin_setattr__", - "o", - { - "str": "full_int_array_op" - } - ] - ], - [ - "___236", - [ - "___235", - { - "str": "full_int_array_op" - }, - "___234" - ] - ], - [ - "___237", - [ - "__builtin_getattr__", - "o", - { - "str": "full_int_array_op" - } - ] - ], - [ - "___238", - [ - "__builtin_list__" - ] - ], - [ - "___239", - [ - "__builtin_getattr__", - "t", - { - "str": "axis" - } - ] - ], - [ - "___240", - [ - "__builtin_list__", - "___239" - ] - ], - [ - "___241", - [ - "___237", - "___238", - "___240" - ] - ], - [ - "___242", - [ - "__builtin_identity__", - "___241" - ] - ], - [ - "___243", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___244", - [ - "___243", - { - "str": "pd_op.sum" - } - ] - ], - [ - "___245", - [ - "__builtin_setattr__", - "o", - { - "str": "sum_op" - } - ] - ], - [ - "___246", - [ - "___245", - { - "str": "sum_op" - }, - "___244" - ] - ], - [ - "___247", - [ - "__builtin_getattr__", - "o", - { - "str": "sum_op" - } - ] - ], - [ - "___248", - [ - "__builtin_getattr__", - "t", - { - "str": "input0" - } - ] - ], - [ - "___249", - [ - "__builtin_getattr__", - "t", - { - "str": "axis" - } - ] - ], - [ - "___250", - [ - "__builtin_list__", - "___248", - "___249" - ] - ], - [ - "___251", - [ - "__builtin_getattr__", - "t", - { - "str": "sum_out" - } - ] - ], - [ - "___252", - [ - "__builtin_list__", - "___251" - ] - ], - [ - "___253", - [ - "___247", - "___250", - "___252" - ] - ], - [ - "___254", - [ - "__builtin_identity__", - "___253" - ] - ], - [ - "___255", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___256", - [ - "___255", - { - "str": "ap_op.load_from_global" - } - ] - ], - [ - "___257", - [ - "__builtin_setattr__", - "o", - { - "str": "dst_load_from_global_op" - } - ] - ], - [ - "___258", - [ - "___257", - { - "str": "dst_load_from_global_op" - }, - "___256" - ] - ], - [ - "___259", - [ - "__builtin_getattr__", - "self", - { - "str": "dst_load_from_global_op_name" - } - ] - ], - [ - "___260", - [ - "__builtin_getattr__", - "o", - { - "str": "dst_load_from_global_op" - } - ] - ], - [ - "___261", - [ - "__builtin_setattr__", - "___260", - { - "str": "index_func_unique_id" - } - ] - ], - [ - "___262", - [ - "___261", - { - "str": "index_func_unique_id" - }, - "___259" - ] - ], - [ - "___263", - [ - "__builtin_getattr__", - "o", - { - "str": "dst_load_from_global_op" - } - ] - ], - [ - "___264", - [ - "__builtin_getattr__", - "t", - { - "str": "dst_input" - } - ] - ], - [ - "___265", - [ - "__builtin_list__", - "___264" - ] - ], - [ - "___266", - [ - "__builtin_getattr__", - "t", - { - "str": "dst_load_from_global_output" - } - ] - ], - [ - "___267", - [ - "__builtin_list__", - "___266" - ] - ], - [ - "___268", - [ - "___263", - "___265", - "___267" - ] - ], - [ - "___269", - [ - "__builtin_identity__", - "___268" - ] - ], - [ - "___270", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___271", - [ - "___270", - { - "str": "ap_op.up_spider" - } - ] - ], - [ - "___272", - [ - "__builtin_setattr__", - "o", - { - "str": "up_spider_op" - } - ] - ], - [ - "___273", - [ - "___272", - { - "str": "up_spider_op" - }, - "___271" - ] - ], - [ - "___274", - [ - "__builtin_getattr__", - "o", - { - "str": "up_spider_op" - } - ] - ], - [ - "___275", - [ - "__builtin_getattr__", - "t", - { - "str": "sum_out" - } - ] - ], - [ - "___276", - [ - "__builtin_getattr__", - "t", - { - "str": "dst_load_from_global_output" - } - ] - ], - [ - "___277", - [ - "__builtin_list__", - "___275", - "___276" - ] - ], - [ - "___278", - [ - "__builtin_list__" - ] - ], - [ - "___279", - [ - "___274", - "___277", - "___278" - ] - ], - [ - "___280", - [ - "__builtin_identity__", - "___279" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___281", - [ - "__builtin_getattr__", - "source_pattern", - { - "str": "__function__" - } - ] - ], - [ - "___282", - [ - "__builtin_list__", - { - "str": "source_pattern" - }, - "___281" - ] - ], - [ - "result_pattern", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - null - ] - ] - ], - [ - "___283", - [ - "__builtin_getattr__", - "result_pattern", - { - "str": "__function__" - } - ] - ], - [ - "___284", - [ - "__builtin_list__", - { - "str": "result_pattern" - }, - "___283" - ] - ], - [ - "___285", - [ - "__builtin_list__" - ] - ], - [ - "___286", - [ - "__builtin_list__", - "___218", - "___282", - "___284" - ] - ], - [ - "___287", - [ - "__builtin_PackedArgs__", - "___285", - "___286" - ] - ], - [ - "___288", - [ - "BuiltinSerializableAttrMap", - "___287" - ] - ], - [ - "___289", - [ - "type", - { - "str": "RemoveBroadcastInputIndexPass" - }, - "___208", - "___288" - ] - ], - [ - "RemoveBroadcastInputIndexPass", - [ - "__builtin_identity__", - "___289" - ] - ], - [ - "___290", - [ - "__builtin_getattr__", - "access_topo_drr", - { - "str": "DrrPass" - } - ] - ], - [ - "___291", - [ - "__builtin_list__", - "___290" - ] - ], - [ - "__init__", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "src_data_op_name", - "dst_store_to_global_op_name" - ], - [ - "__builtin_let__", - [ - [ - "___292", - [ - "__builtin_getattr__", - "pir", - { - "str": "a_str" - } - ] - ], - [ - "___293", - [ - "___292", - "src_data_op_name" - ] - ], - [ - "___294", - [ - "__builtin_setattr__", - "self", - { - "str": "src_data_op_name" - } - ] - ], - [ - "___295", - [ - "___294", - { - "str": "src_data_op_name" - }, - "___293" - ] - ], - [ - "___296", - [ - "__builtin_getattr__", - "pir", - { - "str": "a_str" - } - ] - ], - [ - "___297", - [ - "___296", - "dst_store_to_global_op_name" - ] - ], - [ - "___298", - [ - "__builtin_setattr__", - "self", - { - "str": "dst_store_to_global_op_name" - } - ] - ], - [ - "___299", - [ - "___298", - { - "str": "dst_store_to_global_op_name" - }, - "___297" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___300", - [ - "__builtin_getattr__", - "__init__", - { - "str": "__function__" - } - ] - ], - [ - "___301", - [ - "__builtin_list__", - { - "str": "__init__" - }, - "___300" - ] - ], - [ - "source_pattern", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - [ - "__builtin_let__", - [ - [ - "___302", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___303", - [ - "___302", - { - "str": "pd_op.data" - } - ] - ], - [ - "___304", - [ - "__builtin_setattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___305", - [ - "___304", - { - "str": "src_data_op" - }, - "___303" - ] - ], - [ - "___306", - [ - "__builtin_getattr__", - "self", - { - "str": "src_data_op_name" - } - ] - ], - [ - "___307", - [ - "__builtin_getattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___308", - [ - "__builtin_setattr__", - "___307", - { - "str": "name" - } - ] - ], - [ - "___309", - [ - "___308", - { - "str": "name" - }, - "___306" - ] - ], - [ - "___310", - [ - "__builtin_getattr__", - "o", - { - "str": "src_data_op" - } - ] - ], - [ - "___311", - [ - "__builtin_list__" - ] - ], - [ - "___312", - [ - "__builtin_getattr__", - "t", - { - "str": "src_input" - } - ] - ], - [ - "___313", - [ - "__builtin_list__", - "___312" - ] - ], - [ - "___314", - [ - "___310", - "___311", - "___313" - ] - ], - [ - "___315", - [ - "__builtin_identity__", - "___314" - ] - ], - [ - "___316", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___317", - [ - "___316", - { - "str": "ap_op.store_to_global" - } - ] - ], - [ - "___318", - [ - "__builtin_setattr__", - "o", - { - "str": "dst_store_to_global_op" - } - ] - ], - [ - "___319", - [ - "___318", - { - "str": "dst_store_to_global_op" - }, - "___317" - ] - ], - [ - "___320", - [ - "__builtin_getattr__", - "self", - { - "str": "dst_store_to_global_op_name" - } - ] - ], - [ - "___321", - [ - "__builtin_getattr__", - "o", - { - "str": "dst_store_to_global_op" - } - ] - ], - [ - "___322", - [ - "__builtin_setattr__", - "___321", - { - "str": "index_func_unique_id" - } - ] - ], - [ - "___323", - [ - "___322", - { - "str": "index_func_unique_id" - }, - "___320" - ] - ], - [ - "___324", - [ - "__builtin_getattr__", - "o", - { - "str": "dst_store_to_global_op" - } - ] - ], - [ - "___325", - [ - "__builtin_getattr__", - "t", - { - "str": "dst_output" - } - ] - ], - [ - "___326", - [ - "__builtin_getattr__", - "t", - { - "str": "dst_output_val" - } - ] - ], - [ - "___327", - [ - "__builtin_list__", - "___325", - "___326" - ] - ], - [ - "___328", - [ - "__builtin_list__" - ] - ], - [ - "___329", - [ - "___324", - "___327", - "___328" - ] - ], - [ - "___330", - [ - "__builtin_identity__", - "___329" - ] - ], - [ - "___331", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___332", - [ - "___331", - { - "str": "ap_op.up_spider" - } - ] - ], - [ - "___333", - [ - "__builtin_setattr__", - "o", - { - "str": "up_spider_op" - } - ] - ], - [ - "___334", - [ - "___333", - { - "str": "up_spider_op" - }, - "___332" - ] - ], - [ - "___335", - [ - "__builtin_getattr__", - "o", - { - "str": "up_spider_op" - } - ] - ], - [ - "___336", - [ - "__builtin_getattr__", - "t", - { - "str": "src_input" - } - ] - ], - [ - "___337", - [ - "__builtin_getattr__", - "t", - { - "str": "dst_output_val" - } - ] - ], - [ - "___338", - [ - "__builtin_list__", - "___336", - "___337" - ] - ], - [ - "___339", - [ - "__builtin_list__" - ] - ], - [ - "___340", - [ - "___335", - "___338", - "___339" - ] - ], - [ - "___341", - [ - "__builtin_identity__", - "___340" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___342", - [ - "__builtin_getattr__", - "source_pattern", - { - "str": "__function__" - } - ] - ], - [ - "___343", - [ - "__builtin_list__", - { - "str": "source_pattern" - }, - "___342" - ] - ], - [ - "result_pattern", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - null - ] - ] - ], - [ - "___344", - [ - "__builtin_getattr__", - "result_pattern", - { - "str": "__function__" - } - ] - ], - [ - "___345", - [ - "__builtin_list__", - { - "str": "result_pattern" - }, - "___344" - ] - ], - [ - "___346", - [ - "__builtin_list__" - ] - ], - [ - "___347", - [ - "__builtin_list__", - "___301", - "___343", - "___345" - ] - ], - [ - "___348", - [ - "__builtin_PackedArgs__", - "___346", - "___347" - ] - ], - [ - "___349", - [ - "BuiltinSerializableAttrMap", - "___348" - ] - ], - [ - "___350", - [ - "type", - { - "str": "RemoveOutputIndexPass" - }, - "___291", - "___349" - ] - ], - [ - "RemoveOutputIndexPass", - [ - "__builtin_identity__", - "___350" - ] - ], - [ - "___351", - [ - "__builtin_getattr__", - "abstract_drr", - { - "str": "DrrPass" - } - ] - ], - [ - "___352", - [ - "__builtin_list__", - "___351" - ] - ], - [ - "source_pattern", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - [ - "__builtin_let__", - [ - [ - "___353", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_native_op" - } - ] - ], - [ - "___354", - [ - "___353", - { - "str": "pd_op.matmul" - } - ] - ], - [ - "___355", - [ - "__builtin_setattr__", - "o", - { - "str": "matmul_op" - } - ] - ], - [ - "___356", - [ - "___355", - { - "str": "matmul_op" - }, - "___354" - ] - ], - [ - "___357", - [ - "__builtin_getattr__", - "o", - { - "str": "matmul_op" - } - ] - ], - [ - "___358", - [ - "__builtin_getattr__", - "t", - { - "str": "input0" - } - ] - ], - [ - "___359", - [ - "__builtin_getattr__", - "t", - { - "str": "input1" - } - ] - ], - [ - "___360", - [ - "__builtin_list__", - "___358", - "___359" - ] - ], - [ - "___361", - [ - "__builtin_getattr__", - "t", - { - "str": "mm_out" - } - ] - ], - [ - "___362", - [ - "__builtin_list__", - "___361" - ] - ], - [ - "___363", - [ - "___357", - "___360", - "___362" - ] - ], - [ - "___364", - [ - "__builtin_identity__", - "___363" - ] - ], - [ - "___365", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_trivial_fusion_op" - } - ] - ], - [ - "___366", - [ - "___365" - ] - ], - [ - "___367", - [ - "__builtin_setattr__", - "o", - { - "str": "trivial_op" - } - ] - ], - [ - "___368", - [ - "___367", - { - "str": "trivial_op" - }, - "___366" - ] - ], - [ - "___369", - [ - "__builtin_getattr__", - "o", - { - "str": "trivial_op" - } - ] - ], - [ - "___370", - [ - "__builtin_getattr__", - "t", - { - "str": "mm_out" - } - ] - ], - [ - "___371", - [ - "__builtin_getattr__", - "t", - { - "str": "input2" - } - ] - ], - [ - "___372", - [ - "__builtin_list__", - "___370", - "___371" - ] - ], - [ - "___373", - [ - "__builtin_getattr__", - "t", - { - "str": "output" - } - ] - ], - [ - "___374", - [ - "__builtin_list__", - "___373" - ] - ], - [ - "___375", - [ - "___369", - "___372", - "___374" - ] - ], - [ - "___376", - [ - "__builtin_identity__", - "___375" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___377", - [ - "__builtin_getattr__", - "source_pattern", - { - "str": "__function__" - } - ] - ], - [ - "___378", - [ - "__builtin_list__", - { - "str": "source_pattern" - }, - "___377" - ] - ], - [ - "result_pattern", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - [ - "__builtin_let__", - [ - [ - "___379", - [ - "__builtin_getattr__", - "o", - { - "str": "ap_pattern_fusion_op" - } - ] - ], - [ - "___380", - [ - "__builtin_getattr__", - "self", - { - "str": "code_gen" - } - ] - ], - [ - "___381", - [ - "___379", - "___380" - ] - ], - [ - "___382", - [ - "__builtin_setattr__", - "o", - { - "str": "fustion_op" - } - ] - ], - [ - "___383", - [ - "___382", - { - "str": "fustion_op" - }, - "___381" - ] - ], - [ - "___384", - [ - "__builtin_getattr__", - "o", - { - "str": "fustion_op" - } - ] - ], - [ - "___385", - [ - "__builtin_getattr__", - "t", - { - "str": "input0" - } - ] - ], - [ - "___386", - [ - "__builtin_getattr__", - "t", - { - "str": "input1" - } - ] - ], - [ - "___387", - [ - "__builtin_getattr__", - "t", - { - "str": "input2" - } - ] - ], - [ - "___388", - [ - "__builtin_list__", - "___385", - "___386", - "___387" - ] - ], - [ - "___389", - [ - "__builtin_getattr__", - "t", - { - "str": "output" - } - ] - ], - [ - "___390", - [ - "__builtin_list__", - "___389" - ] - ], - [ - "___391", - [ - "___384", - "___388", - "___390" - ] - ], - [ - "___392", - [ - "__builtin_identity__", - "___391" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___393", - [ - "__builtin_getattr__", - "result_pattern", - { - "str": "__function__" - } - ] - ], - [ - "___394", - [ - "__builtin_list__", - { - "str": "result_pattern" - }, - "___393" - ] - ], - [ - "constraint", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "o", - "t" - ], - [ - "__builtin_let__", - [ - [ - "___395", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "copy_fused_ops_to_program" - } - ] - ], - [ - "___396", - [ - "__builtin_getattr__", - "o", - { - "str": "trivial_op" - } - ] - ], - [ - "___397", - [ - "__builtin_list__", - { - "str": "tensor_match_ctx" - }, - "t" - ] - ], - [ - "___398", - [ - "__builtin_list__", - "___397" - ] - ], - [ - "___399", - [ - "__builtin_list__", - "___396" - ] - ], - [ - "___400", - [ - "__builtin_PackedArgs__", - "___399", - "___398" - ] - ], - [ - "___401", - [ - "___395", - "___400" - ] - ], - [ - "program", - [ - "__builtin_identity__", - "___401" - ] - ], - [ - "___402", - [ - "print", - { - "str": "before-access_topo_pass" - }, - "program" - ] - ], - [ - "___403", - [ - "__builtin_identity__", - "___402" - ] - ], - [ - "___404", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_pass_manager" - } - ] - ], - [ - "___405", - [ - "___404" - ] - ], - [ - "init_pass_manager", - [ - "__builtin_identity__", - "___405" - ] - ], - [ - "___406", - [ - "__builtin_getattr__", - "topo_drr_pass", - { - "str": "InitDownSpiderAccessTopoPass" - } - ] - ], - [ - "___407", - [ - "___406", - { - "str": "mm_out" - } - ] - ], - [ - "init_down_spider", - [ - "__builtin_identity__", - "___407" - ] - ], - [ - "___408", - [ - "__builtin_getattr__", - "init_pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___409", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___410", - [ - "___409", - "init_down_spider" - ] - ], - [ - "___411", - [ - "___408", - "___410" - ] - ], - [ - "___412", - [ - "__builtin_identity__", - "___411" - ] - ], - [ - "___413", - [ - "__builtin_getattr__", - "topo_drr_pass", - { - "str": "FakeDataForYieldAccessTopoPass" - } - ] - ], - [ - "___414", - [ - "__builtin_list__", - { - "str": "output" - } - ] - ], - [ - "___415", - [ - "___413", - "___414" - ] - ], - [ - "init_fake_data_for_yield_input", - [ - "__builtin_identity__", - "___415" - ] - ], - [ - "___416", - [ - "__builtin_getattr__", - "init_pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___417", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___418", - [ - "___417", - "init_fake_data_for_yield_input" - ] - ], - [ - "___419", - [ - "___416", - "___418" - ] - ], - [ - "___420", - [ - "__builtin_identity__", - "___419" - ] - ], - [ - "___421", - [ - "__builtin_getattr__", - "init_pass_manager", - { - "str": "run" - } - ] - ], - [ - "___422", - [ - "___421", - "program" - ] - ], - [ - "___423", - [ - "__builtin_identity__", - "___422" - ] - ], - [ - "___424", - [ - "print", - { - "str": "after-init-access_topo_pass" - }, - "program" - ] - ], - [ - "___425", - [ - "__builtin_identity__", - "___424" - ] - ], - [ - "___426", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_pass_manager" - } - ] - ], - [ - "___427", - [ - "___426" - ] - ], - [ - "pass_manager", - [ - "__builtin_identity__", - "___427" - ] - ], - [ - "___428", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___429", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_pass" - } - ] - ], - [ - "___430", - [ - "___429", - { - "str": "default" - } - ] - ], - [ - "___431", - [ - "___428", - "___430" - ] - ], - [ - "___432", - [ - "__builtin_identity__", - "___431" - ] - ], - [ - "___433", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___434", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_dce_pass" - } - ] - ], - [ - "___435", - [ - "___434" - ] - ], - [ - "___436", - [ - "___433", - "___435" - ] - ], - [ - "___437", - [ - "__builtin_identity__", - "___436" - ] - ], - [ - "___438", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "run" - } - ] - ], - [ - "___439", - [ - "___438", - "program" - ] - ], - [ - "___440", - [ - "__builtin_identity__", - "___439" - ] - ], - [ - "___441", - [ - "print", - { - "str": "after-apply-access_topo_pass" - }, - "program" - ] - ], - [ - "___442", - [ - "__builtin_identity__", - "___441" - ] - ], - [ - "___443", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_pass_manager" - } - ] - ], - [ - "___444", - [ - "___443" - ] - ], - [ - "pass_manager", - [ - "__builtin_identity__", - "___444" - ] - ], - [ - "___445", - [ - "__builtin_list__", - { - "str": "src_data_op_name" - }, - { - "str": "mm_out" - } - ] - ], - [ - "___446", - [ - "__builtin_list__", - { - "str": "dst_data_op_name" - }, - { - "str": "input2" - } - ] - ], - [ - "___447", - [ - "__builtin_list__", - "___445", - "___446" - ] - ], - [ - "___448", - [ - "__builtin_list__" - ] - ], - [ - "___449", - [ - "__builtin_PackedArgs__", - "___448", - "___447" - ] - ], - [ - "___450", - [ - "RemoveDataOpPairPass", - "___449" - ] - ], - [ - "remove_data_op_pair_pass", - [ - "__builtin_identity__", - "___450" - ] - ], - [ - "___451", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___452", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___453", - [ - "___452", - "remove_data_op_pair_pass" - ] - ], - [ - "___454", - [ - "___451", - "___453" - ] - ], - [ - "___455", - [ - "__builtin_identity__", - "___454" - ] - ], - [ - "___456", - [ - "__builtin_list__", - { - "str": "src_data_op_name" - }, - { - "str": "mm_out" - } - ] - ], - [ - "___457", - [ - "__builtin_list__", - { - "str": "dst_data_op_name" - }, - { - "str": "input2" - } - ] - ], - [ - "___458", - [ - "__builtin_list__", - "___456", - "___457" - ] - ], - [ - "___459", - [ - "__builtin_list__" - ] - ], - [ - "___460", - [ - "__builtin_PackedArgs__", - "___459", - "___458" - ] - ], - [ - "___461", - [ - "RemoveDataOp2SumOp2DataOpPass", - "___460" - ] - ], - [ - "remove_data_op2sum_op2data_op_pass", - [ - "__builtin_identity__", - "___461" - ] - ], - [ - "___462", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___463", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___464", - [ - "___463", - "remove_data_op2sum_op2data_op_pass" - ] - ], - [ - "___465", - [ - "___462", - "___464" - ] - ], - [ - "___466", - [ - "__builtin_identity__", - "___465" - ] - ], - [ - "___467", - [ - "__builtin_list__", - { - "str": "src_data_op_name" - }, - { - "str": "mm_out" - } - ] - ], - [ - "___468", - [ - "__builtin_list__", - { - "str": "dst_data_op_name" - }, - { - "str": "output" - } - ] - ], - [ - "___469", - [ - "__builtin_list__", - "___467", - "___468" - ] - ], - [ - "___470", - [ - "__builtin_list__" - ] - ], - [ - "___471", - [ - "__builtin_PackedArgs__", - "___470", - "___469" - ] - ], - [ - "___472", - [ - "RemoveDataOpPairPass", - "___471" - ] - ], - [ - "remove_output_pass", - [ - "__builtin_identity__", - "___472" - ] - ], - [ - "___473", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___474", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___475", - [ - "___474", - "remove_output_pass" - ] - ], - [ - "___476", - [ - "___473", - "___475" - ] - ], - [ - "___477", - [ - "__builtin_identity__", - "___476" - ] - ], - [ - "___478", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___479", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_dce_pass" - } - ] - ], - [ - "___480", - [ - "___479" - ] - ], - [ - "___481", - [ - "___478", - "___480" - ] - ], - [ - "___482", - [ - "__builtin_identity__", - "___481" - ] - ], - [ - "___483", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "run" - } - ] - ], - [ - "___484", - [ - "___483", - "program" - ] - ], - [ - "___485", - [ - "__builtin_identity__", - "___484" - ] - ], - [ - "___486", - [ - "print", - { - "str": "after-remove-input-output-access_topo_pass" - }, - "program" - ] - ], - [ - "___487", - [ - "__builtin_identity__", - "___486" - ] - ], - [ - "___488", - [ - "__builtin_getattr__", - "program", - { - "str": "empty" - } - ] - ], - [ - "___489", - [ - "___488" - ] - ], - [ - "___490", - [ - "__builtin_return__", - "___489" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___491", - [ - "__builtin_getattr__", - "constraint", - { - "str": "__function__" - } - ] - ], - [ - "___492", - [ - "__builtin_list__", - { - "str": "constraint" - }, - "___491" - ] - ], - [ - "_insert_load_from_global", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "program", - "input_names" - ], - [ - "__builtin_let__", - [ - [ - "___493", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_pass_manager" - } - ] - ], - [ - "___494", - [ - "___493" - ] - ], - [ - "init_pass_manager", - [ - "__builtin_identity__", - "___494" - ] - ], - [ - "AddPass", - [ - "__builtin_identity__", - [ - "lambda", - [ - "input_name" - ], - [ - "__builtin_let__", - [ - [ - "___495", - [ - "__builtin_getattr__", - "topo_drr_pass", - { - "str": "InitNaiveLoadFromGlobalAccessTopoPass" - } - ] - ], - [ - "___496", - [ - "___495", - "input_name" - ] - ], - [ - "ir_pass", - [ - "__builtin_identity__", - "___496" - ] - ], - [ - "___497", - [ - "__builtin_getattr__", - "init_pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___498", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___499", - [ - "___498", - "ir_pass" - ] - ], - [ - "___500", - [ - "___497", - "___499" - ] - ], - [ - "___501", - [ - "__builtin_identity__", - "___500" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___502", - [ - "map", - "AddPass", - "input_names" - ] - ], - [ - "___503", - [ - "__builtin_identity__", - "___502" - ] - ], - [ - "___504", - [ - "__builtin_getattr__", - "init_pass_manager", - { - "str": "run" - } - ] - ], - [ - "___505", - [ - "___504", - "program" - ] - ], - [ - "___506", - [ - "__builtin_identity__", - "___505" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___507", - [ - "__builtin_getattr__", - "_insert_load_from_global", - { - "str": "__function__" - } - ] - ], - [ - "___508", - [ - "__builtin_list__", - { - "str": "_insert_load_from_global" - }, - "___507" - ] - ], - [ - "_insert_store_to_global", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "program", - "output_names" - ], - [ - "__builtin_let__", - [ - [ - "___509", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_pass_manager" - } - ] - ], - [ - "___510", - [ - "___509" - ] - ], - [ - "init_pass_manager", - [ - "__builtin_identity__", - "___510" - ] - ], - [ - "___511", - [ - "__builtin_getattr__", - "topo_drr_pass", - { - "str": "FakeDataStoreToGlobalForYieldAccessTopoPass" - } - ] - ], - [ - "___512", - [ - "___511", - "output_names" - ] - ], - [ - "ir_pass", - [ - "__builtin_identity__", - "___512" - ] - ], - [ - "___513", - [ - "__builtin_getattr__", - "init_pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___514", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___515", - [ - "___514", - "ir_pass" - ] - ], - [ - "___516", - [ - "___513", - "___515" - ] - ], - [ - "___517", - [ - "__builtin_identity__", - "___516" - ] - ], - [ - "___518", - [ - "__builtin_getattr__", - "init_pass_manager", - { - "str": "run" - } - ] - ], - [ - "___519", - [ - "___518", - "program" - ] - ], - [ - "___520", - [ - "__builtin_identity__", - "___519" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___521", - [ - "__builtin_getattr__", - "_insert_store_to_global", - { - "str": "__function__" - } - ] - ], - [ - "___522", - [ - "__builtin_list__", - { - "str": "_insert_store_to_global" - }, - "___521" - ] - ], - [ - "_make_kernel_arg_translator", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self" - ], - [ - "__builtin_let__", - [ - [ - "___523", - [ - "__builtin_getattr__", - "matmul_binary_tpl", - { - "str": "make_kernel_arg_translator" - } - ] - ], - [ - "___524", - [ - "___523" - ] - ], - [ - "___525", - [ - "__builtin_return__", - "___524" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___526", - [ - "__builtin_getattr__", - "_make_kernel_arg_translator", - { - "str": "__function__" - } - ] - ], - [ - "___527", - [ - "__builtin_list__", - { - "str": "_make_kernel_arg_translator" - }, - "___526" - ] - ], - [ - "_apply_topo_access_passes", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "mut_program", - "anchor_data_op_name" - ], - [ - "__builtin_let__", - [ - [ - "___528", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_pass_manager" - } - ] - ], - [ - "___529", - [ - "___528" - ] - ], - [ - "init_pass_manager", - [ - "__builtin_identity__", - "___529" - ] - ], - [ - "___530", - [ - "__builtin_getattr__", - "topo_drr_pass", - { - "str": "InitDownSpiderAccessTopoPass" - } - ] - ], - [ - "___531", - [ - "___530", - "anchor_data_op_name" - ] - ], - [ - "init_down_spider", - [ - "__builtin_identity__", - "___531" - ] - ], - [ - "___532", - [ - "__builtin_getattr__", - "init_pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___533", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___534", - [ - "___533", - "init_down_spider" - ] - ], - [ - "___535", - [ - "___532", - "___534" - ] - ], - [ - "___536", - [ - "__builtin_identity__", - "___535" - ] - ], - [ - "___537", - [ - "__builtin_getattr__", - "init_pass_manager", - { - "str": "run" - } - ] - ], - [ - "___538", - [ - "___537", - "mut_program" - ] - ], - [ - "___539", - [ - "__builtin_identity__", - "___538" - ] - ], - [ - "___540", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_pass_manager" - } - ] - ], - [ - "___541", - [ - "___540" - ] - ], - [ - "pass_manager", - [ - "__builtin_identity__", - "___541" - ] - ], - [ - "___542", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___543", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_pass" - } - ] - ], - [ - "___544", - [ - "___543", - { - "str": "default" - } - ] - ], - [ - "___545", - [ - "___542", - "___544" - ] - ], - [ - "___546", - [ - "__builtin_identity__", - "___545" - ] - ], - [ - "___547", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___548", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_dce_pass" - } - ] - ], - [ - "___549", - [ - "___548" - ] - ], - [ - "___550", - [ - "___547", - "___549" - ] - ], - [ - "___551", - [ - "__builtin_identity__", - "___550" - ] - ], - [ - "___552", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "run" - } - ] - ], - [ - "___553", - [ - "___552", - "mut_program" - ] - ], - [ - "___554", - [ - "__builtin_identity__", - "___553" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___555", - [ - "__builtin_getattr__", - "_apply_topo_access_passes", - { - "str": "__function__" - } - ] - ], - [ - "___556", - [ - "__builtin_list__", - { - "str": "_apply_topo_access_passes" - }, - "___555" - ] - ], - [ - "_simplify_index_program", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "mut_program" - ], - [ - "__builtin_let__", - [ - [ - "___557", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_pass_manager" - } - ] - ], - [ - "___558", - [ - "___557" - ] - ], - [ - "pass_manager", - [ - "__builtin_identity__", - "___558" - ] - ], - [ - "___559", - [ - "__builtin_getattr__", - "topo_drr_pass", - { - "str": "ConvertUpSpiderStoreDataOpToYieldOpPass" - } - ] - ], - [ - "___560", - [ - "___559" - ] - ], - [ - "drr_pass", - [ - "__builtin_identity__", - "___560" - ] - ], - [ - "___561", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___562", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___563", - [ - "___562", - "drr_pass" - ] - ], - [ - "___564", - [ - "___561", - "___563" - ] - ], - [ - "___565", - [ - "__builtin_identity__", - "___564" - ] - ], - [ - "___566", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___567", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_dce_pass" - } - ] - ], - [ - "___568", - [ - "___567" - ] - ], - [ - "___569", - [ - "___566", - "___568" - ] - ], - [ - "___570", - [ - "__builtin_identity__", - "___569" - ] - ], - [ - "___571", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "run" - } - ] - ], - [ - "___572", - [ - "___571", - "mut_program" - ] - ], - [ - "___573", - [ - "__builtin_identity__", - "___572" - ] - ], - [ - "___574", - [ - "__builtin_return__", - "mut_program" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___575", - [ - "__builtin_getattr__", - "_simplify_index_program", - { - "str": "__function__" - } - ] - ], - [ - "___576", - [ - "__builtin_list__", - { - "str": "_simplify_index_program" - }, - "___575" - ] - ], - [ - "_make_index_func_unique_id2index_program", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "compute_program", - "anchor_data_op_name", - "input_names", - "output_names" - ], - [ - "__builtin_let__", - [ - [ - "___577", - [ - "__builtin_getattr__", - "compute_program", - { - "str": "clone" - } - ] - ], - [ - "___578", - [ - "___577" - ] - ], - [ - "full_index_program", - [ - "__builtin_identity__", - "___578" - ] - ], - [ - "___579", - [ - "__builtin_getattr__", - "self", - { - "str": "_apply_topo_access_passes" - } - ] - ], - [ - "___580", - [ - "___579", - "full_index_program", - "anchor_data_op_name" - ] - ], - [ - "___581", - [ - "__builtin_identity__", - "___580" - ] - ], - [ - "MatchAndCopyInputIndex", - [ - "__builtin_identity__", - [ - "lambda", - [ - "dst_input_name" - ], - [ - "__builtin_let__", - [ - [ - "___582", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_pass_manager" - } - ] - ], - [ - "___583", - [ - "___582" - ] - ], - [ - "pass_manager", - [ - "__builtin_identity__", - "___583" - ] - ], - [ - "___584", - [ - "MutableList" - ] - ], - [ - "removed_programs", - [ - "__builtin_identity__", - "___584" - ] - ], - [ - "___585", - [ - "__builtin_list__", - { - "str": "src_data_op_name" - }, - "anchor_data_op_name" - ] - ], - [ - "___586", - [ - "__builtin_list__", - { - "str": "dst_load_from_global_op_name" - }, - "dst_input_name" - ] - ], - [ - "___587", - [ - "__builtin_list__", - "___585", - "___586" - ] - ], - [ - "___588", - [ - "__builtin_list__" - ] - ], - [ - "___589", - [ - "__builtin_PackedArgs__", - "___588", - "___587" - ] - ], - [ - "___590", - [ - "RemoveElementInputIndexPass", - "___589" - ] - ], - [ - "rm_elementwise_drr_pass", - [ - "__builtin_identity__", - "___590" - ] - ], - [ - "___591", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___592", - [ - "__builtin_list__", - { - "str": "matched_pattern_mut_list" - }, - "removed_programs" - ] - ], - [ - "___593", - [ - "__builtin_list__", - "___592" - ] - ], - [ - "___594", - [ - "__builtin_list__", - "rm_elementwise_drr_pass" - ] - ], - [ - "___595", - [ - "__builtin_PackedArgs__", - "___594", - "___593" - ] - ], - [ - "___596", - [ - "___591", - "___595" - ] - ], - [ - "rm_elementwise_ir_pass", - [ - "__builtin_identity__", - "___596" - ] - ], - [ - "___597", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___598", - [ - "___597", - "rm_elementwise_ir_pass" - ] - ], - [ - "___599", - [ - "__builtin_identity__", - "___598" - ] - ], - [ - "___600", - [ - "__builtin_list__", - { - "str": "src_data_op_name" - }, - "anchor_data_op_name" - ] - ], - [ - "___601", - [ - "__builtin_list__", - { - "str": "dst_load_from_global_op_name" - }, - "dst_input_name" - ] - ], - [ - "___602", - [ - "__builtin_list__", - "___600", - "___601" - ] - ], - [ - "___603", - [ - "__builtin_list__" - ] - ], - [ - "___604", - [ - "__builtin_PackedArgs__", - "___603", - "___602" - ] - ], - [ - "___605", - [ - "RemoveBroadcastInputIndexPass", - "___604" - ] - ], - [ - "rm_broadcast_drr_pass", - [ - "__builtin_identity__", - "___605" - ] - ], - [ - "___606", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___607", - [ - "__builtin_list__", - { - "str": "matched_pattern_mut_list" - }, - "removed_programs" - ] - ], - [ - "___608", - [ - "__builtin_list__", - "___607" - ] - ], - [ - "___609", - [ - "__builtin_list__", - "rm_broadcast_drr_pass" - ] - ], - [ - "___610", - [ - "__builtin_PackedArgs__", - "___609", - "___608" - ] - ], - [ - "___611", - [ - "___606", - "___610" - ] - ], - [ - "rm_broadcast_ir_pass", - [ - "__builtin_identity__", - "___611" - ] - ], - [ - "___612", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___613", - [ - "___612", - "rm_broadcast_ir_pass" - ] - ], - [ - "___614", - [ - "__builtin_identity__", - "___613" - ] - ], - [ - "___615", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "run" - } - ] - ], - [ - "___616", - [ - "___615", - "full_index_program" - ] - ], - [ - "___617", - [ - "__builtin_identity__", - "___616" - ] - ], - [ - "Converter", - [ - "__builtin_identity__", - [ - "lambda", - [ - "program" - ], - [ - "__builtin_let__", - [ - [ - "___618", - [ - "__builtin_getattr__", - "self", - { - "str": "_simplify_index_program" - } - ] - ], - [ - "___619", - [ - "___618", - "program" - ] - ], - [ - "___620", - [ - "__builtin_list__", - "dst_input_name", - "___619" - ] - ], - [ - "___621", - [ - "__builtin_return__", - "___620" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___622", - [ - "map", - "Converter", - "removed_programs" - ] - ], - [ - "___623", - [ - "__builtin_return__", - "___622" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___624", - [ - "flat_map", - "MatchAndCopyInputIndex", - "input_names" - ] - ], - [ - "input_and_index_programs", - [ - "__builtin_identity__", - "___624" - ] - ], - [ - "MatchAndCopyOutputIndex", - [ - "__builtin_identity__", - [ - "lambda", - [ - "dst_output_name" - ], - [ - "__builtin_let__", - [ - [ - "___625", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_pass_manager" - } - ] - ], - [ - "___626", - [ - "___625" - ] - ], - [ - "pass_manager", - [ - "__builtin_identity__", - "___626" - ] - ], - [ - "___627", - [ - "MutableList" - ] - ], - [ - "removed_programs", - [ - "__builtin_identity__", - "___627" - ] - ], - [ - "___628", - [ - "__builtin_list__", - { - "str": "src_data_op_name" - }, - "anchor_data_op_name" - ] - ], - [ - "___629", - [ - "__builtin_list__", - { - "str": "dst_store_to_global_op_name" - }, - "dst_output_name" - ] - ], - [ - "___630", - [ - "__builtin_list__", - "___628", - "___629" - ] - ], - [ - "___631", - [ - "__builtin_list__" - ] - ], - [ - "___632", - [ - "__builtin_PackedArgs__", - "___631", - "___630" - ] - ], - [ - "___633", - [ - "RemoveOutputIndexPass", - "___632" - ] - ], - [ - "drr_pass", - [ - "__builtin_identity__", - "___633" - ] - ], - [ - "___634", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___635", - [ - "__builtin_list__", - { - "str": "matched_pattern_mut_list" - }, - "removed_programs" - ] - ], - [ - "___636", - [ - "__builtin_list__", - "___635" - ] - ], - [ - "___637", - [ - "__builtin_list__", - "drr_pass" - ] - ], - [ - "___638", - [ - "__builtin_PackedArgs__", - "___637", - "___636" - ] - ], - [ - "___639", - [ - "___634", - "___638" - ] - ], - [ - "ir_pass", - [ - "__builtin_identity__", - "___639" - ] - ], - [ - "___640", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___641", - [ - "___640", - "ir_pass" - ] - ], - [ - "___642", - [ - "__builtin_identity__", - "___641" - ] - ], - [ - "___643", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "run" - } - ] - ], - [ - "___644", - [ - "___643", - "full_index_program" - ] - ], - [ - "___645", - [ - "__builtin_identity__", - "___644" - ] - ], - [ - "Converter", - [ - "__builtin_identity__", - [ - "lambda", - [ - "program" - ], - [ - "__builtin_let__", - [ - [ - "___646", - [ - "__builtin_getattr__", - "self", - { - "str": "_simplify_index_program" - } - ] - ], - [ - "___647", - [ - "___646", - "program" - ] - ], - [ - "___648", - [ - "__builtin_list__", - "dst_output_name", - "___647" - ] - ], - [ - "___649", - [ - "__builtin_return__", - "___648" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___650", - [ - "map", - "Converter", - "removed_programs" - ] - ], - [ - "___651", - [ - "__builtin_return__", - "___650" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___652", - [ - "flat_map", - "MatchAndCopyOutputIndex", - "output_names" - ] - ], - [ - "output_and_index_programs", - [ - "__builtin_identity__", - "___652" - ] - ], - [ - "___653", - [ - "__builtin_starred__", - "input_and_index_programs" - ] - ], - [ - "___654", - [ - "__builtin_starred__", - "output_and_index_programs" - ] - ], - [ - "___655", - [ - "__builtin_list__", - "___653", - "___654" - ] - ], - [ - "___656", - [ - "OrderedDict", - "___655" - ] - ], - [ - "___657", - [ - "__builtin_return__", - "___656" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___658", - [ - "__builtin_getattr__", - "_make_index_func_unique_id2index_program", - { - "str": "__function__" - } - ] - ], - [ - "___659", - [ - "__builtin_list__", - { - "str": "_make_index_func_unique_id2index_program" - }, - "___658" - ] - ], - [ - "_replace_with_load_from_register", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "mut_program", - "load_ir_value_name", - "register_var_name" - ], - [ - "__builtin_let__", - [ - [ - "___660", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_pass_manager" - } - ] - ], - [ - "___661", - [ - "___660" - ] - ], - [ - "pass_manager", - [ - "__builtin_identity__", - "___661" - ] - ], - [ - "___662", - [ - "__builtin_getattr__", - "topo_drr_pass", - { - "str": "ReplaceWithLoadFromRegisterPass" - } - ] - ], - [ - "___663", - [ - "__builtin_list__", - { - "str": "name" - }, - "load_ir_value_name" - ] - ], - [ - "___664", - [ - "__builtin_list__", - { - "str": "register_var_name" - }, - "register_var_name" - ] - ], - [ - "___665", - [ - "__builtin_list__", - "___663", - "___664" - ] - ], - [ - "___666", - [ - "__builtin_list__" - ] - ], - [ - "___667", - [ - "__builtin_PackedArgs__", - "___666", - "___665" - ] - ], - [ - "___668", - [ - "___662", - "___667" - ] - ], - [ - "drr_pass", - [ - "__builtin_identity__", - "___668" - ] - ], - [ - "___669", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___670", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___671", - [ - "___670", - "drr_pass" - ] - ], - [ - "___672", - [ - "___669", - "___671" - ] - ], - [ - "___673", - [ - "__builtin_identity__", - "___672" - ] - ], - [ - "___674", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___675", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_dce_pass" - } - ] - ], - [ - "___676", - [ - "___675" - ] - ], - [ - "___677", - [ - "___674", - "___676" - ] - ], - [ - "___678", - [ - "__builtin_identity__", - "___677" - ] - ], - [ - "___679", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "run" - } - ] - ], - [ - "___680", - [ - "___679", - "mut_program" - ] - ], - [ - "___681", - [ - "__builtin_identity__", - "___680" - ] - ], - [ - "___682", - [ - "__builtin_return__", - "mut_program" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___683", - [ - "__builtin_getattr__", - "_replace_with_load_from_register", - { - "str": "__function__" - } - ] - ], - [ - "___684", - [ - "__builtin_list__", - { - "str": "_replace_with_load_from_register" - }, - "___683" - ] - ], - [ - "_replace_with_store_to_register", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "mut_program", - "store_ir_value_name", - "register_var_name" - ], - [ - "__builtin_let__", - [ - [ - "___685", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_pass_manager" - } - ] - ], - [ - "___686", - [ - "___685" - ] - ], - [ - "pass_manager", - [ - "__builtin_identity__", - "___686" - ] - ], - [ - "___687", - [ - "__builtin_getattr__", - "topo_drr_pass", - { - "str": "ReplaceWithStoreToRegisterPass" - } - ] - ], - [ - "___688", - [ - "__builtin_list__", - { - "str": "name" - }, - "store_ir_value_name" - ] - ], - [ - "___689", - [ - "__builtin_list__", - { - "str": "register_var_name" - }, - "register_var_name" - ] - ], - [ - "___690", - [ - "__builtin_list__", - "___688", - "___689" - ] - ], - [ - "___691", - [ - "__builtin_list__" - ] - ], - [ - "___692", - [ - "__builtin_PackedArgs__", - "___691", - "___690" - ] - ], - [ - "___693", - [ - "___687", - "___692" - ] - ], - [ - "drr_pass", - [ - "__builtin_identity__", - "___693" - ] - ], - [ - "___694", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___695", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_access_topo_drr_one_step_pass" - } - ] - ], - [ - "___696", - [ - "___695", - "drr_pass" - ] - ], - [ - "___697", - [ - "___694", - "___696" - ] - ], - [ - "___698", - [ - "__builtin_identity__", - "___697" - ] - ], - [ - "___699", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "add_pass" - } - ] - ], - [ - "___700", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "create_dce_pass" - } - ] - ], - [ - "___701", - [ - "___700" - ] - ], - [ - "___702", - [ - "___699", - "___701" - ] - ], - [ - "___703", - [ - "__builtin_identity__", - "___702" - ] - ], - [ - "___704", - [ - "__builtin_getattr__", - "pass_manager", - { - "str": "run" - } - ] - ], - [ - "___705", - [ - "___704", - "mut_program" - ] - ], - [ - "___706", - [ - "__builtin_identity__", - "___705" - ] - ], - [ - "___707", - [ - "__builtin_return__", - "mut_program" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___708", - [ - "__builtin_getattr__", - "_replace_with_store_to_register", - { - "str": "__function__" - } - ] - ], - [ - "___709", - [ - "__builtin_list__", - { - "str": "_replace_with_store_to_register" - }, - "___708" - ] - ], - [ - "_get_program_translator", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "ctx", - "o", - "t" - ], - [ - "__builtin_let__", - [ - [ - "___710", - [ - "__builtin_getattr__", - "ir_tools", - { - "str": "copy_fused_ops_to_program" - } - ] - ], - [ - "___711", - [ - "__builtin_getattr__", - "o", - { - "str": "trivial_op" - } - ] - ], - [ - "___712", - [ - "__builtin_list__", - { - "str": "tensor_match_ctx" - }, - "t" - ] - ], - [ - "___713", - [ - "__builtin_list__", - "___712" - ] - ], - [ - "___714", - [ - "__builtin_list__", - "___711" - ] - ], - [ - "___715", - [ - "__builtin_PackedArgs__", - "___714", - "___713" - ] - ], - [ - "___716", - [ - "___710", - "___715" - ] - ], - [ - "mut_program", - [ - "__builtin_identity__", - "___716" - ] - ], - [ - "___717", - [ - "__builtin_getattr__", - "self", - { - "str": "_insert_load_from_global" - } - ] - ], - [ - "___718", - [ - "__builtin_list__", - { - "str": "mm_out" - }, - { - "str": "input2" - } - ] - ], - [ - "___719", - [ - "__builtin_list__", - { - "str": "input_names" - }, - "___718" - ] - ], - [ - "___720", - [ - "__builtin_list__", - "___719" - ] - ], - [ - "___721", - [ - "__builtin_list__", - "mut_program" - ] - ], - [ - "___722", - [ - "__builtin_PackedArgs__", - "___721", - "___720" - ] - ], - [ - "___723", - [ - "___717", - "___722" - ] - ], - [ - "___724", - [ - "__builtin_identity__", - "___723" - ] - ], - [ - "___725", - [ - "__builtin_getattr__", - "self", - { - "str": "_insert_store_to_global" - } - ] - ], - [ - "___726", - [ - "__builtin_list__", - { - "str": "output" - } - ] - ], - [ - "___727", - [ - "__builtin_list__", - { - "str": "output_names" - }, - "___726" - ] - ], - [ - "___728", - [ - "__builtin_list__", - "___727" - ] - ], - [ - "___729", - [ - "__builtin_list__", - "mut_program" - ] - ], - [ - "___730", - [ - "__builtin_PackedArgs__", - "___729", - "___728" - ] - ], - [ - "___731", - [ - "___725", - "___730" - ] - ], - [ - "___732", - [ - "__builtin_identity__", - "___731" - ] - ], - [ - "___733", - [ - "__builtin_getattr__", - "self", - { - "str": "_make_kernel_arg_translator" - } - ] - ], - [ - "___734", - [ - "___733" - ] - ], - [ - "kernel_arg_translator", - [ - "__builtin_identity__", - "___734" - ] - ], - [ - "___735", - [ - "__builtin_getattr__", - "self", - { - "str": "_make_index_func_unique_id2index_program" - } - ] - ], - [ - "___736", - [ - "__builtin_list__", - { - "str": "anchor_data_op_name" - }, - { - "str": "mm_out" - } - ] - ], - [ - "___737", - [ - "__builtin_list__", - { - "str": "input2" - } - ] - ], - [ - "___738", - [ - "__builtin_list__", - { - "str": "input_names" - }, - "___737" - ] - ], - [ - "___739", - [ - "__builtin_list__" - ] - ], - [ - "___740", - [ - "__builtin_list__", - { - "str": "output_names" - }, - "___739" - ] - ], - [ - "___741", - [ - "__builtin_list__", - "___736", - "___738", - "___740" - ] - ], - [ - "___742", - [ - "__builtin_list__", - "mut_program" - ] - ], - [ - "___743", - [ - "__builtin_PackedArgs__", - "___742", - "___741" - ] - ], - [ - "___744", - [ - "___735", - "___743" - ] - ], - [ - "index_func_unique_id2index_program", - [ - "__builtin_identity__", - "___744" - ] - ], - [ - "___745", - [ - "print", - { - "str": "index_func_unique_id2index_program:\n" - }, - "index_func_unique_id2index_program" - ] - ], - [ - "___746", - [ - "__builtin_identity__", - "___745" - ] - ], - [ - "___747", - [ - "__builtin_getattr__", - "index_program_translator_util", - { - "str": "IndexProgramTranslatorMap" - } - ] - ], - [ - "___748", - [ - "__builtin_list__", - { - "str": "index_func_unique_id2index_program" - }, - "index_func_unique_id2index_program" - ] - ], - [ - "___749", - [ - "__builtin_list__", - { - "str": "kernel_arg_translator" - }, - "kernel_arg_translator" - ] - ], - [ - "___750", - [ - "__builtin_getattr__", - "matmul_binary_tpl", - { - "str": "get_anchor_iter_var_names" - } - ] - ], - [ - "___751", - [ - "___750" - ] - ], - [ - "___752", - [ - "__builtin_list__", - { - "str": "anchor_iter_var_names" - }, - "___751" - ] - ], - [ - "___753", - [ - "__builtin_list__", - "___748", - "___749", - "___752" - ] - ], - [ - "___754", - [ - "__builtin_list__" - ] - ], - [ - "___755", - [ - "__builtin_PackedArgs__", - "___754", - "___753" - ] - ], - [ - "___756", - [ - "___747", - "___755" - ] - ], - [ - "index_program_translator_map", - [ - "__builtin_identity__", - "___756" - ] - ], - [ - "___757", - [ - "__builtin_getattr__", - "self", - { - "str": "_replace_with_load_from_register" - } - ] - ], - [ - "___758", - [ - "__builtin_list__", - { - "str": "load_ir_value_name" - }, - { - "str": "mm_out" - } - ] - ], - [ - "___759", - [ - "__builtin_list__", - { - "str": "register_var_name" - }, - { - "str": "x" - } - ] - ], - [ - "___760", - [ - "__builtin_list__", - "___758", - "___759" - ] - ], - [ - "___761", - [ - "__builtin_list__", - "mut_program" - ] - ], - [ - "___762", - [ - "__builtin_PackedArgs__", - "___761", - "___760" - ] - ], - [ - "___763", - [ - "___757", - "___762" - ] - ], - [ - "___764", - [ - "__builtin_identity__", - "___763" - ] - ], - [ - "___765", - [ - "__builtin_getattr__", - "self", - { - "str": "_replace_with_store_to_register" - } - ] - ], - [ - "___766", - [ - "__builtin_list__", - { - "str": "store_ir_value_name" - }, - { - "str": "output" - } - ] - ], - [ - "___767", - [ - "__builtin_list__", - { - "str": "register_var_name" - }, - { - "str": "out" - } - ] - ], - [ - "___768", - [ - "__builtin_list__", - "___766", - "___767" - ] - ], - [ - "___769", - [ - "__builtin_list__", - "mut_program" - ] - ], - [ - "___770", - [ - "__builtin_PackedArgs__", - "___769", - "___768" - ] - ], - [ - "___771", - [ - "___765", - "___770" - ] - ], - [ - "___772", - [ - "__builtin_identity__", - "___771" - ] - ], - [ - "___773", - [ - "print", - { - "str": "mut_program:" - }, - "mut_program" - ] - ], - [ - "___774", - [ - "__builtin_identity__", - "___773" - ] - ], - [ - "___775", - [ - "__builtin_getattr__", - "op_compute_translator_util", - { - "str": "OpComputeTranslatorFactory" - } - ] - ], - [ - "___776", - [ - "___775" - ] - ], - [ - "op_compute_translator_maker", - [ - "__builtin_identity__", - "___776" - ] - ], - [ - "___777", - [ - "__builtin_getattr__", - "program_translator_util", - { - "str": "ProgramTranslator" - } - ] - ], - [ - "___778", - [ - "__builtin_getattr__", - "mut_program", - { - "str": "copy_to_const_program_data" - } - ] - ], - [ - "___779", - [ - "___778" - ] - ], - [ - "___780", - [ - "__builtin_list__", - { - "str": "program_property" - }, - "___779" - ] - ], - [ - "___781", - [ - "__builtin_list__", - { - "str": "kernel_arg_translator" - }, - "kernel_arg_translator" - ] - ], - [ - "___782", - [ - "__builtin_list__", - { - "str": "index_program_translator_map" - }, - "index_program_translator_map" - ] - ], - [ - "___783", - [ - "__builtin_list__", - { - "str": "op_translator_maker" - }, - "op_compute_translator_maker" - ] - ], - [ - "___784", - [ - "__builtin_list__", - "___780", - "___781", - "___782", - "___783" - ] - ], - [ - "___785", - [ - "__builtin_list__" - ] - ], - [ - "___786", - [ - "__builtin_PackedArgs__", - "___785", - "___784" - ] - ], - [ - "___787", - [ - "___777", - "___786" - ] - ], - [ - "program_translator", - [ - "__builtin_identity__", - "___787" - ] - ], - [ - "___788", - [ - "__builtin_return__", - "program_translator" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___789", - [ - "__builtin_getattr__", - "_get_program_translator", - { - "str": "__function__" - } - ] - ], - [ - "___790", - [ - "__builtin_list__", - { - "str": "_get_program_translator" - }, - "___789" - ] - ], - [ - "code_gen", - [ - "__builtin_identity__", - [ - "lambda", - [ - "self", - "ctx", - "o", - "t" - ], - [ - "__builtin_let__", - [ - [ - "___791", - [ - "__builtin_getattr__", - "self", - { - "str": "_get_program_translator" - } - ] - ], - [ - "___792", - [ - "___791", - "ctx", - "o", - "t" - ] - ], - [ - "program_translator", - [ - "__builtin_identity__", - "___792" - ] - ], - [ - "___793", - [ - "__builtin_getattr__", - "kernel_arg_id_util", - { - "str": "KernelArgIdNameRegistry" - } - ] - ], - [ - "___794", - [ - "__builtin_list__", - { - "str": "code_gen_ctx" - }, - "ctx" - ] - ], - [ - "___795", - [ - "__builtin_list__", - { - "str": "tensor_match_ctx" - }, - "t" - ] - ], - [ - "___796", - [ - "__builtin_list__", - { - "str": "name_prefix" - }, - { - "str": "" - } - ] - ], - [ - "___797", - [ - "__builtin_list__", - "___794", - "___795", - "___796" - ] - ], - [ - "___798", - [ - "__builtin_list__" - ] - ], - [ - "___799", - [ - "__builtin_PackedArgs__", - "___798", - "___797" - ] - ], - [ - "___800", - [ - "___793", - "___799" - ] - ], - [ - "mut_kernel_arg_id_registry", - [ - "__builtin_identity__", - "___800" - ] - ], - [ - "___801", - [ - "__builtin_getattr__", - "matmul_binary_tpl", - { - "str": "MatmulBinaryTemplate" - } - ] - ], - [ - "___802", - [ - "__builtin_list__", - { - "str": "program_translator" - }, - "program_translator" - ] - ], - [ - "___803", - [ - "__builtin_list__", - { - "str": "mut_kernel_arg_id_registry" - }, - "mut_kernel_arg_id_registry" - ] - ], - [ - "___804", - [ - "__builtin_list__", - "___802", - "___803" - ] - ], - [ - "___805", - [ - "__builtin_list__" - ] - ], - [ - "___806", - [ - "__builtin_PackedArgs__", - "___805", - "___804" - ] - ], - [ - "___807", - [ - "___801", - "___806" - ] - ], - [ - "template_module", - [ - "__builtin_identity__", - "___807" - ] - ], - [ - "___808", - [ - "__builtin_getattr__", - "template_module", - { - "str": "compile" - } - ] - ], - [ - "___809", - [ - "__builtin_getattr__", - "ctx", - { - "str": "in_tensor_data_ptr_kernel_arg_id" - } - ] - ], - [ - "___810", - [ - "__builtin_getattr__", - "t", - { - "str": "input0" - } - ] - ], - [ - "___811", - [ - "___809", - "___810" - ] - ], - [ - "___812", - [ - "__builtin_list__", - { - "str": "input_karg" - }, - "___811" - ] - ], - [ - "___813", - [ - "__builtin_getattr__", - "ctx", - { - "str": "in_tensor_data_ptr_kernel_arg_id" - } - ] - ], - [ - "___814", - [ - "__builtin_getattr__", - "t", - { - "str": "input1" - } - ] - ], - [ - "___815", - [ - "___813", - "___814" - ] - ], - [ - "___816", - [ - "__builtin_list__", - { - "str": "weight_karg" - }, - "___815" - ] - ], - [ - "___817", - [ - "__builtin_getattr__", - "ctx", - { - "str": "out_tensor_data_ptr_kernel_arg_id" - } - ] - ], - [ - "___818", - [ - "__builtin_getattr__", - "t", - { - "str": "output" - } - ] - ], - [ - "___819", - [ - "___817", - "___818" - ] - ], - [ - "___820", - [ - "__builtin_list__", - { - "str": "output_karg" - }, - "___819" - ] - ], - [ - "___821", - [ - "__builtin_getattr__", - "ctx", - { - "str": "dim_expr_kernel_arg_id" - } - ] - ], - [ - "___823", - [ - "__builtin_getattr__", - "t", - { - "str": "input0" - } - ] - ], - [ - "___822", - [ - "__builtin_getattr__", - "___823", - { - "str": "symbolic_shape_to_list" - } - ] - ], - [ - "___824", - [ - "___822" - ] - ], - [ - "___825", - [ - "__builtin_getitem__", - "___824", - 0 - ] - ], - [ - "___826", - [ - "___821", - "___825" - ] - ], - [ - "___827", - [ - "__builtin_list__", - { - "str": "m_karg" - }, - "___826" - ] - ], - [ - "___828", - [ - "__builtin_getattr__", - "ctx", - { - "str": "dim_expr_kernel_arg_id" - } - ] - ], - [ - "___830", - [ - "__builtin_getattr__", - "t", - { - "str": "input1" - } - ] - ], - [ - "___829", - [ - "__builtin_getattr__", - "___830", - { - "str": "symbolic_shape_to_list" - } - ] - ], - [ - "___831", - [ - "___829" - ] - ], - [ - "___832", - [ - "__builtin_getitem__", - "___831", - 1 - ] - ], - [ - "___833", - [ - "___828", - "___832" - ] - ], - [ - "___834", - [ - "__builtin_list__", - { - "str": "n_karg" - }, - "___833" - ] - ], - [ - "___835", - [ - "__builtin_getattr__", - "ctx", - { - "str": "dim_expr_kernel_arg_id" - } - ] - ], - [ - "___837", - [ - "__builtin_getattr__", - "t", - { - "str": "input0" - } - ] - ], - [ - "___836", - [ - "__builtin_getattr__", - "___837", - { - "str": "symbolic_shape_to_list" - } - ] - ], - [ - "___838", - [ - "___836" - ] - ], - [ - "___839", - [ - "__builtin_getitem__", - "___838", - 1 - ] - ], - [ - "___840", - [ - "___835", - "___839" - ] - ], - [ - "___841", - [ - "__builtin_list__", - { - "str": "k_karg" - }, - "___840" - ] - ], - [ - "___842", - [ - "__builtin_list__", - "___812", - "___816", - "___820", - "___827", - "___834", - "___841" - ] - ], - [ - "___843", - [ - "__builtin_list__" - ] - ], - [ - "___844", - [ - "__builtin_PackedArgs__", - "___843", - "___842" - ] - ], - [ - "___845", - [ - "___808", - "___844" - ] - ], - [ - "___846", - [ - "__builtin_return__", - "___845" - ] - ] - ], - [ - "__builtin_identity__", - null - ] - ] - ] - ] - ], - [ - "___847", - [ - "__builtin_getattr__", - "code_gen", - { - "str": "__function__" - } - ] - ], - [ - "___848", - [ - "__builtin_list__", - { - "str": "code_gen" - }, - "___847" - ] - ], - [ - "___849", - [ - "__builtin_list__" - ] - ], - [ - "___850", - [ - "__builtin_list__", - "___378", - "___394", - "___492", - "___508", - "___522", - "___527", - "___556", - "___576", - "___659", - "___684", - "___709", - "___790", - "___848" - ] - ], - [ - "___851", - [ - "__builtin_PackedArgs__", - "___849", - "___850" - ] - ], - [ - "___852", - [ - "BuiltinSerializableAttrMap", - "___851" - ] - ], - [ - "___853", - [ - "type", - { - "str": "MatmulBinaryFusion" - }, - "___352", - "___852" - ] - ], - [ - "___854", - [ - "__builtin_getattr__", - "abstract_drr", - { - "str": "register_drr_pass" - } - ] - ], - [ - "___855", - [ - "__builtin_list__", - { - "str": "nice" - }, - 0 - ] - ], - [ - "___856", - [ - "__builtin_list__", - "___855" - ] - ], - [ - "___857", - [ - "__builtin_list__", - { - "str": "matmul_binary_fusion" - } - ] - ], - [ - "___858", - [ - "__builtin_PackedArgs__", - "___857", - "___856" - ] - ], - [ - "___859", - [ - "___854", - "___858" - ] - ], - [ - "___860", - [ - "___859", - "___853" - ] - ], - [ - "MatmulBinaryFusion", - [ - "__builtin_identity__", - "___860" - ] - ] - ], - [ - "__builtin_identity__", - "___860" - ] -] \ No newline at end of file diff --git a/tests/ap/test_matmul_binary.sh b/tests/ap/test_matmul_binary.sh index a489a64..4e0a211 100644 --- a/tests/ap/test_matmul_binary.sh +++ b/tests/ap/test_matmul_binary.sh @@ -1,6 +1,6 @@ export CUDA_VISIBLE_DEVICES="0" export NVIDIA_TF32_OVERRIDE=0 -sh make_axpr.sh test_matmul_binary +sh make_axpr.sh FLAGS_enable_ap=1 AP_WORKSPACE_DIR=$(pwd)/ap_workspace AP_PATH=$(pwd)/ FLAGS_check_infer_symbolic=1 FLAGS_enable_pir_api=1 FLAGS_cinn_bucket_compile=True FLAGS_prim_enable_dynamic=true FLAGS_prim_all=True FLAGS_pir_apply_shape_optimization_pass=1 FLAGS_group_schedule_tiling_first=1 FLAGS_cinn_new_group_scheduler=1 python3.9 $(pwd)/paddle-tests/test_matmul_binary.py From 5959c33e93b7cb066ae05a6021b07bce41bbbba7 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 28 Feb 2025 23:07:38 +0800 Subject: [PATCH 5/5] Support autotune in AP generated kernels. --- tests/ap/matmul/all_tuning_configs.h | 12 +- tests/ap/matmul/generate_configs.py | 10 +- tests/ap/matmul/matmul.h | 14 +- tests/ap/matmul/tests/matmul_binary_kernel.cu | 8 +- tests/ap/matmul/tests/matmul_kernel.cu | 6 +- tests/ap/matmul/tests/matmul_unary_kernel.cu | 6 +- tests/ap/matmul_variadic_tpl.py | 127 +++++++++++++----- 7 files changed, 129 insertions(+), 54 deletions(-) diff --git a/tests/ap/matmul/all_tuning_configs.h b/tests/ap/matmul/all_tuning_configs.h index 7b6b4c1..55f97dc 100644 --- a/tests/ap/matmul/all_tuning_configs.h +++ b/tests/ap/matmul/all_tuning_configs.h @@ -20,30 +20,30 @@ template struct SwizzleWrapper { // cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; // }; -#define AP_AUTOTUNE_FP16(func) \ +#define AP_AUTOTUNE_half(func) \ { \ static int selected_config_id = -1; \ - static std::vector> \ + static std::vector> \ matmul_functions = {func<0>, func<1>, func<2>, func<3>, func<4>, \ func<5>, func<6>, func<7>, func<8>, func<9>, \ func<10>, func<11>, func<12>, func<13>, func<14>, \ func<15>, func<16>, func<17>, func<18>, func<19>, \ func<20>, func<21>, func<22>}; \ if (selected_config_id == -1) { \ - selected_config_id = ProfileBestConfig(matmul_functions, params); \ + selected_config_id = ap::ProfileBestConfig(matmul_functions, params); \ } \ matmul_functions[selected_config_id](params); \ } -#define AP_AUTOTUNE_FP32(func) \ +#define AP_AUTOTUNE_float(func) \ { \ static int selected_config_id = -1; \ - static std::vector> \ + static std::vector> \ matmul_functions = {func<0>, func<1>, func<2>, func<3>, func<4>, \ func<5>, func<6>, func<7>, func<8>, func<9>, \ func<10>, func<11>, func<12>}; \ if (selected_config_id == -1) { \ - selected_config_id = ProfileBestConfig(matmul_functions, params); \ + selected_config_id = ap::ProfileBestConfig(matmul_functions, params); \ } \ matmul_functions[selected_config_id](params); \ } diff --git a/tests/ap/matmul/generate_configs.py b/tests/ap/matmul/generate_configs.py index f4e83f8..a62a80d 100644 --- a/tests/ap/matmul/generate_configs.py +++ b/tests/ap/matmul/generate_configs.py @@ -24,12 +24,12 @@ autotune_wrapper_template = """ #define AP_AUTOTUNE_${datatype}(func) { \\ static int selected_config_id = -1; \\ - static std::vector> \\ + static std::vector> \\ matmul_functions = { \\ ${repeat_functions} \\ }; \\ if (selected_config_id == -1) { \\ - selected_config_id = ProfileBestConfig(matmul_functions, params); \\ + selected_config_id = ap::ProfileBestConfig(matmul_functions, params); \\ } \\ matmul_functions[selected_config_id](params); \\ } @@ -260,8 +260,10 @@ def main(): head_code_str = head_template.replace( "${num_configs_fp16}", str(num_fp16_configs) ).replace("${num_configs_fp32}", str(num_fp32_configs)) - fp16_autotune_wrapper_code_str = generate_autotune_wrapper("FP16", num_fp16_configs) - fp32_autotune_wrapper_code_str = generate_autotune_wrapper("FP32", num_fp32_configs) + fp16_autotune_wrapper_code_str = generate_autotune_wrapper("half", num_fp16_configs) + fp32_autotune_wrapper_code_str = generate_autotune_wrapper( + "float", num_fp32_configs + ) with open("all_tuning_configs.h", "w") as f: f.write(head_code_str) f.write(fp16_autotune_wrapper_code_str) diff --git a/tests/ap/matmul/matmul.h b/tests/ap/matmul/matmul.h index 01b2a99..d22bb98 100644 --- a/tests/ap/matmul/matmul.h +++ b/tests/ap/matmul/matmul.h @@ -81,6 +81,8 @@ struct GemmEpilogueParams { cudaStream_t stream; + std::vector input0_shape; + std::vector input1_shape; std::vector epilogue_in_ptrs; std::vector> epilogue_in_shapes; @@ -96,6 +98,9 @@ struct GemmEpilogueParams { ASSERT_CHECK(input_shape.size() >= 2U); ASSERT_CHECK(weight_shape.size() >= 2U); + input0_shape = input_shape; + input1_shape = weight_shape; + batch_count = 1; for (size_t i = 0; i < input_shape.size() - 2; ++i) { batch_count *= input_shape[i]; @@ -146,8 +151,13 @@ struct GemmEpilogueParams { shape_args.ldc_bias = (!bias || is_C_bias) ? 0 : n; } - void SetEpilogues(const std::vector &in_ptrs, - const std::vector> &in_shapes) { + void SetEpilogues(const std::vector &in_ptrs) { + epilogue_in_ptrs = in_ptrs; + } + + void + SetEpilogueAndShapes(const std::vector &in_ptrs, + const std::vector> &in_shapes) { ASSERT_CHECK(in_ptrs.size() == in_shapes.size()); epilogue_in_ptrs = in_ptrs; epilogue_in_shapes = in_shapes; diff --git a/tests/ap/matmul/tests/matmul_binary_kernel.cu b/tests/ap/matmul/tests/matmul_binary_kernel.cu index 2a51c5f..1044c55 100644 --- a/tests/ap/matmul/tests/matmul_binary_kernel.cu +++ b/tests/ap/matmul/tests/matmul_binary_kernel.cu @@ -50,13 +50,13 @@ void MatmulAddBinaryKernel( const std::vector> &epilogue_shapes) { GemmEpilogueParams params(*stream, input, weight, bias, output, input_shape, weight_shape, bias_shape); - params.SetEpilogues(epilogue_ins, epilogue_shapes); + params.SetEpilogueAndShapes(epilogue_ins, epilogue_shapes); -#if AP_ENABLE_AUTO_TUNING +#if AP_ENABLE_AUTOTUNE #if AP_USE_FLOAT16 - AP_AUTOTUNE_FP16(RunMatmulAddBinaryKernel); + AP_AUTOTUNE_half(RunMatmulAddBinaryKernel); #else - AP_AUTOTUNE_FP32(RunMatmulAddBinaryKernel); + AP_AUTOTUNE_float(RunMatmulAddBinaryKernel); #endif #else RunMatmulAddBinaryKernel(params); diff --git a/tests/ap/matmul/tests/matmul_kernel.cu b/tests/ap/matmul/tests/matmul_kernel.cu index 5d342e5..b926e66 100644 --- a/tests/ap/matmul/tests/matmul_kernel.cu +++ b/tests/ap/matmul/tests/matmul_kernel.cu @@ -32,11 +32,11 @@ void MatmulKernel(cudaStream_t *stream, const void *input, const void *weight, input_shape, weight_shape, std::vector{}, false, transpose_b); -#if AP_ENABLE_AUTO_TUNING +#if AP_ENABLE_AUTOTUNE #if AP_USE_FLOAT16 - AP_AUTOTUNE_FP16(RunMatmulKernel); + AP_AUTOTUNE_half(RunMatmulKernel); #else - AP_AUTOTUNE_FP32(RunMatmulKernel); + AP_AUTOTUNE_float(RunMatmulKernel); #endif #else RunMatmulKernel(params); diff --git a/tests/ap/matmul/tests/matmul_unary_kernel.cu b/tests/ap/matmul/tests/matmul_unary_kernel.cu index c4f16ea..ec6e168 100644 --- a/tests/ap/matmul/tests/matmul_unary_kernel.cu +++ b/tests/ap/matmul/tests/matmul_unary_kernel.cu @@ -41,11 +41,11 @@ void MatmulAddUnaryKernel(cudaStream_t *stream, const void *input, GemmEpilogueParams params(*stream, input, weight, bias, output, input_shape, weight_shape, bias_shape, false, transpose_b); -#if AP_ENABLE_AUTO_TUNING +#if AP_ENABLE_AUTOTUNE #if AP_USE_FLOAT16 - AP_AUTOTUNE_FP16(RunMatmulAddUnaryKernel); + AP_AUTOTUNE_half(RunMatmulAddUnaryKernel); #else - AP_AUTOTUNE_FP32(RunMatmulAddUnaryKernel); + AP_AUTOTUNE_float(RunMatmulAddUnaryKernel); #endif #else RunMatmulAddUnaryKernel(params); diff --git a/tests/ap/matmul_variadic_tpl.py b/tests/ap/matmul_variadic_tpl.py index 944a096..1e7c1c0 100644 --- a/tests/ap/matmul_variadic_tpl.py +++ b/tests/ap/matmul_variadic_tpl.py @@ -10,6 +10,11 @@ def get_anchor_iter_var_names(): return ["coord.batch", "coord.row", "coord.column"] +def is_in_tensor_karg(kernel_arg_id): + kernel_arg_id_type_name = f"{type(kernel_arg_id)}".replace("", "") + return kernel_arg_id_type_name == "InTensorDataPtrKernelArgId" + + class MatmulVariadicTemplate: def __init__( self, @@ -30,6 +35,10 @@ def __init__( [DataType.int64_t, "int64_t"], ] ) + self.input_dim_karg_to_shape_access = MutableOrderedDict() + self.input_tensor_karg_to_shape_access = MutableOrderedDict() + self.kernel_name = "MatmulVariadicKernel" + self.library_name = "matmul_variadic_kernel" def _register_name(self, pair): registry = self.mut_kernel_arg_id_registry @@ -142,14 +151,23 @@ def declare_epilogue_arguments_field(pair): map(declare_epilogue_arguments_field, generated_kernel_arg_id_and_names) ) - def get_epilogue_arguments_init_str(self, param_obj_name, indent): + def get_epilogue_arguments_init_str(self, obj_name, params_name, output_dtype, indent): def declare_epilogue_arguments_assign(pair): kernel_arg_id = pair[0] + is_in_tensor_type = is_in_tensor_karg(kernel_arg_id) + var_name = pair[1] field_name = self.kernel_arg_translator.get_param_struct_field_name( var_name ) - return f"{param_obj_name}.{field_name} = {var_name};" + def get_in_tensor_statement(): + param_name_for_var = self.input_tensor_karg_to_shape_access[var_name] + return f"reinterpret_cast({params_name}.{param_name_for_var})" + def get_dim_expr_statement(): + param_name_for_var = self.input_dim_karg_to_shape_access[var_name] + return f"{params_name}.{param_name_for_var}" + statement = get_in_tensor_statement() if is_in_tensor_type else get_dim_expr_statement() + return f"{obj_name}.{field_name} = {statement};" generated_kernel_arg_id_and_names = ( self.mut_kernel_arg_id_registry.generated_kernel_arg_id2unique_name.items() @@ -158,10 +176,35 @@ def declare_epilogue_arguments_assign(pair): map(declare_epilogue_arguments_assign, generated_kernel_arg_id_and_names) ) - def get_input_shape_init_str(self, input_name, input_shape_kargs, indent): + def get_params_epilogue_ptrs_init_str(self, obj_name, indent): + in_tensor_id = 0 + def declare_params_epilogue_arguments_assign(pair): + def get_creator(): + return f"{obj_name}[{in_tensor_id}]" + + kernel_arg_id = pair[0] + is_in_tensor_type = is_in_tensor_karg(kernel_arg_id) + def generate_statement(): + self.input_tensor_karg_to_shape_access.get_or_create(pair[1], get_creator) + statement = f"{obj_name}.push_back({pair[1]});" + in_tensor_id = in_tensor_id + 1 + return statement + return generate_statement() if is_in_tensor_type else "" + + generated_kernel_arg_id_and_names = ( + self.mut_kernel_arg_id_registry.generated_kernel_arg_id2unique_name.items() + ) + return f"\n{indent}".join( + map(declare_params_epilogue_arguments_assign, generated_kernel_arg_id_and_names) + ) + + def get_params_input_shape_init_str(self, input_name, input_shape_kargs, indent): def init_input_shape_with_args(i): - karg = input_shape_kargs[i] - return f"{indent}{input_name}_shape[{i}] = {self.get_kernel_arg_id_var_name(karg)};" + def get_creator(): + return f"{input_name}_shape[{i}]" + karg_var_name = self.get_kernel_arg_id_var_name(input_shape_kargs[i]) + self.input_dim_karg_to_shape_access.get_or_create(karg_var_name, get_creator) + return f"{indent}{input_name}_shape[{i}] = {karg_var_name};" shape_vector_init_str = ( f"{input_name}_shape.resize({len(input_shape_kargs)});\n" @@ -186,6 +229,7 @@ def make_project( #include #include "cutlass_matmul.cuh" +#include "profile.h" namespace ap { @@ -204,32 +248,45 @@ def make_project( } }; +template +static void RunMatmulWithVariadicKernel(const GemmEpilogueParams ¶ms) { + using ElementT = ${output_dtype}; + using ElementComputeT = float; + + typename ap::VariadicEpilogueFunctor::Arguments epilogue_args; + + AP_EPILOGUE_ARGUMENTS_INIT + + ap::CutlassMatmulAddVariadic(params, epilogue_args); } +} // namespace ap + extern "C" { -void MatmulVariadicKernel(void* stream_ptr, AP_KERNEL_ARGS_DECLARE) { - std::vector $input0_shape; - AP_INPUT0_SHAPE_INIT +void ${kernel_name}(void* stream_ptr, AP_KERNEL_ARGS_DECLARE) { + std::vector ${input0}_shape; + AP_PARAMS_INPUT0_SHAPE_INIT - std::vector $input1_shape; - AP_INPUT1_SHAPE_INIT + std::vector ${input1}_shape; + AP_PARAMS_INPUT1_SHAPE_INIT cudaStream_t* cuda_stream_ptr = reinterpret_cast(stream_ptr); ap::GemmEpilogueParams params( - *cuda_stream_ptr, $input0, $input1, nullptr, $output, $input0_shape, $input1_shape, std::vector{}); + *cuda_stream_ptr, ${input0}, ${input1}, nullptr, ${output}, ${input0}_shape, ${input1}_shape, std::vector{}); - using ElementT = AP_GENERATED_ELEMENT_DTYPE; - using ElementComputeT = float; + std::vector epilogue_in_ptrs; + AP_PARAMS_EPILOGUE_PTRS_INIT - typename ap::VariadicEpilogueFunctor::Arguments epilogue_args; - - AP_EPILOGUE_ARGUMENTS_INIT + params.SetEpilogues(epilogue_in_ptrs); - ap::CutlassMatmulAddVariadic(params, epilogue_args); +#if AP_ENABLE_AUTOTUNE + AP_AUTOTUNE_${output_dtype}(ap::RunMatmulWithVariadicKernel); +#else + ap::RunMatmulWithVariadicKernel(params); +#endif } } - """ output_dtype = self.dtype2type_name[output_karg.type.data_type] @@ -237,26 +294,31 @@ def make_project( code_template.replace( "AP_GENERATED_BINARY_EPILOGUE_STRING", trivial_code_str ) - .replace("AP_GENERATED_ELEMENT_DTYPE", output_dtype) .replace("AP_KERNEL_ARGS_DECLARE", self.get_kernel_arg_list_str()) .replace( - "AP_INPUT0_SHAPE_INIT", - self.get_input_shape_init_str("$input0", input0_shape_kargs, indent=" "), + "AP_PARAMS_INPUT0_SHAPE_INIT", + self.get_params_input_shape_init_str("${input0}", input0_shape_kargs, indent=" "), + ) + .replace( + "AP_PARAMS_INPUT1_SHAPE_INIT", + self.get_params_input_shape_init_str("${input1}", input1_shape_kargs, indent=" "), ) .replace( - "AP_INPUT1_SHAPE_INIT", - self.get_input_shape_init_str("$input1", input1_shape_kargs, indent=" "), + "AP_PARAMS_EPILOGUE_PTRS_INIT", + self.get_params_epilogue_ptrs_init_str("epilogue_in_ptrs", indent=" "), ) .replace( "AP_EPILOGUE_ARGUMENTS_FIELDS", self.get_epilogue_arguments_fields_str(indent=" ") ) .replace( "AP_EPILOGUE_ARGUMENTS_INIT", - self.get_epilogue_arguments_init_str("epilogue_args", indent=" "), + self.get_epilogue_arguments_init_str("epilogue_args", "params", output_dtype, indent=" "), ) - .replace("$input0", self.get_kernel_arg_id_var_name(input0_karg)) - .replace("$input1", self.get_kernel_arg_id_var_name(input1_karg)) - .replace("$output", self.get_kernel_arg_id_var_name(output_karg)) + .replace("${kernel_name}", self.kernel_name) + .replace("${input0}", self.get_kernel_arg_id_var_name(input0_karg)) + .replace("${input1}", self.get_kernel_arg_id_var_name(input1_karg)) + .replace("${output}", self.get_kernel_arg_id_var_name(output_karg)) + .replace("${output_dtype}", output_dtype) ) source_dir = "/work/abstract_pass/Athena/tests/ap/matmul" @@ -269,26 +331,27 @@ def make_project( compile_cmd = compile_cmd + " -I " + source_dir compile_cmd = ( compile_cmd - + " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -DCUTLASS_DEBUG_TRACE_LEVEL=0 " + + " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -DCUTLASS_DEBUG_TRACE_LEVEL=0" ) + compile_cmd = compile_cmd + " -DAP_ENABLE_AUTOTUNE=1 -DAP_ENABLE_DEBUG=0" compile_cmd = ( compile_cmd - + " --shared matmul_variadic_kernel.cu -o libmatmul_variadic_kernel.so" + + f" --shared {self.library_name}.cu -o lib{self.library_name}.so" ) return CodeModule( FuncDeclare( DataType.void, - "MatmulVariadicKernel", + self.kernel_name, [PointerType.void_ptr, *self.get_kernel_arg_types()], ), Project( nested_files=Project.Directory( - ["matmul_variadic_kernel.cu", Project.FileContent(code)], + [f"{self.library_name}.cu", Project.FileContent(code)], ["make.sh", Project.FileContent(compile_cmd)], ), compile_cmd="sh make.sh", - so_relative_path="libmatmul_variadic_kernel.so", + so_relative_path=f"lib{self.library_name}.so", ), )