Skip to content

support arbitrary input/ouput and non-3-dim mm_out#14

Open
hxzd5568 wants to merge 1 commit into
lixinqi:apfrom
hxzd5568:upload3
Open

support arbitrary input/ouput and non-3-dim mm_out#14
hxzd5568 wants to merge 1 commit into
lixinqi:apfrom
hxzd5568:upload3

Conversation

@hxzd5568

@hxzd5568 hxzd5568 commented Apr 9, 2025

Copy link
Copy Markdown

变更:

  1. 支持任意输入/输出
  2. 支持非三维 mm_out(仅适用于 b-m-n 模式的矩阵乘法情况)
  3. 通过 umprime 模块支持动态 ReLU
  4. 将 test_matmul_binary_outs 重命名为 test_matmul_epilogue
  5. 从 MatmulEpilogueFusion 中提取 remove-function 作为单独模块

cuda 示例

// auto generated codes
#include <cuda.h>
#include <cuda_fp16.h>
#include <vector>

#include "cutlass_matmul.cuh"
#include "profile.h"

namespace ap {

template <typename T>
struct VariadicEpilogueFunctor {
  struct Arguments {
    const float* in_ptr_0;
    float* out_ptr_0;
    int64_t input0_dim1;
    int64_t input1_dim1;
  };

  // Note: need to support vectorized operation
  __forceinline__ __host__ __device__
  T operator()(T x, const Arguments& args, const MatrixCoord& coord) const {
    T out0, out1;
    float op1_out0 = (args.in_ptr_0[(coord.column)]);
    float op4_out0 = (x + op1_out0);
    out0 = op4_out0;
    float op6_out0 = (0.707107);
    float op7_out0 = (op4_out0 * op6_out0);
    float op8_out0 = (erf(op7_out0));
    float op9_out0 = (1.000000);
    float op10_out0 = (op9_out0 + op8_out0);
    float op11_out0 = (0.500000);
    float op12_out0 = (op4_out0 * op11_out0);
    float op13_out0 = (op12_out0 * op10_out0);
    out1 = op13_out0;
    args.out_ptr_0[(coord.batch * args.input0_dim1 * args.input1_dim1 + coord.row * args.input1_dim1 + coord.column)] = static_cast<float>(out1);
    return out0;
  }
};

template <int TuningConfigId>
static void RunMatmulWithVariadicKernel(const GemmEpilogueParams &params) {
  using ElementT = float;
  using ElementComputeT = float;

  typename VariadicEpilogueFunctor<ElementComputeT>::Arguments epilogue_args;

  epilogue_args.in_ptr_0 = reinterpret_cast<const float *>(params.epilogue_in_ptrs[0]);
  epilogue_args.out_ptr_0 = reinterpret_cast<float *>(params.epilogue_out_ptrs[0]);
  epilogue_args.input0_dim1 = params.input0_shape[1];
  epilogue_args.input1_dim1 = params.input1_shape[1];

  constexpr int AlignA = AP_ALIGNMENT_float(192);
  constexpr int AlignB = AP_ALIGNMENT_float(768);

  CutlassMatmulAddVariadic<ElementT, ElementComputeT, VariadicEpilogueFunctor,
                           AlignA, AlignB, TuningConfigId>(params,
                                                           epilogue_args);
}

} // namespace ap

