Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
0ec2dec
abstract pass initial commit
lixinqi Feb 13, 2025
df6414a
Merge branch 'develop' of github.com:lixinqi/Paddle into ap
lixinqi Feb 13, 2025
96fea1c
remove unused index_expr code
lixinqi Feb 13, 2025
f15a766
Fix compiling error on CI.
Xreki Feb 17, 2025
aaeed58
Merge branch 'ap' of https://github.com/lixinqi/Paddle into ap
Xreki Feb 19, 2025
35196b4
Change the log level.
Xreki Feb 19, 2025
639a8b3
Rename ap_lower_fusion_op_pass to ap_generic_drr_pass.
Xreki Feb 19, 2025
1e0a138
Rename ap_unary -> ap_variadic, ApUnary -> ApVariadic.
Xreki Feb 19, 2025
01e2686
Fallback to cinn when ap fails, and disable fuse_gemm_epilogue when a…
Xreki Feb 20, 2025
34d490a
Fix compiling error in CI, including: using std::memcpy instead of re…
Xreki Feb 20, 2025
4832e4c
Fix narrowing conversion error and unused value error.
Xreki Feb 21, 2025
4db56bc
Fix missing-field-initializers and unused-result error on CI.
Xreki Feb 21, 2025
a441f80
Fix some sign-compare error on CI.
Xreki Feb 21, 2025
60676fe
Add cmake dependent.
Xreki Feb 25, 2025
6b786d9
Fix some using statement without creating an alias.
Xreki Feb 25, 2025
f36dfe8
Support experimental/type_traits for WIN32.
Xreki Feb 25, 2025
b950804
Fix an unused-but-set-parameter and remove the `typename` in using st…
Xreki Feb 26, 2025
5f6e67a
Disable AP when cinn is not enabled.
Xreki Feb 26, 2025
6b7446c
Merge branch 'develop' into ap
Xreki Mar 20, 2025
4691ca4
Remove the use of Reciprocal because Reciprocal is deleted by #70376.
Xreki Mar 20, 2025
a8e051d
Merge branch 'develop' into ap
Xreki Mar 21, 2025
b031281
Fix "basic_string::_M_construct null not valid" error.
Xreki Mar 21, 2025
ba5c576
Fix typo.
Xreki Mar 21, 2025
9f3b79c
Support meticulous matching (with input/outoput number). Submit by hx…
Xreki Mar 21, 2025
6408d08
remove redundant sentence
hxzd5568 Apr 1, 2025
d058177
Using void* as StreamT.
Xreki Apr 2, 2025
0629cb4
Fix compiling error related to std::optional<StreamT> on gcc12.
Xreki Apr 7, 2025
1119b12
Merge branch 'develop' into ap
Xreki Apr 7, 2025
2255c2f
support no-extra-use for temporary ir value in source pattern
lixinqi Apr 10, 2025
2def5b4
minor fix
lixinqi Apr 10, 2025
2e2b230
minor fix
lixinqi Apr 11, 2025
2598a93
Return the address of stream instead.
Xreki Apr 14, 2025
be2389b
support paddle.cc.*
lixinqi Apr 16, 2025
407122e
fix adt::WeakPtrLock bug
lixinqi Apr 18, 2025
b8afbfe
Merge branch 'ap' of github.com:lixinqi/Paddle into ap
lixinqi Apr 18, 2025
4beae5c
rename all non python standard api's to __builtin__xxx
lixinqi Apr 18, 2025
e457617
Merge branch 'ap' into pcc
lixinqi Apr 18, 2025
8ffef38
move paddle.cc into paddle.incubate.cc
lixinqi Apr 18, 2025
3ebbd5b
Add pcc to setup.py.
Xreki Apr 21, 2025
b0b377d
Add missing modules and return partial_program_layer directly in pcc.…
Xreki Apr 21, 2025
07e0652
Fix the mismatched output numerical order issue
hxzd5568 Apr 21, 2025
9ddfcef
Merge pull request #3 from hxzd5568/order
lixinqi Apr 21, 2025
f5e47cb
Remove force_register_fusion related codes in pcc api.
Xreki Apr 21, 2025
07f7034
Add an argument train.
Xreki Apr 21, 2025
0d357f9
support ap.facade and infer_symbolic/infer_meta in python
lixinqi Apr 25, 2025
a9354f0
merge develop
lixinqi Apr 25, 2025
2abc1ee
minor fix
lixinqi Apr 25, 2025
b9356ff
paddle.cc.ap.FacadeOp
lixinqi Apr 27, 2025
6b9a7c8
support zero inputs for pd_op.ap_facade
lixinqi Apr 28, 2025
b6376b7
Merge branch 'develop' into ap_facade
Xreki May 7, 2025
f7562a1
Fix compiling error in CI and refine some error messages.
Xreki May 7, 2025
a1beb59
Correct the copyright.
Xreki May 7, 2025
01eec7c
Polish error messages and remove some unused header files.
Xreki May 7, 2025
d220539
Fix compiling when cinn is not enabled.
Xreki May 7, 2025
1ecf10f
Add InferMeta to list.
Xreki May 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion paddle/ap/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ cc_library(
SRCS ${ap_pir_srcs}
DEPS ${AP_COMMON_DEPS} ${ap_pir_deps})

