Skip to content

Support non-3d tensors.#25

Open
Xreki wants to merge 4 commits into
lixinqi:apfrom
Xreki:support_non3d
Open

Support non-3d tensors.#25
Xreki wants to merge 4 commits into
lixinqi:apfrom
Xreki:support_non3d

Conversation

@Xreki

@Xreki Xreki commented Apr 17, 2025

Copy link
Copy Markdown

背景

AP系统当前只支持了3-D输入,非3-D输入在生成Epilogue index计算代码的时候会报错。针对非3-D输入,有2种解决思路(以如下子图结构为例):

{
    (%0) = "pd_op.data" () {dtype:float16,name:"x",place:Place(undefined:0),shape:[4,16,4096,128],stop_gradient:[false]} : () -> tensor<4x16x4096x128xf16>
    (%1) = "pd_op.data" () {dtype:float16,name:"y",place:Place(undefined:0),shape:[128,32],stop_gradient:[false]} : () -> tensor<128x32xf16>                                        
    (%2) = "pd_op.data" () {dtype:float16,name:"b",place:Place(undefined:0),shape:[4,16,4096,32],stop_gradient:[false]} : () -> tensor<4x16x4096x32xf16>
    (%3) = "pd_op.matmul" (%0, %1) {stop_gradient:[false],transpose_x:false,transpose_y:false} : (tensor<4x16x4096x128xf16>, tensor<128x32xf16>) -> tensor<4x16x4096x32xf16>
    (%4) = "cinn_op.fusion" () -> tensor<4x16x4096x32xf16> {
        (%5) = "pd_op.add" (%3, %2) {stop_gradient:[false]} : (tensor<4x16x4096x32xf16>, tensor<4x16x4096x32xf16>) -> tensor<4x16x4096x32xf16>
        (%6) = "pd_op.full" () {dtype:float16,place:Place(undefined:0),shape:[4,16,4096,32],stop_gradient:[true],value:0} : () -> tensor<4x16x4096x32xf16>
        (%7) = "pd_op.maximum" (%5, %6) {comp_op_name:"pd_op.relu",stop_gradient:[false]} : (tensor<4x16x4096x32xf16>, tensor<4x16x4096x32xf16>) -> tensor<4x16x4096x32xf16>
        () = "cf.yield" (%7) {} : (tensor<4x16x4096x32xf16>) ->  
    }   
    () = "builtin.shadow_output" (%4) {output_name:"output_0"} : (tensor<4x16x4096x32xf16>) ->  
}
  • 生成代码时,在子图前后插入reshape算子。涉及的工作包括:
    1. 插入2个reshape算子,要准确地计算好reshape算子的目标shape,变换后如下
      image

    2. 子图shape推导,更新其他算子的输入、输出Shape

    3. 针对full这类特殊算子,需要更新shape属性

    4. 添加reshape算子计算的代码生成器

  • 根据实际shape修正index计算

解决方案

插入reshape算子方案工作量较大,且容易引入错误,所以本PR采取修正index计算的解决方案。具体思路如下:

  • 当矩阵乘输入Tensor是2-D时,其实只有2个坐标,故修改get_anchor_iter_var_names()函数,只返回["coord.row", "coord.column"]。生成的Epilogue代码如下:
    float op2_out0 = static_cast<float>(args.in_ptr_0[(coord.row * args.input1_dim1 + coord.column)]);
    float op3_out0 = static_cast<float>(x + op2_out0);
    float op4_out0 = static_cast<float>((op3_out0 > 0 ? op3_out0 : 0) );
    out = op4_out0;
  • 当矩阵乘输入Tensor是4-D以上时,引入一个和anchor_iter_var_names一样长的list anchor_iter_dim_splits,来记录anchor_iter_var_names中每个坐标在原Tensor shape中占的维度的个数。
    • 对于4-D以上的Tensor,目前只支持[b0, ..., bx, M, N],即前面Rank-2个维度都是Batch维度,因此anchor_iter_dim_splits = [Rank - 2, 1, 1]。针对[b0, ..., bx, M0, ..., My, N0, ..., Nz]的情况,后续也支持扩展。
    • iter_dim_splits也需要保存到IndexCodeGenValue中,sum算子的输出iter_dim_splits会发生改变。

@@ -1,4 +1,5 @@
class IndexCodeGenValue:
def __init__(self, iter_var_names):
def __init__(self, iter_var_names, iter_dim_splits):

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iter_dim_splits超出了IndexCodeGenValue的语义

@lixinqi

lixinqi commented Apr 18, 2025

Copy link
Copy Markdown
Owner

为了支持4D以上的matmul,我们让底层的算子接受了语义不清晰的参数,这是不能接受的行为。

@lixinqi

lixinqi commented Apr 18, 2025

Copy link
Copy Markdown
Owner

不论是IndexCodeGenValue,还有sum算子,都没有iter_dim_splits语义。强行扩张只会引入更多bug。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants