Skip to content

[AP] Add paddle.cc.ap.FacadeOp: a custom op machanism for ap pass only#72525

Merged
Xreki merged 55 commits into
PaddlePaddle:developfrom
lixinqi:ap_facade
May 8, 2025
Merged

[AP] Add paddle.cc.ap.FacadeOp: a custom op machanism for ap pass only#72525
Xreki merged 55 commits into
PaddlePaddle:developfrom
lixinqi:ap_facade

Conversation

@lixinqi

@lixinqi lixinqi commented Apr 27, 2025

Copy link
Copy Markdown
Contributor

PR Category

CINN

PR Types

New features

Description

pcard-76996

动机

当我们将用户的custom op升级成apass时,要求现在pir/phi系统里引入该custom op。用户不得不重新编译paddle代码,这就和ap的理念不符,且拖慢开发速度。
为此,我们设计了通用的pd_op.ap_facadeap_op.facade两个op,前者用于python层配置,后者用于drr pass匹配。实际上,用户用paddle.cc.ap.FacadeOp机制能更便利地添加新扩展算子。

使用示例

一个能跑通的示例,lixinqi/Athena#37

  • 使用facade_op实现facade_matmul,infermeta和infer_symbolic利用特殊shape hack实现,能跑通pir转换、代码生成、program执行全流程。

python

# file path: ./tuple_identity_op.py
import paddle.incubate.cc as pcc
import paddle.incubate.cc.typing as pct

import paddle

class TupleIdentityOp(pcc.ap.FacadeOp):

    def __init__(self):
        super().__init__()

    def custom_op_name(self) -> str:
        return "ap_custom_op.tuple_identity"

    def infer_meta(self) -> str:
        return "tuple_identity_util.infer_meta"

    def infer_symbolic(self) -> str:
        return "tuple_identity_util.infer_symbolic"

    def num_inputs(self) -> int:
        return -1 # 可变

    def num_outputs(self, args) -> int:
        return len(args)

    def attributes_schema(self):
        # annotations matter.
        pass

tie = TupleIdentityOp()

N = pct.DimVar(1024)
M = pct.DimVar(256)
DType = pct.DTypeVar("T","float32")

def foo(
    x: pct.Tensor([N], DType),
    y: pct.Tensor([M], DType)
):
    x, _ = tie([x, y])
    return x

fused_foo = pcc.compile(
    foo,
    ap_path=f"os.path.dirname(__file__)/axpr",
)

axpr

从代码中可以看出,TupleIdentityOp需要axpr代码文件tuple_identity_util.py的infer_meta和infer_symbolic函数。

# file path: ./axpr/tuple_identity_util.py

def infer_symbolic(infer_ctx, inputs, attrs):
    return inputs

def infer_meta(inputs, attrs, mut_outputs):

    def copy_meta(i):
        mut_outputs[i].dims = inputs[i].dims
        mut_outputs[i].dtype = inputs[i].dtype

    map(copy_meta, range(len(inputs)))

运行方法

python -m paddle.incubate.cc.ap.py_to_axpr_json $(pwd)/axpr/tuple_identity_util.py $(pwd)/axpr/tuple_identity_util.py.json 
AP_PATH=$(pwd)/axpr python3 tuple_identity_op.py

之后我们会在输出日志里看到如下日志(可能不是位于末尾):

