Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/ap/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# import test_trivial_reduce
# import test_binary_trivial_reduce
import test_matmul_binary
import matmul_unary_pattern
import matmul_binary_pattern
32 changes: 21 additions & 11 deletions tests/ap/code_gen_value_util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
class TensorCodeGenValue:
def __init__(self, pir_type, var_name):
self.pir_type = pir_type
self.var_name = var_name
self.const_value = None

class CodeGenValue:
def __init__(self, pir_type, var_name):
self.pir_type = pir_type
self.var_name = var_name
self.const_value = None
def get_dtype(self):
def convert_to_dtype(pir_dtype, shape, data_layout):
return pir_dtype.convert_to_dtype()

def get_dtype(self):
def convert_to_dtype(pir_dtype, shape, data_layout):
return pir_dtype.convert_to_dtype()
return self.pir_type.match(t_dtensor=convert_to_dtype)
return self.pir_type.match(t_dtensor=convert_to_dtype)

def is_dense_tensor_type(self):
return self.pir_type.get_type_name() == "t_dtensor"
def is_dense_tensor_type(self):
return self.pir_type.get_type_name() == "t_dtensor"


class AttrCodeGenValue:
def __init__(self, data_type, var_name):
self.data_type = data_type
self.var_name = var_name
self.const_value = None

def get_dtype(self):
return self.data_type
5 changes: 3 additions & 2 deletions tests/ap/make_axpr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ FILENAMES_ARRAY=(
"access_topo_drr"
"abstract_drr"
"ap_tpl_codegen"
"matmul_binary_tpl"
"test_matmul_binary"
"matmul_variadic_tpl"
"matmul_unary_pattern"
"matmul_binary_pattern"
)
for filename in "${FILENAMES_ARRAY[@]}"
do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import access_topo_drr
import topo_drr_pass
import op_convertion_drr_pass
import matmul_binary_tpl
import matmul_variadic_tpl
import ir_tools
import index_program_translator_util
import op_compute_translator_util
Expand Down Expand Up @@ -250,7 +250,7 @@ def _insert_store_to_global(self, program, output_names):
init_pass_manager.run(program)

def _make_kernel_arg_translator(self):
return matmul_binary_tpl.make_kernel_arg_translator()
return matmul_variadic_tpl.make_kernel_arg_translator()

def _apply_topo_access_passes(self, mut_program, anchor_data_op_name):
init_pass_manager = ir_tools.create_pass_manager()
Expand Down Expand Up @@ -349,6 +349,7 @@ def _get_program_translator(self, ctx, o, t):
mut_program = ir_tools.copy_fused_ops_to_program(
o.trivial_op, tensor_match_ctx=t
)
print("origin-program_translator", mut_program)
self._insert_load_from_global(
mut_program,
input_names=["mm_out", "input2"]
Expand All @@ -368,7 +369,7 @@ def _get_program_translator(self, ctx, o, t):
index_program_translator_map = index_program_translator_util.IndexProgramTranslatorMap(
index_func_unique_id2index_program=index_func_unique_id2index_program,
kernel_arg_translator=kernel_arg_translator,
anchor_iter_var_names=matmul_binary_tpl.get_anchor_iter_var_names()
anchor_iter_var_names=matmul_variadic_tpl.get_anchor_iter_var_names()
)
self._replace_with_load_from_register(
mut_program,
Expand All @@ -380,7 +381,7 @@ def _get_program_translator(self, ctx, o, t):
store_ir_value_name="output",
register_var_name="out"
)
print("mut_program:", mut_program)
print("after-insert-load-store-program_translator", mut_program)
op_compute_translator_maker = op_compute_translator_util.OpComputeTranslatorFactory()
program_translator = program_translator_util.ProgramTranslator(
program_property=mut_program.copy_to_const_program_data(),
Expand All @@ -397,7 +398,7 @@ def code_gen(self, ctx, o, t):
tensor_match_ctx=t,
name_prefix=""
)
template_module = matmul_binary_tpl.MatmulBinaryTemplate(
template_module = matmul_variadic_tpl.MatmulVariadicTemplate(
program_translator=program_translator,
mut_kernel_arg_id_registry=mut_kernel_arg_id_registry,
)
Expand Down
Loading