extern "C" {

void MatmulBinaryKernel(void* stream_ptr, const float* input0, const float* input1, float* output, int64_t input0_dim0, int64_t input0_dim1, int64_t input0_dim2, int64_t input1_dim1, const float* in_ptr_0, float* out_ptr_0) {
  std::vector<int64_t> input0_shape;
  input0_shape.resize(3);
  input0_shape[0] = input0_dim0;
  input0_shape[1] = input0_dim1;
  input0_shape[2] = input0_dim2;

  std::vector<int64_t> input1_shape;
  input1_shape.resize(2);
  input1_shape[0] = input0_dim2;
  input1_shape[1] = input1_dim1;

  cudaStream_t* cuda_stream_ptr = reinterpret_cast<cudaStream_t*>(stream_ptr);
  ap::GemmEpilogueParams params(
      *cuda_stream_ptr, input0, input1, nullptr, output, input0_shape, input1_shape, std::vector<int64_t>{});

  std::vector<const void *> epilogue_in_ptrs;
  std::vector<void *> epilogue_out_ptrs;
  epilogue_in_ptrs.push_back(in_ptr_0);
  epilogue_out_ptrs.push_back(out_ptr_0);
  
  

  params.SetEpilogues(epilogue_in_ptrs, epilogue_out_ptrs);

#if AP_ENABLE_AUTOTUNE
  AP_AUTOTUNE_float(ap::RunMatmulWithVariadicKernel);
#else
  ap::RunMatmulWithVariadicKernel<ap::DefaultConfig::kConfigId>(params);
#endif
}
}

@Xreki Xreki left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意清理下所有文件里面的debug代码

Comment thread tests/ap/kernel_arg_id_util.py Outdated
ir_value = getattr(self.tensor_match_ctx, in_ir_value_name)
print('ir_value: ', ir_value)
kernel_arg_id = self.code_gen_ctx.in_tensor_data_ptr_kernel_arg_id(ir_value)
print('kernel_arg_id: ', kernel_arg_id)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些print需要保留吗?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处已删除,其他已精简

return name

def get_out_tensor_data_ptr_var_name(self, out_ir_value_name):
out_ir_value_name = out_ir_value_name.replace("out", "output")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个replace是必须的吗?

@hxzd5568 hxzd5568 Apr 10, 2025

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out_{i} 是程序的局部变量, output_{i} 是context中注册的ir_name,此命名可以相互区分。采用replace可以使得,局部变量找到全局变量。

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没有解释为啥需要有"out", "output"字面量

Comment thread tests/ap/make_axpr.sh Outdated
TEST_TPL_FILENAME=`echo ${TEST_FILENAME/test_/}`

echo "-- Write 'import ${TEST_FILENAME}' to __main__.py"
echo "import ${TEST_FILENAME}" > __main__.py

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L3 - L9删除吧,直接把所有需要生成json的文件加到FILENAMES_ARRAY里面,pattern文件加到__main__.py里面。

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