file(GLOB_RECURSE ap_hlir_srcs "src/paddle/hlir/*.cc")
set(ap_hlir_deps axpr ap_drr ap_pir)
cc_library(
ap_hlir
SRCS ${ap_hlir_srcs}
DEPS ${AP_COMMON_DEPS} ${ap_hlir_deps})

file(GLOB_RECURSE ap_reified_drr_srcs "src/reified_drr/*.cc")
set(ap_reified_drr_deps axpr ap_drr ap_code_module ap_code_gen)
cc_library(
Expand All @@ -60,7 +67,7 @@ cc_library(
DEPS ${AP_COMMON_DEPS} ${ap_reified_drr_deps})

file(GLOB_RECURSE ap_pass_srcs "src/paddle/pass/*.cc")
set(ap_pass_deps axpr ap_pir ap_drr ap_code_module ap_code_gen ap_reified_drr)
set(ap_pass_deps axpr ap_hlir ap_drr ap_code_module ap_code_gen ap_reified_drr)
cc_library(
ap_pass
SRCS ${ap_pass_srcs}
Expand Down
25 changes: 24 additions & 1 deletion paddle/ap/include/axpr/attr_map_method_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,35 @@ struct AttrMapMethodClass {
}
};

template <typename ValueT>
struct TypeImplBuiltinAttrMapMethodClass {
using This = TypeImplBuiltinAttrMapMethodClass;
using Self = TypeImpl<AttrMap<ValueT>>;

adt::Result<ValueT> Call(const Self&) { return &This::StaticConstruct; }

static adt::Result<ValueT> StaticConstruct(const ValueT&,
const std::vector<ValueT>& args) {
return This{}.Construct(args);
}

adt::Result<ValueT> Construct(const std::vector<ValueT>& args) {
const auto& packed_args = CastToPackedArgs(args);
const auto& [pos_args, kwargs] = *packed_args;
ADT_CHECK(pos_args->empty())
<< adt::errors::TypeError{std::string() +
"the construct of AttrMap "
"takes no positional arguments."};
return kwargs;
}
};

template <typename ValueT>
struct MethodClassImpl<ValueT, AttrMap<ValueT>>
: public AttrMapMethodClass<ValueT> {};

template <typename ValueT>
struct MethodClassImpl<ValueT, TypeImpl<AttrMap<ValueT>>>
: public EmptyMethodClass<ValueT> {};
: public TypeImplBuiltinAttrMapMethodClass<ValueT> {};

} // namespace ap::axpr
1 change: 1 addition & 0 deletions paddle/ap/include/axpr/binary_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace ap::axpr {
_(Sub, -) \
_(Mul, *) \
_(Div, /) \
_(FloorDiv, /) \
_(Mod, %) \
_(EQ, ==) \
_(NE, !=) \
Expand Down
52 changes: 52 additions & 0 deletions paddle/ap/include/axpr/builtin_class_instance_method_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,58 @@ struct MethodClassImpl<ValueT, BuiltinClassInstance<ValueT>> {
return class_ops->Equals(self, rhs_val);
}

adt::Result<ValueT> Add(InterpreterBase<ValueT>* interpreter,
const Self& self,
const ValueT& rhs_val) {
const auto& opt_func = GetClassAttr(self, "__add__");
const auto& class_attrs = self.type.class_attrs();
ADT_CHECK(opt_func.has_value())
<< adt::errors::AttributeError{std::string() + class_attrs->class_name +
" class has no attribute '__add__'"};
std::vector<ValueT> args{rhs_val};
ADT_LET_CONST_REF(ret, interpreter->InterpretCall(opt_func.value(), args));
return ret;
}

adt::Result<ValueT> Sub(InterpreterBase<ValueT>* interpreter,
const Self& self,
const ValueT& rhs_val) {
const auto& opt_func = GetClassAttr(self, "__sub__");
const auto& class_attrs = self.type.class_attrs();
ADT_CHECK(opt_func.has_value())
<< adt::errors::AttributeError{std::string() + class_attrs->class_name +
" class has no attribute '__sub__'"};
std::vector<ValueT> args{rhs_val};
ADT_LET_CONST_REF(ret, interpreter->InterpretCall(opt_func.value(), args));
return ret;
}

adt::Result<ValueT> Mul(InterpreterBase<ValueT>* interpreter,
const Self& self,
const ValueT& rhs_val) {
const auto& opt_func = GetClassAttr(self, "__mul__");
const auto& class_attrs = self.type.class_attrs();
ADT_CHECK(opt_func.has_value())
<< adt::errors::AttributeError{std::string() + class_attrs->class_name +
" class has no attribute '__mul__'"};
std::vector<ValueT> args{rhs_val};
ADT_LET_CONST_REF(ret, interpreter->InterpretCall(opt_func.value(), args));
return ret;
}

adt::Result<ValueT> FloorDiv(InterpreterBase<ValueT>* interpreter,
const Self& self,
const ValueT& rhs_val) {
const auto& opt_func = GetClassAttr(self, "__floordiv__");
const auto& class_attrs = self.type.class_attrs();
ADT_CHECK(opt_func.has_value()) << adt::errors::AttributeError{
std::string() + class_attrs->class_name +
" class has no attribute '__floordiv__'"};
std::vector<ValueT> args{rhs_val};
ADT_LET_CONST_REF(ret, interpreter->InterpretCall(opt_func.value(), args));
return ret;
}

adt::Result<ValueT> GetItem(InterpreterBase<ValueT>* interpreter,
const Self& self,
const ValueT& idx_val) {
Expand Down
1 change: 1 addition & 0 deletions paddle/ap/include/axpr/builtin_frame_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void VisitEachBuiltinFrameAttr(const YieldT& Yield) {
Yield("__builtin_not__", &BuiltinNot);

Yield("__builtin__foreach", &ForEach);

auto YieldTwice = [&](const auto& name, const auto& value) {
Yield(name, value);
Yield(std::string("__builtin__") + name, value);
Expand Down
38 changes: 38 additions & 0 deletions paddle/ap/include/axpr/dim_expr_method_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "paddle/pir/include/dialect/shape/utils/dim_expr.h"

namespace ap::axpr {
template <typename ValueT>
axpr::TypeImpl<axpr::BuiltinClassInstance<ValueT>> GetDimExprClass();

template <typename ValueT>
struct DimExprMethodClass {
Expand All @@ -41,6 +43,38 @@ struct DimExprMethodClass {
return hash_value;
}

static adt::Result<ValueT> Add(const ValueT& self_val,
const std::vector<ValueT>& args) {
ADT_LET_CONST_REF(lhs, self_val.template CastTo<Self>());
ADT_CHECK(args.size() == 1);
ADT_LET_CONST_REF(rhs, args.at(0).template CastTo<Self>());
return GetDimExprClass<ValueT>().New(lhs + rhs);
}

static adt::Result<ValueT> Sub(const ValueT& self_val,
const std::vector<ValueT>& args) {
ADT_LET_CONST_REF(lhs, self_val.template CastTo<Self>());
ADT_CHECK(args.size() == 1);
ADT_LET_CONST_REF(rhs, args.at(0).template CastTo<Self>());
return GetDimExprClass<ValueT>().New(lhs - rhs);
}

static adt::Result<ValueT> Mul(const ValueT& self_val,
const std::vector<ValueT>& args) {
ADT_LET_CONST_REF(lhs, self_val.template CastTo<Self>());
ADT_CHECK(args.size() == 1);
ADT_LET_CONST_REF(rhs, args.at(0).template CastTo<Self>());
return GetDimExprClass<ValueT>().New(lhs * rhs);
}

static adt::Result<ValueT> FloorDiv(const ValueT& self_val,
const std::vector<ValueT>& args) {
ADT_LET_CONST_REF(lhs, self_val.template CastTo<Self>());
ADT_CHECK(args.size() == 1);
ADT_LET_CONST_REF(rhs, args.at(0).template CastTo<Self>());
return GetDimExprClass<ValueT>().New(lhs / rhs);
}

static adt::Result<ValueT> Match(axpr::InterpreterBase<ValueT>* interpreter,
const ValueT& self_val,
const std::vector<ValueT>& packed_args_val) {
Expand Down Expand Up @@ -93,6 +127,10 @@ axpr::TypeImpl<axpr::BuiltinClassInstance<ValueT>> GetDimExprClass() {
static auto cls(
axpr::MakeBuiltinClass<ValueT>("DimExpr", [&](const auto& Define) {
Define("__str__", &Impl::ToString);
Define("__add__", &Impl::Add);
Define("__sub__", &Impl::Sub);
Define("__mul__", &Impl::Mul);
Define("__floordiv__", &Impl::FloorDiv);
Define("__hash__", &Impl::Hash);
Define("match", &Impl::Match);
}));
Expand Down
1 change: 1 addition & 0 deletions paddle/ap/include/axpr/type_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ AttrMap<ValueT> GetObjectTypeName2Type() {
OrderedDict<ValueT>,
MutableOrderedDict<ValueT>,
AttrMap<axpr::SerializableValue>,
AttrMap<ValueT>,
ValueImplTypes...>::Call(&object);
return object;
}
Expand Down
5 changes: 1 addition & 4 deletions paddle/ap/include/drr/builtin_frame_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ void VisitEachBuiltinFrameClass(const DoEachT& DoEach) {
DoEach(drr::Type<DrrCtx>{}.GetClass());
}

template <typename VisitorT>
ap::axpr::AttrMap<axpr::Value> MakeBuiltinFrameAttrMap(
const VisitorT& Visitor) {
inline ap::axpr::AttrMap<axpr::Value> MakeBuiltinFrameAttrMap() {
ap::axpr::AttrMap<axpr::Value> attr_map;
ap::axpr::VisitEachBuiltinFrameAttr<axpr::Value>(
[&](const std::string& k, const axpr::Value& v) { attr_map->Set(k, v); });
Expand All @@ -38,7 +36,6 @@ ap::axpr::AttrMap<axpr::Value> MakeBuiltinFrameAttrMap(
attr_map->Set(std::string("__builtin__") + cls.Name(), cls);
};
VisitEachBuiltinFrameClass(Insert);
Visitor(Insert);
return attr_map;
}

Expand Down
7 changes: 2 additions & 5 deletions paddle/ap/include/drr/drr_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@ namespace ap::drr {

class DrrInterpreter {
public:
explicit DrrInterpreter(
const axpr::TypeImpl<axpr::BuiltinClassInstance<axpr::Value>>&
backend_ir_ctx,
const std::weak_ptr<ap::memory::CirclableRefListBase>&
circlable_ref_list);
explicit DrrInterpreter(const std::weak_ptr<ap::memory::CirclableRefListBase>&
circlable_ref_list);

using Function = ap::axpr::Value;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include "paddle/ap/include/axpr/attr_map.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/pir/include/core/builder.h"
Expand All @@ -25,6 +26,22 @@

namespace ap::dialect {

class IR_API FacadeOp
: public pir::Op<FacadeOp, ::paddle::dialect::InferSymbolicShapeInterface> {
public:
using Op::Op;
static const char *name() { return "ap_op.facade"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const pir::AttributeMap &attributes,
const std::vector<pir::Type> &output_types);
void VerifySig() const {}
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
};

class IR_API UpSpiderOp
: public pir::Op<UpSpiderOp,
pir::SideEffectTrait,
Expand Down Expand Up @@ -134,6 +151,7 @@ class IR_API StoreToGlobalOp

} // namespace ap::dialect

IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::FacadeOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::UpSpiderOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::DownSpiderOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::LoadFromRegisterOp);
Expand Down
7 changes: 4 additions & 3 deletions paddle/ap/include/paddle/meta_tensor_ptr_method_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ struct MetaTensorPtrMethodClass {

adt::Result<axpr::Value> SetDims(const Self& self,
const axpr::Value& dims_val) {
if (dims_val.CastableTo<DDim>()) {
ADT_LET_CONST_REF(ddim, dims_val.CastTo<DDim>());
return SetDimsByDDim(self, ddim);
}
return dims_val.Match(
[&](const DDim& ddims) -> adt::Result<axpr::Value> {
return SetDimsByDDim(self, ddims);
},
[&](const adt::List<axpr::Value>& list) -> adt::Result<axpr::Value> {
return SetDimsByIntList(self, list);
},
Expand Down
41 changes: 41 additions & 0 deletions paddle/ap/include/paddle/pass/convert_pd_facade_to_ap_facade.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <optional>
#include "paddle/pir/include/pass/pass.h"

namespace ap::memory {

class CirclableRefListBase;

}

namespace ap::axpr {

struct Value;

}

namespace cinn {
namespace dialect {
namespace ir {

std::unique_ptr<::pir::Pass> CreateConvertPdFacadeToApFacadePass();

} // namespace ir
} // namespace dialect
} // namespace cinn
2 changes: 1 addition & 1 deletion paddle/ap/include/paddle/pass/ir_helper_method_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
#include "paddle/ap/include/axpr/callable_helper.h"
#include "paddle/ap/include/axpr/lambda_expr_builder.h"
#include "paddle/ap/include/drr/drr_value_helper.h"
#include "paddle/ap/include/paddle/hlir/op_dialect.h"
#include "paddle/ap/include/paddle/pass/ap_drr_helper.h"
#include "paddle/ap/include/paddle/pass/ap_generic_drr_pass.h"
#include "paddle/ap/include/paddle/pass/ir_helper.h"
#include "paddle/ap/include/paddle/pir/op_dialect.h"
#include "paddle/ap/include/paddle/pir/packed_ir_op_inner_source_pattern_helper.h"
#include "paddle/ap/include/paddle/pir/pass_manager_method_class.h"
#include "paddle/ap/include/paddle/pir/pass_method_class.h"
Expand Down
9 changes: 9 additions & 0 deletions paddle/ap/include/paddle/phi/ap_infer_meta_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
#pragma once

#include "paddle/ap/include/adt/adt.h"
#include "paddle/ap/include/axpr/attr_map.h"
#include "paddle/ap/include/axpr/core_expr.h"
#include "paddle/ap/include/axpr/value.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/pir/include/core/operation_utils.h"

namespace phi {

Expand All @@ -29,6 +32,12 @@ struct ApInferMetaHelper {
adt::Result<adt::Ok> InferMeta(const std::string& lambda,
const std::vector<const MetaTensor*>* inputs,
std::vector<MetaTensor*>* outputs);

adt::Result<adt::Ok> InferMetaByAxprHook(
const ::paddle::optional<std::vector<const MetaTensor*>>& inputs,
const std::string& infer_meta_func_name,
const std::string& serialized_attributes,
const std::vector<MetaTensor*>& outputs);
};

} // namespace phi
Loading
Loading