Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions tests/ap/code_gen_value_util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
class TensorCodeGenValue:
def __init__(self, pir_type, var_name):
self.pir_type = pir_type
self.var_name = var_name
self.const_value = None

class CodeGenValue:
def __init__(self, pir_type, var_name):
self.pir_type = pir_type
self.var_name = var_name
self.const_value = None
def get_dtype(self):
def convert_to_dtype(pir_dtype, shape, data_layout):
return pir_dtype.convert_to_dtype()

def get_dtype(self):
def convert_to_dtype(pir_dtype, shape, data_layout):
return pir_dtype.convert_to_dtype()
return self.pir_type.match(t_dtensor=convert_to_dtype)
return self.pir_type.match(t_dtensor=convert_to_dtype)

def is_dense_tensor_type(self):
return self.pir_type.get_type_name() == "t_dtensor"
def is_dense_tensor_type(self):
return self.pir_type.get_type_name() == "t_dtensor"


class AttrCodeGenValue:
def __init__(self, data_type, var_name):
self.data_type = data_type
self.var_name = var_name
self.const_value = None

def get_dtype(self):
return self.data_type
53 changes: 32 additions & 21 deletions tests/ap/op_compute_translator_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
def get_out_cg_val(self, i):
register_var_name_attr = self.op_property.attributes.register_var_name
register_var_name = register_var_name_attr.match(a_str=lambda x:x)
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
register_var_name
)
Expand Down Expand Up @@ -55,7 +55,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand Down Expand Up @@ -102,7 +102,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):

def get_out_cg_val(self, i):
name = self.op_property.attributes.name.match(a_str=lambda x:x)
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
name
)
Expand All @@ -128,7 +128,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand Down Expand Up @@ -163,7 +163,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand All @@ -188,7 +188,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand All @@ -213,7 +213,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand All @@ -238,7 +238,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand All @@ -264,7 +264,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand All @@ -289,7 +289,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand All @@ -308,22 +308,33 @@ def __init__(self,
self.index_program_translator_map = index_program_translator_map

def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
scale = self.op_property.attributes.scale.match(a_f32=lambda x:x)
bias = self.op_property.attributes.bias.match(a_f32=lambda x:x)
scale_value = self.op_property.attributes.scale.match(a_f32=lambda x:x)
scale_var_name = f"op{self.op_property.op_index}_scale"
scale = self.get_tmp_cg_val(scale_var_name)
mut_lir_code_gen_ctx.let(scale, f"{scale_value}")

bias_value = self.op_property.attributes.bias.match(a_f32=lambda x:x)
bias_var_name = f"op{self.op_property.op_index}_bias"
bias = self.get_tmp_cg_val(bias_var_name)
mut_lir_code_gen_ctx.let(bias, f"{bias_value}")

bias_after_scale = self.op_property.attributes.bias_after_scale.match(a_bool=lambda x:x)
in_name = inputs[0].var_name
true_str = f"{scale} * {in_name} + {bias}"
false_str = f"{scale} * ({in_name} + {bias})"
true_str = f"{scale_var_name} * {in_name} + {bias_var_name}"
false_str = f"{scale_var_name} * ({in_name} + {bias_var_name})"
out = self.get_out_cg_val(0)
mut_lir_code_gen_ctx.let(out, true_str if bias_after_scale else false_str)
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)

def get_tmp_cg_val(self, name):
return code_gen_value_util.AttrCodeGenValue(DataType.float, f"{name}")


class PdOpSubstractCodeGen:
def __init__(self,
Expand All @@ -346,7 +357,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand All @@ -373,7 +384,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand All @@ -400,7 +411,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand All @@ -427,7 +438,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand All @@ -454,7 +465,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand Down Expand Up @@ -511,7 +522,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return [out]

def get_out_cg_val(self, i):
return code_gen_value_util.CodeGenValue(
return code_gen_value_util.TensorCodeGenValue(
self.output_properties[i].type,
f"op{self.op_property.op_index}_out{i}"
)
Expand Down