E0427 15:36:22.238061 43694 add_cinn_pass.cc:231] before ConvertPdFacadeToApFacadePass:
{
    (%0) = "pd_op.data" () {dtype:float32,name:"x",place:Place(undefined:0),shape:[1024],stop_gradient:[false]} : () -> tensor<1024xf32>
    (%1) = "pd_op.data" () {dtype:float32,name:"y",place:Place(undefined:0),shape:[256],stop_gradient:[false]} : () -> tensor<256xf32>
    (%2) = "builtin.combine" (%0, %1) {stop_gradient:[false]} : (tensor<1024xf32>, tensor<256xf32>) -> vec[tensor<1024xf32>,tensor<256xf32>]
    (%3) = "pd_op.ap_facade" (%2) {custom_op_name:"ap_custom_op.tuple_identity",infer_meta_func_name:"tuple_identity_util.infer_meta",infer_symbolic_func_name:"tuple_identity_util.infer_symbolic",num_outputs:2,serialized_attributes:"[\"__builtin_let__\", [[\"___0\", [\"__builtin__AttrMap\"]], [\"___1\", [\"__builtin_identity__\", \"___0\"]]], [\"__builtin_identity__\", \"___1\"]]",stop_gradient:[false]} : (vec[tensor<1024xf32>,tensor<256xf32>]) -> vec[tensor<1024xf32>,tensor<256xf32>]
    (%4, %5) = "builtin.split" (%3) {stop_gradient:[false,false]} : (vec[tensor<1024xf32>,tensor<256xf32>]) -> tensor<1024xf32>, tensor<256xf32>
    () = "builtin.shadow_output" (%4) {output_name:"output_0"} : (tensor<1024xf32>) -> 
}
E0427 15:36:22.238247 43694 add_cinn_pass.cc:236] after ConvertPdFacadeToApFacadePass:
{
    (%0) = "pd_op.data" () {dtype:float32,name:"x",place:Place(undefined:0),shape:[1024],stop_gradient:[false]} : () -> tensor<1024xf32>
    (%1) = "pd_op.data" () {dtype:float32,name:"y",place:Place(undefined:0),shape:[256],stop_gradient:[false]} : () -> tensor<256xf32>
    (%2, %3) = "ap_op.facade" (%0, %1) {__original_serialized_attributes__:"[\"__builtin_let__\", [[\"___0\", [\"__builtin__AttrMap\"]], [\"___1\", [\"__builtin_identity__\", \"___0\"]]], [\"__builtin_identity__\", \"___1\"]]",custom_op_name:"ap_custom_op.tuple_identity",infer_meta_func_name:"tuple_identity_util.infer_meta",infer_symbolic_func_name:"tuple_identity_util.infer_symbolic"} : (tensor<1024xf32>, tensor<256xf32>) -> tensor<1024xf32>, tensor<256xf32>
    () = "builtin.shadow_output" (%2) {output_name:"output_0"} : (tensor<1024xf32>) -> 
}

从日志中可以看出,pd_op.ap_facade前后总是伴随builtin.combine和builtin.split,这是因为它的输入输出是tensor list。显然ap_op.facade比它更容易在drr_pass内用起来。
注意看,上述日志中ap_op.facade算子的custom_op_name的值为"ap_custom_op.tuple_identity",此字段在drr pass的用法为:

class DrrPass:
    def source_pattern(self, o, t):
        o.tuple_identity = o.ap_native_op("ap_op.facade")
        o.tuple_identity.custom_op_name = pir.a_str("ap_custom_op.tuple_identity") # 无须在constrain函数里再做约束
        ...

infer_meta 相关接口

输入的MetaTensor

// file path: paddle/ap/include/paddle/const_meta_tensor_ptr_method_class.h

inline axpr::TypeImpl<axpr::BuiltinClassInstance<axpr::Value>>
GetConstMetaTensorPtrClass() {
  using Impl = ConstMetaTensorPtrMethodClass;
  static auto cls(axpr::MakeBuiltinClass<axpr::Value>(
      "ConstMetaTensorPtr", [&](const auto& Define) {
        Define("__str__", &Impl::ToString);
        Define("__hash__", &Impl::Hash);
        Define("__getattr__", &Impl::GetAttr);
      }));
  return axpr::MakeGlobalNaiveClassOps<typename Impl::Self>(cls);
}

输出的MetaTensor

// file path: paddle/ap/include/paddle/meta_tensor_ptr_method_class.h


inline axpr::TypeImpl<axpr::BuiltinClassInstance<axpr::Value>>
GetMetaTensorPtrClass() {
  using Impl = MetaTensorPtrMethodClass;
  static auto cls(axpr::MakeBuiltinClass<axpr::Value>(
      "MetaTensorPtr", [&](const auto& Define) {
        Define("__str__", &Impl::ToString);
        Define("__hash__", &Impl::Hash);
        Define("__getattr__", &Impl::GetAttr);
        Define("__setattr__", &Impl::SetAttr);
      }));
  return axpr::MakeGlobalNaiveClassOps<typename Impl::Self>(cls);
}

infer_symbolic 相关接口

pir::InferSymbolicShapeContext*

// file path: paddle/ap/src/paddle/pir/infer_symbolic_shape_context_method_class.cc

axpr::TypeImpl<axpr::BuiltinClassInstance<axpr::Value>>
GetPirInferSymbolicShapeContextClass() {
  static auto cls(axpr::MakeBuiltinClass<axpr::Value>(
      "PirInferSymbolicShapeContext", [&](const auto& Yield) {
        Yield("max", &Max);
        Yield("min", &Min);
        Yield("broadcast", &Broadcast);
        Yield("new_symbolic_name", &NewSymbolicName);
        Yield("add_equal_cstr", &AddEqualCstr);
        Yield("is_equal", &IsEqual);
        Yield("add_greater_than_one_cstr", &AddGreatThanOneCstr);
        Yield("is_greater_than_one", &IsGreatThanOne);
        Yield("add_broadcastable_cstr", &AddBroadcastableCstr);
        Yield("is_broadcastable", &IsBroadcastable);
      }));
  return axpr::MakeGlobalNaiveClassOps<pir::InferSymbolicShapeContext*>(cls);
}

symbol::ShapeOrDataDimExprs

其中的match用法类似axpr python里的pir attribute

// file path: paddle/ap/src/paddle/pir/shape_or_data_method_class.cc

axpr::TypeImpl<axpr::BuiltinClassInstance<axpr::Value>>
GetPirShapeOrDataClass() {
  static auto cls(axpr::MakeBuiltinClass<axpr::Value>(
      "PirShapeOrData", [&](const auto& Yield) {
        Yield("__str__", &PirShapeOrDataString);
        Yield("get_type_name", &PirShapeOrDataGetTypeName);
        Yield("match", &PirShapeOrDataMatch);
      }));
  return axpr::MakeGlobalNaiveClassOps<symbol::ShapeOrDataDimExprs>(cls);
}

ShapeOrDataDimExprs具体alternative对象的构造统一用 pir模块进行

// file path: paddle/ap/src/paddle/pir/shape_or_data_method_class.cc

template <typename Builder>
void DefineMethods(Builder* m) {

  ...

  ForEachShapeOrDataMaker(
      [&](const auto& name, const auto& value) { m->Def(name, value); });
}
// file path: paddle/ap/include/paddle/pir/shape_or_data_method_class.h

template <typename YieldT>
void ForEachShapeOrDataMaker(const YieldT& Yield) {
  Yield("s_null", &MakeNullShapeOrDataDimExpr);
  Yield("s_tensor_shape_or_data", &MakeTensorShapeOrDataDimExprs);
  Yield("s_tensor_list_shape_or_data", &MakeTensorListShapeOrDataDimExprs);
  Yield("s_ranked_tensor_array_shape_or_data",
        &MakeRankedTensorArrayShapeOrDataDimExprs);
}

symbol::DimExpr

file path: paddle/ap/include/axpr/dim_expr_method_class.h


template <typename ValueT>
axpr::TypeImpl<axpr::BuiltinClassInstance<ValueT>> GetDimExprClass() {
  using Impl = DimExprMethodClass<ValueT>;
  static auto cls(
      axpr::MakeBuiltinClass<ValueT>("DimExpr", [&](const auto& Define) {
        Define("__str__", &Impl::ToString);
        Define("__add__", &Impl::Add);
        Define("__sub__", &Impl::Sub);
        Define("__mul__", &Impl::Mul);
        Define("__floordiv__", &Impl::FloorDiv);
        Define("__hash__", &Impl::Hash);
        Define("match", &Impl::Match);
      }));
  return axpr::MakeGlobalNaiveClassOps<typename Impl::Self>(cls);
}

lixinqi and others added 30 commits February 13, 2025 07:24
@paddle-bot

paddle-bot Bot commented Apr 27, 2025

Copy link
Copy Markdown

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot Bot added the contributor External developers label Apr 27, 2025
@codecov-commenter

Copy link
Copy Markdown

Codecov Report

Attention: Patch coverage is 30.23715% with 353 lines in your changes missing coverage. Please review.

Please upload report for BASE (develop@a091b78). Learn more about missing BASE report.

Files with missing lines Patch % Lines
python/paddle/incubate/cc/ap/py_to_axpr_json.py 25.86% 235 Missing ⚠️
...thon/paddle/incubate/cc/ap/pir_attrs_serializer.py 38.75% 79 Missing ⚠️
python/paddle/incubate/cc/ap/facade_op.py 33.33% 30 Missing ⚠️
paddle/phi/infermeta/multiary.cc 0.00% 6 Missing ⚠️
...tor/interface/infer_symbolic_shape/ap_infer_sym.cc 0.00% 2 Missing ⚠️
python/paddle/incubate/cc/data_type_util.py 66.66% 1 Missing ⚠️

❌ Your patch status has failed because the patch coverage (30.23%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #72525   +/-   ##
==========================================
  Coverage           ?   30.23%           
==========================================
  Files              ?        9           
  Lines              ?      506           
  Branches           ?        0           
==========================================
  Hits               ?      153           
  Misses             ?      353           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@XiaoguangHu01 XiaoguangHu01 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Xreki Xreki changed the title paddle.cc.ap.FacadeOp: a custom op machanism for ap pass only [AP] paddle.cc.ap.FacadeOp: a custom op machanism for ap pass only May 8, 2025
@Xreki Xreki changed the title [AP] paddle.cc.ap.FacadeOp: a custom op machanism for ap pass only [AP] Add paddle.cc.ap.FacadeOp: a custom op machanism for ap pass only May 8, 2025
@Xreki Xreki merged commit 30e5f93 into PaddlePaddle:develop May 8, 2025
53 of 56 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants