diff --git a/tests/ap/code_gen_value_util.py b/tests/ap/code_gen_value_util.py index 545757c..46569e6 100644 --- a/tests/ap/code_gen_value_util.py +++ b/tests/ap/code_gen_value_util.py @@ -1,14 +1,24 @@ +class TensorCodeGenValue: + def __init__(self, pir_type, var_name): + self.pir_type = pir_type + self.var_name = var_name + self.const_value = None -class CodeGenValue: - def __init__(self, pir_type, var_name): - self.pir_type = pir_type - self.var_name = var_name - self.const_value = None + def get_dtype(self): + def convert_to_dtype(pir_dtype, shape, data_layout): + return pir_dtype.convert_to_dtype() - def get_dtype(self): - def convert_to_dtype(pir_dtype, shape, data_layout): - return pir_dtype.convert_to_dtype() - return self.pir_type.match(t_dtensor=convert_to_dtype) + return self.pir_type.match(t_dtensor=convert_to_dtype) - def is_dense_tensor_type(self): - return self.pir_type.get_type_name() == "t_dtensor" + def is_dense_tensor_type(self): + return self.pir_type.get_type_name() == "t_dtensor" + + +class AttrCodeGenValue: + def __init__(self, data_type, var_name): + self.data_type = data_type + self.var_name = var_name + self.const_value = None + + def get_dtype(self): + return self.data_type diff --git a/tests/ap/op_compute_translator_util.py b/tests/ap/op_compute_translator_util.py index 16b2f5a..8b0094d 100644 --- a/tests/ap/op_compute_translator_util.py +++ b/tests/ap/op_compute_translator_util.py @@ -20,7 +20,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): def get_out_cg_val(self, i): register_var_name_attr = self.op_property.attributes.register_var_name register_var_name = register_var_name_attr.match(a_str=lambda x:x) - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, register_var_name ) @@ -55,7 +55,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -102,7 +102,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): def get_out_cg_val(self, i): name = self.op_property.attributes.name.match(a_str=lambda x:x) - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, name ) @@ -128,7 +128,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -163,7 +163,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -188,7 +188,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -213,7 +213,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -238,7 +238,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -264,7 +264,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -289,7 +289,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -308,22 +308,33 @@ def __init__(self, self.index_program_translator_map = index_program_translator_map def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): - scale = self.op_property.attributes.scale.match(a_f32=lambda x:x) - bias = self.op_property.attributes.bias.match(a_f32=lambda x:x) + scale_value = self.op_property.attributes.scale.match(a_f32=lambda x:x) + scale_var_name = f"op{self.op_property.op_index}_scale" + scale = self.get_tmp_cg_val(scale_var_name) + mut_lir_code_gen_ctx.let(scale, f"{scale_value}") + + bias_value = self.op_property.attributes.bias.match(a_f32=lambda x:x) + bias_var_name = f"op{self.op_property.op_index}_bias" + bias = self.get_tmp_cg_val(bias_var_name) + mut_lir_code_gen_ctx.let(bias, f"{bias_value}") + bias_after_scale = self.op_property.attributes.bias_after_scale.match(a_bool=lambda x:x) in_name = inputs[0].var_name - true_str = f"{scale} * {in_name} + {bias}" - false_str = f"{scale} * ({in_name} + {bias})" + true_str = f"{scale_var_name} * {in_name} + {bias_var_name}" + false_str = f"{scale_var_name} * ({in_name} + {bias_var_name})" out = self.get_out_cg_val(0) mut_lir_code_gen_ctx.let(out, true_str if bias_after_scale else false_str) return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) + def get_tmp_cg_val(self, name): + return code_gen_value_util.AttrCodeGenValue(DataType.float, f"{name}") + class PdOpSubstractCodeGen: def __init__(self, @@ -346,7 +357,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -373,7 +384,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -400,7 +411,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -427,7 +438,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -454,7 +465,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" ) @@ -511,7 +522,7 @@ def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx): return [out] def get_out_cg_val(self, i): - return code_gen_value_util.CodeGenValue( + return code_gen_value_util.TensorCodeGenValue( self.output_properties[i].type, f"op{self.op_property.op_index}_out{i}" )