AP_OUTPUTS_INIT
AP_GENERATED_BINARY_EPILOGUE_STRING
return out;
return out0;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里生成的代码是什么样子,在PR描述里面贴一个例子吧

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread tests/ap/matmul_binary_tpl.py Outdated
code_template.replace(
"AP_GENERATED_BINARY_EPILOGUE_STRING", trivial_code_str
)
.replace("AP_GENERATED_ELEMENT_DTYPE", output_dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我删掉了,确认一下是否还需要

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -0,0 +1,158 @@
import access_topo_drr

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件名,感觉不是很合适

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以请 @lixinqi 命名一下吗,想了几个感觉都太长了比如:matmul_epilogue_simplify_homeomorphic_subgraph

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实不容易想一个好名字

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

就直接叫matmul_epilogue_pass.py就行

Comment thread tests/ap/op_compute_translator_util.py Outdated
arg_name = mut_kernel_arg_id_registry.get_in_tensor_data_ptr_var_name(data_op_name)
print('arg_name is: ', arg_name)
ptr_var_name = self.kernel_arg_translator.get_use_name(arg_name)
print('ptr_var_name is: ', ptr_var_name)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些print要删除?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该log在之后策略补充中很需要,重名为print('ptr_var of OpLoadFromGlobal is: ', ptr_var_name)

Comment thread tests/ap/umprime.py Outdated
import pir

@access_topo_drr.register_drr_pass("pd_op_static_relu", tag="umprime")
class PdOpCastAccessTopoPass(access_topo_drr.DrrPass):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类名改一下

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return self.all_kernel_arg_id2unique_name.get_or_create(kernel_arg_id, create)

def get_in_tensor_data_ptr_var_name(self, in_ir_value_name):
print('in_ir_value_name: ', in_ir_value_name)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉

return name

def get_out_tensor_data_ptr_var_name(self, out_ir_value_name):
out_ir_value_name = out_ir_value_name.replace("out", "output")

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没有解释为啥需要有"out", "output"字面量

Comment on lines +19 to +23
def is_out_tensor_karg(kernel_arg_id):
kernel_arg_id_type_name = f"{type(kernel_arg_id)}".replace("<class '", "").replace(
"'>", ""
)
return kernel_arg_id_type_name == "OutTensorDataPtrKernelArgId"

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的代码质量不行,如果需要判断kernel_arg_id的类型,那就在c++代码里导出OutTensorDataPtrKernelArgId类变量到python层。

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥要这个is_out_tensor_arg的逻辑?

@Xreki Xreki Apr 10, 2025

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的代码质量不行,如果需要判断kernel_arg_id的类型,那就在c++代码里导出OutTensorDataPtrKernelArgId类变量到python层。

我上个PR中对输入指针参数是这么判断的,我的锅,我来改下。

为啥要这个is_out_tensor_arg的逻辑?

因为之前为Autotune功能设计的ProfileBestConfig函数声明如下,只能接收void(const GemmEpilogueParams &)这种形式的函数。AP里面生成的Kernel函数参数列表是不固定的,需要将所有的维度、指针参数都先存到GemmEpilogueParams中,因此需要区分karg类型来保存。

static int ProfileBestConfig(
    const std::vector<std::function<void(const GemmEpilogueParams &)>>
        &gemm_functions,
    const GemmEpilogueParams &params);

最近想到ProfileBestConfig即使不生成,应该也可以支持可变参数列表,后面有空了可以再来优化下。


def get_out_tensor_statement():
param_name_for_var = self.output_tensor_karg_to_shape_access[var_name]
return f"reinterpret_cast<{output_dtype} *>({params_name}.{param_name_for_var})"

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里难道不会有性能问题吗?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只是在Host端赋值一次,应该不会有太大开销,对应PR描述示例代码里面的如下部分:

  epilogue_args.in_ptr_0 = reinterpret_cast<const float *>(params.epilogue_in_ptrs[0]);
  epilogue_args.out_ptr_0 = reinterpret_cast<float *>(params.epilogue_out_ptrs[0]);

__forceinline__ __host__ __device__
T operator()(T x, const Arguments& args, const MatrixCoord& coord) const {
T out;
AP_OUTPUTS_INIT

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉应该命名成$AP_OUTPUTS_INIT 与普通的c++宏分开。

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原来的代码模板中对要替换的对象命名没有制定规范,可以制定一个,然后按规范来

.replace("${n_value}", f"{input1_shape_kargs[-1].value}")
)

print('cuda code is: ', code)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这地方的代码应该去掉

@@ -0,0 +1,158 @@
import access_topo_drr

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实不容易想一个好名字

@@ -0,0 +1,158 @@
import access_topo_drr

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

就直接叫matmul_epilogue_pass.py就行

mut_lir_code_gen_ctx=mut_lir_code_gen_ctx,
)
data_op_name = inputs[0].var_name
print('data_name of OpLoadFromGlobal is: ', data_op_name)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些地方都去掉

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下同

def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
mut_lir_code_gen_ctx.stmts.append(f"{self.get_out_var_name()} = {inputs[0].var_name};")
out_name = self.get_out_var_name()
mut_kernel_arg_id_registry.get_out_tensor_data_ptr_var_name(out_name) if out_name != "out0" else None

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么总是需要特判呢?"out0"这个字面量很重要吗?为什么一定这样

def __call__(self, inputs, mut_kernel_arg_id_registry, mut_lir_code_gen_ctx):
return inputs

class CinnOpExpandCodeGen:

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些无争议的部分现在一个pr里提交吧

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants