-
Notifications
You must be signed in to change notification settings - Fork 4
support arbitrary input/ouput and non-3-dim mm_out #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: ap
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1 @@ | ||
| # import test_trivial_reduce | ||
| # import test_binary_trivial_reduce | ||
| import test_matmul_binary | ||
| import test_matmul_epilogue |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ def get_or_create_kernel_arg_id_manul_var_name(self, kernel_arg_id, cpp_var_name | |
| return self.all_kernel_arg_id2unique_name.get_or_create(kernel_arg_id, create) | ||
|
|
||
| def get_in_tensor_data_ptr_var_name(self, in_ir_value_name): | ||
| print('in_ir_value_name: ', in_ir_value_name) | ||
| ir_value = getattr(self.tensor_match_ctx, in_ir_value_name) | ||
| kernel_arg_id = self.code_gen_ctx.in_tensor_data_ptr_kernel_arg_id(ir_value) | ||
| create = self._get_creator(kernel_arg_id, self._create_in_tensor_data_ptr_var_name) | ||
|
|
@@ -29,6 +30,7 @@ def _create_in_tensor_data_ptr_var_name(self): | |
| return name | ||
|
|
||
| def get_out_tensor_data_ptr_var_name(self, out_ir_value_name): | ||
| out_ir_value_name = out_ir_value_name.replace("out", "output") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out_{i} 是程序的局部变量, output_{i} 是context中注册的ir_name,此命名可以相互区分。采用replace可以使得,局部变量找到全局变量。
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里没有解释为啥需要有"out", "output"字面量 |
||
| ir_value = getattr(self.tensor_match_ctx, out_ir_value_name) | ||
| kernel_arg_id = self.code_gen_ctx.out_tensor_data_ptr_kernel_arg_id(ir_value) | ||
| create = self._get_creator(kernel_arg_id, self._create_out_tensor_data_ptr_var_name) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,11 @@ def is_in_tensor_karg(kernel_arg_id): | |
| ) | ||
| return kernel_arg_id_type_name == "InTensorDataPtrKernelArgId" | ||
|
|
||
| def is_out_tensor_karg(kernel_arg_id): | ||
| kernel_arg_id_type_name = f"{type(kernel_arg_id)}".replace("<class '", "").replace( | ||
| "'>", "" | ||
| ) | ||
| return kernel_arg_id_type_name == "OutTensorDataPtrKernelArgId" | ||
|
Comment on lines
+19
to
+23
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的代码质量不行,如果需要判断kernel_arg_id的类型,那就在c++代码里导出OutTensorDataPtrKernelArgId类变量到python层。
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为啥要这个is_out_tensor_arg的逻辑? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
我上个PR中对输入指针参数是这么判断的,我的锅,我来改下。
因为之前为Autotune功能设计的 static int ProfileBestConfig(
const std::vector<std::function<void(const GemmEpilogueParams &)>>
&gemm_functions,
const GemmEpilogueParams ¶ms);最近想到 |
||
|
|
||
| class MatmulBinaryTemplate: | ||
| def __init__( | ||
|
|
@@ -39,6 +44,7 @@ def __init__( | |
| ) | ||
| self.input_dim_karg_to_shape_access = MutableOrderedDict() | ||
| self.input_tensor_karg_to_shape_access = MutableOrderedDict() | ||
| self.output_tensor_karg_to_shape_access = MutableOrderedDict() | ||
| self.kernel_name = "MatmulBinaryKernel" | ||
| self.library_name = "matmul_binary_kernel" | ||
|
|
||
|
|
@@ -105,6 +111,11 @@ def get_kernel_arg_runtime_getters(self): | |
| lambda pair: pair[0].runtime_getter, all_kernel_arg_id_and_unique_names | ||
| ) | ||
|
|
||
| def init_outputs(self): | ||
| out_tensor_data_nums = self.mut_kernel_arg_id_registry.out_tensor_data_ptr_seq_no | ||
| stmt = map(lambda i: f"out{i}", range(out_tensor_data_nums + 1)) | ||
| return "T " + f", ".join(stmt) + ";" | ||
|
|
||
| def get_kernel_arg_types(self): | ||
| all_kernel_arg_id_and_unique_names = ( | ||
| self.mut_kernel_arg_id_registry.all_kernel_arg_id2unique_name.items() | ||
|
|
@@ -159,6 +170,7 @@ def get_epilogue_arguments_init_str( | |
| def declare_epilogue_arguments_assign(pair): | ||
| kernel_arg_id = pair[0] | ||
| is_in_tensor_type = is_in_tensor_karg(kernel_arg_id) | ||
| is_out_tensor_type = is_out_tensor_karg(kernel_arg_id) | ||
|
|
||
| var_name = pair[1] | ||
| field_name = self.kernel_arg_translator.get_param_struct_field_name( | ||
|
|
@@ -169,30 +181,36 @@ def get_in_tensor_statement(): | |
| param_name_for_var = self.input_tensor_karg_to_shape_access[var_name] | ||
| return f"reinterpret_cast<const {output_dtype} *>({params_name}.{param_name_for_var})" | ||
|
|
||
| def get_out_tensor_statement(): | ||
| param_name_for_var = self.output_tensor_karg_to_shape_access[var_name] | ||
| return f"reinterpret_cast<{output_dtype} *>({params_name}.{param_name_for_var})" | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里难道不会有性能问题吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里只是在Host端赋值一次,应该不会有太大开销,对应PR描述示例代码里面的如下部分: epilogue_args.in_ptr_0 = reinterpret_cast<const float *>(params.epilogue_in_ptrs[0]);
epilogue_args.out_ptr_0 = reinterpret_cast<float *>(params.epilogue_out_ptrs[0]); |
||
|
|
||
| def get_dim_expr_statement(): | ||
| param_name_for_var = self.input_dim_karg_to_shape_access[var_name] | ||
| return f"{params_name}.{param_name_for_var}" | ||
|
|
||
| statement = ( | ||
| get_in_tensor_statement() | ||
| if is_in_tensor_type | ||
| else get_dim_expr_statement() | ||
| else get_out_tensor_statement() | ||
| if is_out_tensor_type | ||
| else get_dim_expr_statement() | ||
| ) | ||
| return f"{obj_name}.{field_name} = {statement};" | ||
|
|
||
| generated_kernel_arg_id_and_names = ( | ||
| self.mut_kernel_arg_id_registry.generated_kernel_arg_id2unique_name.items() | ||
| ) | ||
|
|
||
| return f"\n{indent}".join( | ||
| map(declare_epilogue_arguments_assign, generated_kernel_arg_id_and_names) | ||
| ) | ||
|
|
||
| def get_params_epilogue_ptrs_init_str(self, obj_name, indent): | ||
| def get_params_epilogue_ptrs_init_str(self, in_obj_name, out_obj_name, indent): | ||
| in_tensor_id = 0 | ||
|
|
||
| def declare_params_epilogue_arguments_assign(pair): | ||
| def declare_in_params_epilogue_arguments_assign(pair): | ||
| def get_creator(): | ||
| return f"{obj_name}[{in_tensor_id}]" | ||
| return f"{in_obj_name}[{in_tensor_id}]" | ||
|
|
||
| kernel_arg_id = pair[0] | ||
| is_in_tensor_type = is_in_tensor_karg(kernel_arg_id) | ||
|
|
@@ -201,7 +219,7 @@ def generate_statement(): | |
| self.input_tensor_karg_to_shape_access.get_or_create( | ||
| pair[1], get_creator | ||
| ) | ||
| statement = f"{obj_name}.push_back({pair[1]});" | ||
| statement = f"{in_obj_name}.push_back({pair[1]});" | ||
| in_tensor_id = in_tensor_id + 1 | ||
| return statement | ||
|
|
||
|
|
@@ -210,13 +228,39 @@ def generate_statement(): | |
| generated_kernel_arg_id_and_names = ( | ||
| self.mut_kernel_arg_id_registry.generated_kernel_arg_id2unique_name.items() | ||
| ) | ||
| return f"\n{indent}".join( | ||
| map( | ||
| declare_params_epilogue_arguments_assign, | ||
| in_str_list = map( | ||
| declare_in_params_epilogue_arguments_assign, | ||
| generated_kernel_arg_id_and_names, | ||
| ) | ||
| ) | ||
|
|
||
| out_tensor_id = 0 | ||
| def declare_out_params_epilogue_arguments_assign(pair): | ||
| def get_creator(): | ||
| return f"{out_obj_name}[{out_tensor_id}]" | ||
|
|
||
| kernel_arg_id = pair[0] | ||
| is_out_tensor_type = is_out_tensor_karg(kernel_arg_id) | ||
|
|
||
| def generate_statement(): | ||
| self.output_tensor_karg_to_shape_access.get_or_create( | ||
| pair[1], get_creator | ||
| ) | ||
| statement = f"{out_obj_name}.push_back({pair[1]});" | ||
| out_tensor_id = out_tensor_id + 1 | ||
| return statement | ||
|
|
||
| return generate_statement() if is_out_tensor_type else "" | ||
|
|
||
| out_str_list = map( | ||
| declare_out_params_epilogue_arguments_assign, | ||
| generated_kernel_arg_id_and_names, | ||
| ) | ||
| str_list = filter( | ||
| lambda ss: ss != "", | ||
| [*in_str_list, *out_str_list] | ||
| ) | ||
| return f"\n{indent}".join(str_list) | ||
|
|
||
| def get_params_input_shape_init_str(self, input_name, input_shape_kargs, indent): | ||
| def init_input_shape_with_args(i): | ||
| def get_creator(): | ||
|
|
@@ -264,9 +308,9 @@ def make_project( | |
| // Note: need to support vectorized operation | ||
| __forceinline__ __host__ __device__ | ||
| T operator()(T x, const Arguments& args, const MatrixCoord& coord) const { | ||
| T out; | ||
| AP_OUTPUTS_INIT | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感觉应该命名成 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 原来的代码模板中对要替换的对象命名没有制定规范,可以制定一个,然后按规范来 |
||
| AP_GENERATED_BINARY_EPILOGUE_STRING | ||
| return out; | ||
| return out0; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里生成的代码是什么样子,在PR描述里面贴一个例子吧
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| } | ||
| }; | ||
|
|
||
|
|
@@ -303,9 +347,10 @@ def make_project( | |
| *cuda_stream_ptr, ${input0}, ${input1}, nullptr, ${output}, ${input0}_shape, ${input1}_shape, std::vector<int64_t>{}); | ||
|
|
||
| std::vector<const void *> epilogue_in_ptrs; | ||
| std::vector<void *> epilogue_out_ptrs; | ||
| AP_PARAMS_EPILOGUE_PTRS_INIT | ||
|
|
||
| params.SetEpilogues(epilogue_in_ptrs); | ||
| params.SetEpilogues(epilogue_in_ptrs, epilogue_out_ptrs); | ||
|
|
||
| #if AP_ENABLE_AUTOTUNE | ||
| AP_AUTOTUNE_${output_dtype}(ap::RunMatmulWithVariadicKernel); | ||
|
|
@@ -321,6 +366,7 @@ def make_project( | |
| code_template.replace( | ||
| "AP_GENERATED_BINARY_EPILOGUE_STRING", trivial_code_str | ||
| ) | ||
| .replace("AP_OUTPUTS_INIT", self.init_outputs()) | ||
| .replace("AP_KERNEL_ARGS_DECLARE", self.get_kernel_arg_list_str()) | ||
| .replace( | ||
| "AP_PARAMS_INPUT0_SHAPE_INIT", | ||
|
|
@@ -336,7 +382,7 @@ def make_project( | |
| ) | ||
| .replace( | ||
| "AP_PARAMS_EPILOGUE_PTRS_INIT", | ||
| self.get_params_epilogue_ptrs_init_str("epilogue_in_ptrs", indent=" "), | ||
| self.get_params_epilogue_ptrs_init_str("epilogue_in_ptrs", "epilogue_out_ptrs", indent=" "), | ||
| ) | ||
| .replace( | ||
| "AP_EPILOGUE_ARGUMENTS_FIELDS", | ||
|
|
@@ -356,7 +402,7 @@ def make_project( | |
| .replace("${k_value}", f"{input0_shape_kargs[-1].value}") | ||
| .replace("${n_value}", f"{input1_shape_kargs[-1].value}") | ||
| ) | ||
|
|
||
| print('cuda code is: ', code) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这地方的代码应该去掉 |
||
| source_dir = "/work/abstract_pass/Athena/tests/ap/matmul" | ||
| cutlass_dir = "/work/abstract_pass/Athena/tests/ap/matmul/cutlass" | ||
| compile_cmd = ( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