diff --git a/paddle/ap/CMakeLists.txt b/paddle/ap/CMakeLists.txt index 32bf8680011333..e70cec9ce7fb98 100644 --- a/paddle/ap/CMakeLists.txt +++ b/paddle/ap/CMakeLists.txt @@ -52,6 +52,13 @@ cc_library( SRCS ${ap_pir_srcs} DEPS ${AP_COMMON_DEPS} ${ap_pir_deps}) +file(GLOB_RECURSE ap_hlir_srcs "src/paddle/hlir/*.cc") +set(ap_hlir_deps axpr ap_drr ap_pir) +cc_library( + ap_hlir + SRCS ${ap_hlir_srcs} + DEPS ${AP_COMMON_DEPS} ${ap_hlir_deps}) + file(GLOB_RECURSE ap_reified_drr_srcs "src/reified_drr/*.cc") set(ap_reified_drr_deps axpr ap_drr ap_code_module ap_code_gen) cc_library( @@ -60,7 +67,7 @@ cc_library( DEPS ${AP_COMMON_DEPS} ${ap_reified_drr_deps}) file(GLOB_RECURSE ap_pass_srcs "src/paddle/pass/*.cc") -set(ap_pass_deps axpr ap_pir ap_drr ap_code_module ap_code_gen ap_reified_drr) +set(ap_pass_deps axpr ap_hlir ap_drr ap_code_module ap_code_gen ap_reified_drr) cc_library( ap_pass SRCS ${ap_pass_srcs} diff --git a/paddle/ap/include/axpr/attr_map_method_class.h b/paddle/ap/include/axpr/attr_map_method_class.h index 546b0491698daf..efb9de852a415c 100644 --- a/paddle/ap/include/axpr/attr_map_method_class.h +++ b/paddle/ap/include/axpr/attr_map_method_class.h @@ -50,12 +50,35 @@ struct AttrMapMethodClass { } }; +template +struct TypeImplBuiltinAttrMapMethodClass { + using This = TypeImplBuiltinAttrMapMethodClass; + using Self = TypeImpl>; + + adt::Result Call(const Self&) { return &This::StaticConstruct; } + + static adt::Result StaticConstruct(const ValueT&, + const std::vector& args) { + return This{}.Construct(args); + } + + adt::Result Construct(const std::vector& args) { + const auto& packed_args = CastToPackedArgs(args); + const auto& [pos_args, kwargs] = *packed_args; + ADT_CHECK(pos_args->empty()) + << adt::errors::TypeError{std::string() + + "the construct of AttrMap " + "takes no positional arguments."}; + return kwargs; + } +}; + template struct MethodClassImpl> : public AttrMapMethodClass {}; template struct MethodClassImpl>> - : public EmptyMethodClass {}; + : public TypeImplBuiltinAttrMapMethodClass {}; } // namespace ap::axpr diff --git a/paddle/ap/include/axpr/binary_func.h b/paddle/ap/include/axpr/binary_func.h index 3a2b7f3308ebb7..ba7110417d8bdc 100644 --- a/paddle/ap/include/axpr/binary_func.h +++ b/paddle/ap/include/axpr/binary_func.h @@ -23,6 +23,7 @@ namespace ap::axpr { _(Sub, -) \ _(Mul, *) \ _(Div, /) \ + _(FloorDiv, /) \ _(Mod, %) \ _(EQ, ==) \ _(NE, !=) \ diff --git a/paddle/ap/include/axpr/builtin_class_instance_method_class.h b/paddle/ap/include/axpr/builtin_class_instance_method_class.h index 66b280e0a82a7b..e6b49f48798a55 100644 --- a/paddle/ap/include/axpr/builtin_class_instance_method_class.h +++ b/paddle/ap/include/axpr/builtin_class_instance_method_class.h @@ -114,6 +114,58 @@ struct MethodClassImpl> { return class_ops->Equals(self, rhs_val); } + adt::Result Add(InterpreterBase* interpreter, + const Self& self, + const ValueT& rhs_val) { + const auto& opt_func = GetClassAttr(self, "__add__"); + const auto& class_attrs = self.type.class_attrs(); + ADT_CHECK(opt_func.has_value()) + << adt::errors::AttributeError{std::string() + class_attrs->class_name + + " class has no attribute '__add__'"}; + std::vector args{rhs_val}; + ADT_LET_CONST_REF(ret, interpreter->InterpretCall(opt_func.value(), args)); + return ret; + } + + adt::Result Sub(InterpreterBase* interpreter, + const Self& self, + const ValueT& rhs_val) { + const auto& opt_func = GetClassAttr(self, "__sub__"); + const auto& class_attrs = self.type.class_attrs(); + ADT_CHECK(opt_func.has_value()) + << adt::errors::AttributeError{std::string() + class_attrs->class_name + + " class has no attribute '__sub__'"}; + std::vector args{rhs_val}; + ADT_LET_CONST_REF(ret, interpreter->InterpretCall(opt_func.value(), args)); + return ret; + } + + adt::Result Mul(InterpreterBase* interpreter, + const Self& self, + const ValueT& rhs_val) { + const auto& opt_func = GetClassAttr(self, "__mul__"); + const auto& class_attrs = self.type.class_attrs(); + ADT_CHECK(opt_func.has_value()) + << adt::errors::AttributeError{std::string() + class_attrs->class_name + + " class has no attribute '__mul__'"}; + std::vector args{rhs_val}; + ADT_LET_CONST_REF(ret, interpreter->InterpretCall(opt_func.value(), args)); + return ret; + } + + adt::Result FloorDiv(InterpreterBase* interpreter, + const Self& self, + const ValueT& rhs_val) { + const auto& opt_func = GetClassAttr(self, "__floordiv__"); + const auto& class_attrs = self.type.class_attrs(); + ADT_CHECK(opt_func.has_value()) << adt::errors::AttributeError{ + std::string() + class_attrs->class_name + + " class has no attribute '__floordiv__'"}; + std::vector args{rhs_val}; + ADT_LET_CONST_REF(ret, interpreter->InterpretCall(opt_func.value(), args)); + return ret; + } + adt::Result GetItem(InterpreterBase* interpreter, const Self& self, const ValueT& idx_val) { diff --git a/paddle/ap/include/axpr/builtin_frame_util.h b/paddle/ap/include/axpr/builtin_frame_util.h index 4e3f87c39b4fc0..fdc8c9431cf414 100644 --- a/paddle/ap/include/axpr/builtin_frame_util.h +++ b/paddle/ap/include/axpr/builtin_frame_util.h @@ -41,6 +41,7 @@ void VisitEachBuiltinFrameAttr(const YieldT& Yield) { Yield("__builtin_not__", &BuiltinNot); Yield("__builtin__foreach", &ForEach); + auto YieldTwice = [&](const auto& name, const auto& value) { Yield(name, value); Yield(std::string("__builtin__") + name, value); diff --git a/paddle/ap/include/axpr/dim_expr_method_class.h b/paddle/ap/include/axpr/dim_expr_method_class.h index d4f39eceff7721..32cf1593dff3b8 100644 --- a/paddle/ap/include/axpr/dim_expr_method_class.h +++ b/paddle/ap/include/axpr/dim_expr_method_class.h @@ -22,6 +22,8 @@ #include "paddle/pir/include/dialect/shape/utils/dim_expr.h" namespace ap::axpr { +template +axpr::TypeImpl> GetDimExprClass(); template struct DimExprMethodClass { @@ -41,6 +43,38 @@ struct DimExprMethodClass { return hash_value; } + static adt::Result Add(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(lhs, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(rhs, args.at(0).template CastTo()); + return GetDimExprClass().New(lhs + rhs); + } + + static adt::Result Sub(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(lhs, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(rhs, args.at(0).template CastTo()); + return GetDimExprClass().New(lhs - rhs); + } + + static adt::Result Mul(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(lhs, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(rhs, args.at(0).template CastTo()); + return GetDimExprClass().New(lhs * rhs); + } + + static adt::Result FloorDiv(const ValueT& self_val, + const std::vector& args) { + ADT_LET_CONST_REF(lhs, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(rhs, args.at(0).template CastTo()); + return GetDimExprClass().New(lhs / rhs); + } + static adt::Result Match(axpr::InterpreterBase* interpreter, const ValueT& self_val, const std::vector& packed_args_val) { @@ -93,6 +127,10 @@ axpr::TypeImpl> GetDimExprClass() { static auto cls( axpr::MakeBuiltinClass("DimExpr", [&](const auto& Define) { Define("__str__", &Impl::ToString); + Define("__add__", &Impl::Add); + Define("__sub__", &Impl::Sub); + Define("__mul__", &Impl::Mul); + Define("__floordiv__", &Impl::FloorDiv); Define("__hash__", &Impl::Hash); Define("match", &Impl::Match); })); diff --git a/paddle/ap/include/axpr/type_util.h b/paddle/ap/include/axpr/type_util.h index 5b48f9728897de..6bef6021a17690 100644 --- a/paddle/ap/include/axpr/type_util.h +++ b/paddle/ap/include/axpr/type_util.h @@ -98,6 +98,7 @@ AttrMap GetObjectTypeName2Type() { OrderedDict, MutableOrderedDict, AttrMap, + AttrMap, ValueImplTypes...>::Call(&object); return object; } diff --git a/paddle/ap/include/drr/builtin_frame_util.h b/paddle/ap/include/drr/builtin_frame_util.h index b899b49e7cd9de..f4a20eb1b5673d 100644 --- a/paddle/ap/include/drr/builtin_frame_util.h +++ b/paddle/ap/include/drr/builtin_frame_util.h @@ -27,9 +27,7 @@ void VisitEachBuiltinFrameClass(const DoEachT& DoEach) { DoEach(drr::Type{}.GetClass()); } -template -ap::axpr::AttrMap MakeBuiltinFrameAttrMap( - const VisitorT& Visitor) { +inline ap::axpr::AttrMap MakeBuiltinFrameAttrMap() { ap::axpr::AttrMap attr_map; ap::axpr::VisitEachBuiltinFrameAttr( [&](const std::string& k, const axpr::Value& v) { attr_map->Set(k, v); }); @@ -38,7 +36,6 @@ ap::axpr::AttrMap MakeBuiltinFrameAttrMap( attr_map->Set(std::string("__builtin__") + cls.Name(), cls); }; VisitEachBuiltinFrameClass(Insert); - Visitor(Insert); return attr_map; } diff --git a/paddle/ap/include/drr/drr_interpreter.h b/paddle/ap/include/drr/drr_interpreter.h index 0fc33a22fbe590..85febf37c2e6c9 100644 --- a/paddle/ap/include/drr/drr_interpreter.h +++ b/paddle/ap/include/drr/drr_interpreter.h @@ -23,11 +23,8 @@ namespace ap::drr { class DrrInterpreter { public: - explicit DrrInterpreter( - const axpr::TypeImpl>& - backend_ir_ctx, - const std::weak_ptr& - circlable_ref_list); + explicit DrrInterpreter(const std::weak_ptr& + circlable_ref_list); using Function = ap::axpr::Value; diff --git a/paddle/ap/include/paddle/pir/manual_op.h b/paddle/ap/include/paddle/hlir/manual_op.h similarity index 88% rename from paddle/ap/include/paddle/pir/manual_op.h rename to paddle/ap/include/paddle/hlir/manual_op.h index 75dd12529ab314..6ac2aa81cb2e17 100644 --- a/paddle/ap/include/paddle/pir/manual_op.h +++ b/paddle/ap/include/paddle/hlir/manual_op.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/ap/include/axpr/attr_map.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/pir/include/core/builder.h" @@ -25,6 +26,22 @@ namespace ap::dialect { +class IR_API FacadeOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "ap_op.facade"; } + static constexpr uint32_t attributes_num = 3; + static const char *attributes_name[attributes_num]; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + const std::vector &inputs, + const pir::AttributeMap &attributes, + const std::vector &output_types); + void VerifySig() const {} + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); +}; + class IR_API UpSpiderOp : public pir::Op SetDims(const Self& self, const axpr::Value& dims_val) { + if (dims_val.CastableTo()) { + ADT_LET_CONST_REF(ddim, dims_val.CastTo()); + return SetDimsByDDim(self, ddim); + } return dims_val.Match( - [&](const DDim& ddims) -> adt::Result { - return SetDimsByDDim(self, ddims); - }, [&](const adt::List& list) -> adt::Result { return SetDimsByIntList(self, list); }, diff --git a/paddle/ap/include/paddle/pass/convert_pd_facade_to_ap_facade.h b/paddle/ap/include/paddle/pass/convert_pd_facade_to_ap_facade.h new file mode 100644 index 00000000000000..79786f072f1a5d --- /dev/null +++ b/paddle/ap/include/paddle/pass/convert_pd_facade_to_ap_facade.h @@ -0,0 +1,41 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/pir/include/pass/pass.h" + +namespace ap::memory { + +class CirclableRefListBase; + +} + +namespace ap::axpr { + +struct Value; + +} + +namespace cinn { +namespace dialect { +namespace ir { + +std::unique_ptr<::pir::Pass> CreateConvertPdFacadeToApFacadePass(); + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/ap/include/paddle/pass/ir_helper_method_class.h b/paddle/ap/include/paddle/pass/ir_helper_method_class.h index fc234cc74a654b..6344336e831b21 100644 --- a/paddle/ap/include/paddle/pass/ir_helper_method_class.h +++ b/paddle/ap/include/paddle/pass/ir_helper_method_class.h @@ -18,10 +18,10 @@ #include "paddle/ap/include/axpr/callable_helper.h" #include "paddle/ap/include/axpr/lambda_expr_builder.h" #include "paddle/ap/include/drr/drr_value_helper.h" +#include "paddle/ap/include/paddle/hlir/op_dialect.h" #include "paddle/ap/include/paddle/pass/ap_drr_helper.h" #include "paddle/ap/include/paddle/pass/ap_generic_drr_pass.h" #include "paddle/ap/include/paddle/pass/ir_helper.h" -#include "paddle/ap/include/paddle/pir/op_dialect.h" #include "paddle/ap/include/paddle/pir/packed_ir_op_inner_source_pattern_helper.h" #include "paddle/ap/include/paddle/pir/pass_manager_method_class.h" #include "paddle/ap/include/paddle/pir/pass_method_class.h" diff --git a/paddle/ap/include/paddle/phi/ap_infer_meta_helper.h b/paddle/ap/include/paddle/phi/ap_infer_meta_helper.h index e39ae72c322863..cbc09cf2fa9dba 100644 --- a/paddle/ap/include/paddle/phi/ap_infer_meta_helper.h +++ b/paddle/ap/include/paddle/phi/ap_infer_meta_helper.h @@ -15,8 +15,11 @@ #pragma once #include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/attr_map.h" #include "paddle/ap/include/axpr/core_expr.h" +#include "paddle/ap/include/axpr/value.h" #include "paddle/phi/core/meta_tensor.h" +#include "paddle/pir/include/core/operation_utils.h" namespace phi { @@ -29,6 +32,12 @@ struct ApInferMetaHelper { adt::Result InferMeta(const std::string& lambda, const std::vector* inputs, std::vector* outputs); + + adt::Result InferMetaByAxprHook( + const ::paddle::optional>& inputs, + const std::string& infer_meta_func_name, + const std::string& serialized_attributes, + const std::vector& outputs); }; } // namespace phi diff --git a/paddle/ap/include/paddle/pir/ap_pir_attribute.h b/paddle/ap/include/paddle/pir/ap_pir_attribute.h new file mode 100644 index 00000000000000..e33e9c100127fc --- /dev/null +++ b/paddle/ap/include/paddle/pir/ap_pir_attribute.h @@ -0,0 +1,107 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/pir/include/core/builtin_attribute.h" + +namespace ap::dialect { + +template +using ApPirAttributeImpl = std::variant>; + +struct ApPirAttribute : public ApPirAttributeImpl { + using ApPirAttributeImpl::ApPirAttributeImpl; + ADT_DEFINE_VARIANT_METHODS(ApPirAttributeImpl); + + static std::optional OptCastFrom(const axpr::Value& value) { + const auto& ret = CastFrom(value); + if (ret.HasError()) return std::nullopt; + return ret.GetOkValue(); + } + + static adt::Result CastFrom(const axpr::Value& value) { + using RetT = adt::Result; + return value.Match( + [](bool impl) -> RetT { return impl; }, + [](int64_t impl) -> RetT { return impl; }, + [](double impl) -> RetT { return impl; }, + [](const std::string& impl) -> RetT { return impl; }, + [](const axpr::DataType& impl) -> RetT { return impl; }, + [](const adt::List& lst) -> RetT { + adt::List ret; + ret->reserve(lst->size()); + for (const auto& elt_val : *lst) { + ADT_LET_CONST_REF(ap_attr, ApPirAttribute::CastFrom(elt_val)); + ret->emplace_back(ap_attr); + } + return ret; + }, + [](const auto&) -> RetT { + return adt::errors::TypeError{ + "couldn't cast object from axpr::Value to ApPirAttribute"}; + }); + } + + adt::Result CastToPirAttribute() const { + return Match([&](const auto& impl) -> adt::Result { + return CastToPirAttributeImpl(impl); + }); + } + + adt::Result CastToPirAttributeImpl(bool impl) const { + return pir::BoolAttribute::get(pir::IrContext::Instance(), impl); + } + + adt::Result CastToPirAttributeImpl(int64_t impl) const { + return pir::Int64Attribute::get(pir::IrContext::Instance(), impl); + } + + adt::Result CastToPirAttributeImpl(double impl) const { + return pir::DoubleAttribute::get(pir::IrContext::Instance(), impl); + } + + adt::Result CastToPirAttributeImpl( + const std::string& impl) const { + return pir::StrAttribute::get(pir::IrContext::Instance(), impl); + } + + adt::Result CastToPirAttributeImpl( + const axpr::DataType& impl) const { + ADT_LET_CONST_REF(phi_data_type, axpr::GetPhiDataTypeFromDataType(impl)); + return ::paddle::dialect::DataTypeAttribute::get(pir::IrContext::Instance(), + phi_data_type); + } + + adt::Result CastToPirAttributeImpl( + const adt::List& impl) const { + std::vector vec; + vec.resize(impl->size()); + for (const auto& ap_attr : *impl) { + ADT_LET_CONST_REF(elt_attr, ap_attr.CastToPirAttribute()); + vec.emplace_back(elt_attr); + } + return pir::ArrayAttribute::get(pir::IrContext::Instance(), vec); + } +}; + +} // namespace ap::dialect diff --git a/paddle/ap/include/paddle/pir/infer_symbolic_shape_context_method_class.h b/paddle/ap/include/paddle/pir/infer_symbolic_shape_context_method_class.h new file mode 100644 index 00000000000000..c583df6522ded6 --- /dev/null +++ b/paddle/ap/include/paddle/pir/infer_symbolic_shape_context_method_class.h @@ -0,0 +1,33 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ap/include/adt/adt.h" +#include "paddle/ap/include/axpr/naive_class_ops.h" +#include "paddle/ap/include/axpr/value.h" +#include "paddle/ap/include/paddle/pir/type.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/core/builtin_attribute.h" +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace ap::paddle { + +axpr::TypeImpl> +GetPirInferSymbolicShapeContextClass(); + +} // namespace ap::paddle diff --git a/paddle/ap/include/paddle/pir/infer_symbolic_shape_util.h b/paddle/ap/include/paddle/pir/infer_symbolic_shape_util.h new file mode 100644 index 00000000000000..8dd2d481f1c0a6 --- /dev/null +++ b/paddle/ap/include/paddle/pir/infer_symbolic_shape_util.h @@ -0,0 +1,32 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/common/ddim.h" +#include "paddle/common/layout.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace ap::dialect { + +// for ap_op.facade +bool ApOpFacadeOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context); + +// for pd_op.ap_facade +bool PdOpApFacadeOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context); + +} // namespace ap::dialect diff --git a/paddle/ap/include/paddle/pir/pir_method_class.h b/paddle/ap/include/paddle/pir/pir_method_class.h index 95572c35d7fa9b..e8edd2a20fffc4 100644 --- a/paddle/ap/include/paddle/pir/pir_method_class.h +++ b/paddle/ap/include/paddle/pir/pir_method_class.h @@ -22,6 +22,7 @@ #include "paddle/ap/include/paddle/phi/place_method_class.h" #include "paddle/ap/include/paddle/pir/attribute_method_class.h" #include "paddle/ap/include/paddle/pir/pir.h" +#include "paddle/ap/include/paddle/pir/shape_or_data_method_class.h" #include "paddle/ap/include/paddle/pir/type_method_class.h" namespace ap::paddle { diff --git a/paddle/ap/include/paddle/pir/shape_or_data_method_class.h b/paddle/ap/include/paddle/pir/shape_or_data_method_class.h index 8bf23b88836582..39d4ef1ada8faf 100644 --- a/paddle/ap/include/paddle/pir/shape_or_data_method_class.h +++ b/paddle/ap/include/paddle/pir/shape_or_data_method_class.h @@ -30,4 +30,25 @@ namespace ap::paddle { axpr::TypeImpl> GetPirShapeOrDataClass(); +adt::Result MakeNullShapeOrDataDimExpr( + const axpr::Value&, const std::vector& args); + +adt::Result MakeTensorShapeOrDataDimExprs( + const axpr::Value&, const std::vector& args); + +adt::Result MakeTensorListShapeOrDataDimExprs( + const axpr::Value&, const std::vector& args); + +adt::Result MakeRankedTensorArrayShapeOrDataDimExprs( + const axpr::Value&, const std::vector& args); + +template +void ForEachShapeOrDataMaker(const YieldT& Yield) { + Yield("s_null", &MakeNullShapeOrDataDimExpr); + Yield("s_tensor_shape_or_data", &MakeTensorShapeOrDataDimExprs); + Yield("s_tensor_list_shape_or_data", &MakeTensorListShapeOrDataDimExprs); + Yield("s_ranked_tensor_array_shape_or_data", + &MakeRankedTensorArrayShapeOrDataDimExprs); +} + } // namespace ap::paddle diff --git a/paddle/ap/src/drr/drr_interpreter.cc b/paddle/ap/src/drr/drr_interpreter.cc index 8383d8f6074d8e..8cf5035883d9b0 100644 --- a/paddle/ap/src/drr/drr_interpreter.cc +++ b/paddle/ap/src/drr/drr_interpreter.cc @@ -37,12 +37,8 @@ using DrrCtx = ap::drr::DrrCtx; } // namespace DrrInterpreter::DrrInterpreter( - const axpr::TypeImpl>& - backend_ir_ctx, const std::weak_ptr& circlable_ref_list) - : interpreter_(ap::drr::MakeBuiltinFrameAttrMap( - [&](const auto& Insert) { Insert(backend_ir_ctx); }), - circlable_ref_list) {} + : interpreter_(ap::drr::MakeBuiltinFrameAttrMap(), circlable_ref_list) {} adt::Result DrrInterpreter::InterpretDrrCtxMaker( const Function& lambda, const std::vector& args) { diff --git a/paddle/ap/src/paddle/pir/manual_op.cc b/paddle/ap/src/paddle/hlir/manual_op.cc similarity index 86% rename from paddle/ap/src/paddle/pir/manual_op.cc rename to paddle/ap/src/paddle/hlir/manual_op.cc index 4e96d190e08305..c61bcd1d195faf 100644 --- a/paddle/ap/src/paddle/pir/manual_op.cc +++ b/paddle/ap/src/paddle/hlir/manual_op.cc @@ -14,7 +14,8 @@ #include -#include "paddle/ap/include/paddle/pir/manual_op.h" +#include "paddle/ap/include/paddle/hlir/manual_op.h" +#include "paddle/ap/include/paddle/pir/infer_symbolic_shape_util.h" #include "paddle/common/enforce.h" #include "paddle/pir/include/core/builtin_attribute.h" #include "paddle/pir/include/core/builtin_type.h" @@ -22,6 +23,24 @@ namespace ap::dialect { +const char* FacadeOp::attributes_name[FacadeOp::attributes_num] = { + "custom_op_name", "infer_meta", "infer_symbolic"}; + +void FacadeOp::Build(pir::Builder& builder, // NOLINT + pir::OperationArgument& argument, // NOLINT + const std::vector& inputs, + const pir::AttributeMap& attributes, + const std::vector& output_types) { + argument.inputs = inputs; + argument.attributes = attributes; + argument.output_types = output_types; +} + +bool FacadeOp::InferSymbolicShape( + pir::InferSymbolicShapeContext* infer_context) { + return ApOpFacadeOpInferSymbolicShape(*this, infer_context); +} + void UpSpiderOp::Build(pir::Builder& builder, // NOLINT pir::OperationArgument& argument, // NOLINT pir::Value lhs, @@ -145,6 +164,7 @@ bool StoreToGlobalOp::InferSymbolicShape( } // namespace ap::dialect +IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::FacadeOp); IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::UpSpiderOp); IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::DownSpiderOp); IR_DEFINE_EXPLICIT_TYPE_ID(ap::dialect::LoadFromRegisterOp); diff --git a/paddle/ap/src/paddle/pir/op_dialect.cc b/paddle/ap/src/paddle/hlir/op_dialect.cc similarity index 89% rename from paddle/ap/src/paddle/pir/op_dialect.cc rename to paddle/ap/src/paddle/hlir/op_dialect.cc index ee0d69720303b1..a8036c2c6f3b73 100644 --- a/paddle/ap/src/paddle/pir/op_dialect.cc +++ b/paddle/ap/src/paddle/hlir/op_dialect.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ap/include/paddle/pir/op_dialect.h" -#include "paddle/ap/include/paddle/pir/manual_op.h" +#include "paddle/ap/include/paddle/hlir/op_dialect.h" +#include "paddle/ap/include/paddle/hlir/manual_op.h" namespace ap { namespace dialect { @@ -25,6 +25,7 @@ OperatorDialect::OperatorDialect(::pir::IrContext *context) } void OperatorDialect::initialize() { + RegisterOp(); RegisterOp(); RegisterOp(); RegisterOp(); diff --git a/paddle/ap/src/paddle/pass/ap_drr_helper.cc b/paddle/ap/src/paddle/pass/ap_drr_helper.cc index bf551dc81a8fc5..f4b596c59cc536 100644 --- a/paddle/ap/src/paddle/pass/ap_drr_helper.cc +++ b/paddle/ap/src/paddle/pass/ap_drr_helper.cc @@ -39,7 +39,7 @@ using DrrCtx = ap::drr::DrrCtx; ApDrrHelper::ApDrrHelper( const std::weak_ptr& circlable_ref_list) - : drr_interpreter_(ap::paddle::GetPirClass(), circlable_ref_list) {} + : drr_interpreter_(circlable_ref_list) {} adt::Result ApDrrHelper::InterpretDrrCtxMaker( const Function& lambda, const std::vector& args) { diff --git a/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc b/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc index 9c7a4028a1fedd..17aa852358b88b 100644 --- a/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc +++ b/paddle/ap/src/paddle/pass/ap_generic_drr_pass.cc @@ -39,11 +39,11 @@ #include "paddle/ap/include/ir_match/ir_match_ctx.h" #include "paddle/ap/include/ir_match/op_match_ctx_method_class.h" #include "paddle/ap/include/ir_match/tensor_match_ctx_method_class.h" +#include "paddle/ap/include/paddle/hlir/manual_op.h" #include "paddle/ap/include/paddle/pass/ap_drr_helper.h" #include "paddle/ap/include/paddle/pass/ap_kernel_define_helper.h" #include "paddle/ap/include/paddle/pass/ap_registry_helper.h" #include "paddle/ap/include/paddle/pass/ir_helper_method_class.h" -#include "paddle/ap/include/paddle/pir/manual_op.h" #include "paddle/ap/include/paddle/pir/pir_method_class.h" #include "paddle/ap/include/paddle/pir/pir_node_matched_src_ptn_ctx_helper.h" #include "paddle/ap/include/paddle/pir/pir_to_anf_expr_helper.h" @@ -967,9 +967,17 @@ struct ApRewriter { const std::vector& dim_exprs) const { std::vector anf_dims; for (const auto& dim_expr : dim_exprs) { - ADT_LET_CONST_REF(anf_dim_expr, - ConstructDDimDimExpr(ctx, infer_meta_ctx, dim_expr)); - anf_dims.emplace_back(anf_dim_expr); + const auto& anf_dim_expr = + ConstructDDimDimExpr(ctx, infer_meta_ctx, dim_expr); + if (anf_dim_expr.HasOkValue()) { + anf_dims.emplace_back(anf_dim_expr.GetOkValue()); + } else if (anf_dim_expr.GetError() + .template Has()) { + // TODO(lixinqi): anf_dims.emplace_back(AnfExpr{ctx->Int64(-1)}); + return anf_dim_expr.GetError(); + } else { + return anf_dim_expr.GetError(); + } } return ctx->Call(ap::axpr::kBuiltinList(), anf_dims); } @@ -990,7 +998,9 @@ struct ApRewriter { const OpInferMetaCtx& infer_meta_ctx, const symbol::DimExpr& dim_expr) const { const auto& idx_iter = infer_meta_ctx.dim_expr2in_dim_index.find(dim_expr); - ADT_CHECK(idx_iter != infer_meta_ctx.dim_expr2in_dim_index.end()); + if (idx_iter == infer_meta_ctx.dim_expr2in_dim_index.end()) { + return adt::errors::MismatchError{}; + } auto anf_expr_iter = infer_meta_ctx.dim_expr2anf_expr.find(dim_expr); if (anf_expr_iter == infer_meta_ctx.dim_expr2anf_expr.end()) { const auto& in_dim = ConstructInDimExpr(ctx, idx_iter->second); diff --git a/paddle/ap/src/paddle/pass/convert_pd_facade_to_ap_facade.cc b/paddle/ap/src/paddle/pass/convert_pd_facade_to_ap_facade.cc new file mode 100644 index 00000000000000..6046cd76c97b5c --- /dev/null +++ b/paddle/ap/src/paddle/pass/convert_pd_facade_to_ap_facade.cc @@ -0,0 +1,183 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pass/convert_pd_facade_to_ap_facade.h" +#include "paddle/ap/include/paddle/hlir/manual_op.h" + +#include "paddle/ap/include/axpr/abstract_list.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/atomic.h" +#include "paddle/ap/include/axpr/builtin_frame_util.h" +#include "paddle/ap/include/axpr/builtin_serializable_attr_map_to_axpr_helper.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/paddle/builtin_frame_util.h" +#include "paddle/ap/include/paddle/pir/ap_pir_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace cinn::dialect::ir { + +namespace adt = ap::adt; + +namespace { + +class ConvertPdFacadeToApFacadePattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(paddle::dialect::ApFacadeOp pd_facade_op, + pir::PatternRewriter& rewriter) const override { + const auto& ret = TryMatchAndRewrite(pd_facade_op, &rewriter); + PADDLE_ENFORCE_EQ( + ret.HasError(), + false, + phi::errors::Fatal( + "ConvertPdFacadeToApFacadePattern::MatchAndRewrite failed. " + "\nTraceback (most recent call " + "last):\n%s\n%s: %s. ", + ret.GetError().CallStackToString(), + ret.GetError().class_name(), + ret.GetError().msg())); + return ret.GetOkValue(); + } + + adt::Result TryMatchAndRewrite(paddle::dialect::ApFacadeOp pd_facade_op, + pir::PatternRewriter* rewriter) const { + std::vector inputs{}; + pir::Operation* upstream_op = nullptr; + if (pd_facade_op->operand_source(0)) { + upstream_op = pd_facade_op->operand_source(0).defining_op(); + ADT_CHECK(upstream_op != nullptr); + ADT_CHECK(upstream_op->isa()) << adt::errors::TypeError{ + "the upstream of pd_op.ap_facade should builtin.combine"}; + inputs = upstream_op->dyn_cast().inputs(); + } + ADT_CHECK(pd_facade_op->result(0).use_count() == 1); + auto* downstream_op = pd_facade_op->result(0).first_use().owner(); + ADT_CHECK(downstream_op != nullptr); + ADT_CHECK(downstream_op->isa()) << adt::errors::TypeError{ + "the downstream of pd_op.ap_facade should builtin.split"}; + ADT_LET_CONST_REF(attributes, GetFacadeOpAttributes(pd_facade_op)); + const auto old_outputs = downstream_op->dyn_cast().outputs(); + std::vector output_types{}; + output_types.reserve(old_outputs.size()); + for (const auto& output : old_outputs) { + output_types.emplace_back(output.type()); + } + auto ap_facade_op = rewriter->Build( + inputs, attributes, output_types); + for (int i = 0; i < old_outputs.size(); ++i) { + rewriter->ReplaceAllUsesWith(old_outputs.at(i), ap_facade_op->result(i)); + } + rewriter->EraseOp(downstream_op); + rewriter->EraseOp(pd_facade_op); + if (upstream_op != nullptr) { + rewriter->EraseOp(upstream_op); + } + return true; + } + + adt::Result GetFacadeOpAttributes( + paddle::dialect::ApFacadeOp pd_facade_op) const { + ADT_LET_CONST_REF(serialized_attributes, + GetFacadeOpSerializedAttributes(pd_facade_op)); + ADT_LET_CONST_REF(lambda, CastStrToLambda(serialized_attributes)); + ADT_LET_CONST_REF(attr_map, RunLambda(lambda)); + return CastToPirAttributeMap(pd_facade_op, attr_map, serialized_attributes); + } + + adt::Result GetFacadeOpSerializedAttributes( + paddle::dialect::ApFacadeOp op) const { + const auto& iter = op->attributes().find("serialized_attributes"); + ADT_CHECK(iter != op->attributes().end()); + ADT_CHECK(iter->second.template isa()); + return iter->second.template dyn_cast().AsString(); + } + + adt::Result> CastStrToLambda( + const std::string& serialized_attributes) const { + ADT_LET_CONST_REF( + anf_expr, ap::axpr::MakeAnfExprFromJsonString(serialized_attributes)); + const auto& core_expr = ap::axpr::ConvertAnfExprToCoreExpr(anf_expr); + std::vector> args{}; + return ap::axpr::Lambda{args, core_expr}; + } + + adt::Result> RunLambda( + const ap::axpr::Lambda& lambda) const { + ap::memory::Guard guard{}; + ap::axpr::Interpreter interpreter( + ap::paddle::MakeBuiltinFrameAttrMap(), + guard.circlable_ref_list()); + ADT_LET_CONST_REF(ret_val, interpreter.Interpret(lambda, {})); + ADT_LET_CONST_REF( + attr_map_val, + ret_val.template CastTo>()); + return attr_map_val; + } + + adt::Result CastToPirAttributeMap( + paddle::dialect::ApFacadeOp pd_facade_op, + const ap::axpr::AttrMap& attr_map, + const std::string& serialized_attributes) const { + pir::AttributeMap attributes{}; + for (const auto& [name, val] : attr_map->storage) { + ADT_LET_CONST_REF(pir_attr, CastToPirAttribute(val)); + attributes[name] = pir_attr; + } + const auto& CopyAttribute = + [&](const auto& attr_name) -> adt::Result { + const auto& iter = pd_facade_op->attributes().find(attr_name); + ADT_CHECK(iter != pd_facade_op->attributes().end()); + attributes[attr_name] = iter->second; + return adt::Ok{}; + }; + ADT_RETURN_IF_ERR(CopyAttribute("custom_op_name")); + ADT_RETURN_IF_ERR(CopyAttribute("infer_meta_func_name")); + ADT_RETURN_IF_ERR(CopyAttribute("infer_symbolic_func_name")); + attributes["__original_serialized_attributes__"] = pir::StrAttribute::get( + pir::IrContext::Instance(), serialized_attributes); + return attributes; + } + + adt::Result CastToPirAttribute( + const ap::axpr::Value& val) const { + ADT_LET_CONST_REF(ap_pir_attr, ap::dialect::ApPirAttribute::CastFrom(val)); + return ap_pir_attr.CastToPirAttribute(); + } +}; + +class ConvertPdFacadeToApFacadePass : public pir::PatternRewritePass { + public: + ConvertPdFacadeToApFacadePass() + : pir::PatternRewritePass("convert_pd_facade_to_ap_facade_pass", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add(context); + return ps; + } +}; + +} // namespace + +std::unique_ptr<::pir::Pass> CreateConvertPdFacadeToApFacadePass() { + return std::make_unique(); +} + +} // namespace cinn::dialect::ir diff --git a/paddle/ap/src/paddle/pass/op_factory.cc b/paddle/ap/src/paddle/pass/op_factory.cc index e39cd6d9c3d4e5..2f06aec7b14358 100644 --- a/paddle/ap/src/paddle/pass/op_factory.cc +++ b/paddle/ap/src/paddle/pass/op_factory.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/ap/src/paddle/pass/op_factory.h" -#include "paddle/ap/include/paddle/pir/manual_op.h" +#include "paddle/ap/include/paddle/hlir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/include/core/builtin_attribute.h" diff --git a/paddle/ap/src/paddle/phi/ap_infer_meta_helper.cc b/paddle/ap/src/paddle/phi/ap_infer_meta_helper.cc index 95a3dfba82e23a..a8ddeacd047cfc 100644 --- a/paddle/ap/src/paddle/phi/ap_infer_meta_helper.cc +++ b/paddle/ap/src/paddle/phi/ap_infer_meta_helper.cc @@ -16,8 +16,10 @@ #include #include "paddle/ap/include/adt/adt.h" #include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/attr_map.h" #include "paddle/ap/include/axpr/data_type.h" #include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" #include "paddle/ap/include/axpr/value.h" #include "paddle/ap/include/axpr/value_method_class.h" #include "paddle/ap/include/paddle/builtin_frame_util.h" @@ -34,6 +36,8 @@ namespace phi { namespace { +namespace adt = ap::adt; + using CoreExpr = ap::axpr::CoreExpr; using Lambda = ap::axpr::Lambda; @@ -52,44 +56,137 @@ adt::Result InferMetaByLambda( return adt::Ok{}; } -adt::Result MakeLambda(const std::string& lambda_str) { - ADT_LET_CONST_REF(anf_expr, ap::axpr::MakeAnfExprFromJsonString(lambda_str)); +adt::Result InferMetaByLambda( + const Lambda& lambda, + const ::paddle::optional>& inputs, + const ap::axpr::AttrMap& attrs, + const std::vector& outputs) { + ap::memory::Guard guard{}; + ap::axpr::Interpreter interpreter( + ap::paddle::MakeBuiltinFrameAttrMap(), + guard.circlable_ref_list()); + adt::List inputs_val{}; + if (inputs.is_initialized()) { + inputs_val->reserve(inputs->size()); + for (const auto& input : *inputs) { + inputs_val->emplace_back( + ap::paddle::GetConstMetaTensorPtrClass().New(input)); + } + } + adt::List outputs_val{}; + outputs_val->reserve(outputs.size()); + for (const auto& output : outputs) { + outputs_val->emplace_back(ap::paddle::GetMetaTensorPtrClass().New(output)); + } + ADT_RETURN_IF_ERR( + interpreter.Interpret(lambda, {inputs_val, attrs, outputs_val})); + return adt::Ok{}; +} + +template +using MakeT = adt::Result (*)(const std::string& str); + +template Make> +adt::Result CacheResult(const std::string& serialized_attributes) { + static std::unordered_map> cache; + static std::mutex mutex; + std::unique_lock lock(mutex); + auto iter = cache.find(serialized_attributes); + if (iter == cache.end()) { + iter = + cache.emplace(serialized_attributes, Make(serialized_attributes)).first; + } + ADT_LET_CONST_REF(ret, iter->second); + return ret; +} + +adt::Result MakeLambda(const std::string& serialized_attributes) { + ADT_LET_CONST_REF(anf_expr, + ap::axpr::MakeAnfExprFromJsonString(serialized_attributes)); const auto& core_expr = ap::axpr::ConvertAnfExprToCoreExpr(anf_expr); ADT_LET_CONST_REF(atomic, core_expr.TryGet>()) << adt::errors::TypeError{ std::string() + - "lambda_str can not be converted to atomic AnfExpr."}; + "serialized_attributes can not be converted to atomic AnfExpr."}; ADT_LET_CONST_REF(lambda, atomic.TryGet>()); return lambda; } -using MakeLambdaT = adt::Result (*)(const std::string& lambda_str); +constexpr auto CastToLambda = &CacheResult; -template -adt::Result CacheConvertResult(const std::string& lambda_str) { - static std::unordered_map> cache; - static std::mutex mutex; - std::unique_lock lock(mutex); - auto iter = cache.find(lambda_str); - if (iter == cache.end()) { - iter = cache.emplace(lambda_str, Make(lambda_str)).first; - } - ADT_LET_CONST_REF(lambda, iter->second); - return lambda; +adt::Result> MakeAttrMap( + const std::string& serialized_attributes) { + ADT_LET_CONST_REF(anf_expr, + ap::axpr::MakeAnfExprFromJsonString(serialized_attributes)); + const auto& core_expr = ap::axpr::ConvertAnfExprToCoreExpr(anf_expr); + std::vector> args{}; + ap::axpr::Lambda lambda{args, core_expr}; + ap::memory::Guard guard{}; + ap::axpr::Interpreter interpreter( + ap::paddle::MakeBuiltinFrameAttrMap(), + guard.circlable_ref_list()); + ADT_LET_CONST_REF(ret, interpreter.Interpret(lambda, {})); + ADT_LET_CONST_REF(attrs, + ret.template CastTo>()); + return attrs; +} + +constexpr auto CastToAttrMap = + &CacheResult, &MakeAttrMap>; + +adt::Result MakeInferMetaLambda( + const std::string& infer_meta_func_name) { + auto dot_pos = infer_meta_func_name.find('.'); + ADT_CHECK(dot_pos != std::string::npos); + const auto& module_name = infer_meta_func_name.substr(0, dot_pos); + const auto& func_name = infer_meta_func_name.substr(dot_pos + 1); + ADT_CHECK(func_name.find('.') == std::string::npos); + ap::axpr::LambdaExprBuilder lmd; + const ap::axpr::AnfExpr anf_expr = + lmd.Lambda({"inputs", "attrs", "mut_outputs"}, [&](auto& ctx) { + auto& infer_hooks = ctx.Var("import").Call(ctx.String(module_name)); + auto& method = infer_hooks.Attr(func_name); + auto& inputs = ctx.Var("inputs"); + auto& attrs = ctx.Var("attrs"); + auto& mut_outputs = ctx.Var("mut_outputs"); + auto& ret = method.Call(inputs, attrs, mut_outputs); + return ret; + }); + const auto& core_expr = ap::axpr::ConvertAnfExprToCoreExpr(anf_expr); + const auto& atomic = core_expr.Get>(); + return atomic.Get>(); } -constexpr MakeLambdaT CastToLambda = &CacheConvertResult<&MakeLambda>; +constexpr auto GetInferMetaLambda = &CacheResult; + +adt::Result InferMetaByAxprHookImpl( + const std::string& infer_meta_func_name, + const ::paddle::optional>& inputs, + const ap::axpr::AttrMap& attrs, + const std::vector& outputs) { + ADT_LET_CONST_REF(lambda, GetInferMetaLambda(infer_meta_func_name)); + return InferMetaByLambda(lambda, inputs, attrs, outputs); +} } // namespace adt::Result ApInferMetaHelper::InferMeta( - const std::string& lambda_str, + const std::string& serialized_attributes, const std::vector* inputs, std::vector* outputs) { - ADT_LET_CONST_REF(lambda, CastToLambda(lambda_str)); + ADT_LET_CONST_REF(lambda, CastToLambda(serialized_attributes)); return InferMetaByLambda(lambda, inputs, outputs); } +adt::Result ApInferMetaHelper::InferMetaByAxprHook( + const ::paddle::optional>& inputs, + const std::string& infer_meta_func_name, + const std::string& serialized_attributes, + const std::vector& outputs) { + ADT_LET_CONST_REF(attrs, CastToAttrMap(serialized_attributes)); + return InferMetaByAxprHookImpl(infer_meta_func_name, inputs, attrs, outputs); +} + } // namespace phi diff --git a/paddle/ap/src/paddle/pir/infer_symbolic_shape_context_method_class.cc b/paddle/ap/src/paddle/pir/infer_symbolic_shape_context_method_class.cc new file mode 100644 index 00000000000000..71ddee2daf79e5 --- /dev/null +++ b/paddle/ap/src/paddle/pir/infer_symbolic_shape_context_method_class.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pir/infer_symbolic_shape_context_method_class.h" +#include "paddle/ap/include/axpr/callable_helper.h" +#include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/dim_expr.h" +#include "paddle/ap/include/paddle/pir/type_adt_type_id.h" +#include "paddle/ap/include/paddle/pir/type_method_class.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace ap::paddle { + +namespace { + +adt::Result Max(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF( + self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(lhs, args.at(0).template CastTo()); + ADT_LET_CONST_REF(rhs, args.at(1).template CastTo()); + symbol::DimExpr ret{symbol::Max{ + symbol::List{lhs, rhs}, + }}; + return axpr::GetDimExprClass().New(symbol::SimplifyDimExpr(ret)); +} + +adt::Result Min(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF( + self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(lhs, args.at(0).template CastTo()); + ADT_LET_CONST_REF(rhs, args.at(1).template CastTo()); + symbol::DimExpr ret{symbol::Min{ + symbol::List{lhs, rhs}, + }}; + return axpr::GetDimExprClass().New(symbol::SimplifyDimExpr(ret)); +} + +adt::Result Broadcast(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF( + self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(lhs, args.at(0).template CastTo()); + ADT_LET_CONST_REF(rhs, args.at(1).template CastTo()); + symbol::DimExpr ret{symbol::Broadcast{ + symbol::List{lhs, rhs}, + }}; + return axpr::GetDimExprClass().New(symbol::SimplifyDimExpr(ret)); +} + +adt::Result NewSymbolicName(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF( + self, self_val.template CastTo()); + ADT_CHECK(args.size() == 0); + return self->GetNextSymName(); +} + +adt::Result AddEqualCstr(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF( + self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(lhs, args.at(0).template CastTo()); + ADT_LET_CONST_REF(rhs, args.at(1).template CastTo()); + self->AddEqualCstr(lhs, rhs); + return adt::Nothing{}; +} + +adt::Result IsEqual(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF( + self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(lhs, args.at(0).template CastTo()); + ADT_LET_CONST_REF(rhs, args.at(1).template CastTo()); + return self->IsEqual(lhs, rhs); +} + +adt::Result AddGreatThanOneCstr( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF( + self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(operand, args.at(0).template CastTo()); + self->AddGreatThanOneCstr(operand); + return adt::Nothing{}; +} + +adt::Result IsGreatThanOne(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF( + self, self_val.template CastTo()); + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(operand, args.at(0).template CastTo()); + return self->IsGreatThanOne(operand); +} + +adt::Result AddBroadcastableCstr( + const axpr::Value& self_val, const std::vector& args) { + ADT_LET_CONST_REF( + self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(lhs, args.at(0).template CastTo()); + ADT_LET_CONST_REF(rhs, args.at(1).template CastTo()); + self->AddBroadcastableCstr(lhs, rhs); + return adt::Nothing{}; +} + +adt::Result IsBroadcastable(const axpr::Value& self_val, + const std::vector& args) { + ADT_LET_CONST_REF( + self, self_val.template CastTo()); + ADT_CHECK(args.size() == 2); + ADT_LET_CONST_REF(lhs, args.at(0).template CastTo()); + ADT_LET_CONST_REF(rhs, args.at(1).template CastTo()); + return self->IsBroadcastable(lhs, rhs); +} + +} // namespace + +axpr::TypeImpl> +GetPirInferSymbolicShapeContextClass() { + static auto cls(axpr::MakeBuiltinClass( + "PirInferSymbolicShapeContext", [&](const auto& Yield) { + Yield("max", &Max); + Yield("min", &Min); + Yield("broadcast", &Broadcast); + Yield("new_symbolic_name", &NewSymbolicName); + Yield("add_equal_cstr", &AddEqualCstr); + Yield("is_equal", &IsEqual); + Yield("add_greater_than_one_cstr", &AddGreatThanOneCstr); + Yield("is_greater_than_one", &IsGreatThanOne); + Yield("add_broadcastable_cstr", &AddBroadcastableCstr); + Yield("is_broadcastable", &IsBroadcastable); + })); + return axpr::MakeGlobalNaiveClassOps(cls); +} + +} // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pir/infer_symbolic_shape_util.cc b/paddle/ap/src/paddle/pir/infer_symbolic_shape_util.cc new file mode 100644 index 00000000000000..c84d1e14401b53 --- /dev/null +++ b/paddle/ap/src/paddle/pir/infer_symbolic_shape_util.cc @@ -0,0 +1,279 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ap/include/paddle/pir/infer_symbolic_shape_util.h" +#include "paddle/ap/include/axpr/abstract_list.h" +#include "paddle/ap/include/axpr/anf_expr_util.h" +#include "paddle/ap/include/axpr/interpreter.h" +#include "paddle/ap/include/axpr/lambda_expr_builder.h" +#include "paddle/ap/include/memory/guard.h" +#include "paddle/ap/include/paddle/builtin_frame_util.h" +#include "paddle/ap/include/paddle/pir/attribute_method_class.h" +#include "paddle/ap/include/paddle/pir/infer_symbolic_shape_context_method_class.h" +#include "paddle/ap/include/paddle/pir/shape_or_data_method_class.h" + +namespace ap::dialect { + +namespace { + +using Lambda = axpr::Lambda; + +adt::Result GetInferCtxVal( + pir::InferSymbolicShapeContext* infer_context) { + return ap::paddle::GetPirInferSymbolicShapeContextClass().New(infer_context); +} + +adt::Result GetApOpFacadeOpInputsVal( + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { + adt::List lst; + lst->reserve(op->num_operands()); + for (int i = 0; i < op->num_operands(); ++i) { + const auto& shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(i)); + lst->emplace_back(ap::paddle::GetPirShapeOrDataClass().New(shape_or_data)); + } + return axpr::Value{lst}; +} + +adt::Result GetPdOpApFacadeOpInputsVal( + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { + ADT_CHECK(op->num_operands() == 1); + if (!op->operand_source(0)) { + return adt::List{}; + } + const auto& shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + ADT_CHECK( + shape_or_data.template isa()); + const auto& tensor_list_shape_or_data = + shape_or_data.template dyn_cast(); + adt::List lst; + lst->reserve(tensor_list_shape_or_data.size()); + for (const auto& elt_tensor_shape_or_data : tensor_list_shape_or_data) { + symbol::ShapeOrDataDimExprs elt_shape_or_data{elt_tensor_shape_or_data}; + lst->emplace_back( + ap::paddle::GetPirShapeOrDataClass().New(elt_shape_or_data)); + } + return axpr::Value{lst}; +} + +adt::Result CastStrToLambda(const std::string& serialized_attributes) { + ADT_LET_CONST_REF(anf_expr, + axpr::MakeAnfExprFromJsonString(serialized_attributes)); + const auto& core_expr = axpr::ConvertAnfExprToCoreExpr(anf_expr); + std::vector> args{}; + return Lambda{args, core_expr}; +} + +adt::Result Unserialize(const std::string& serialized_attributes) { + ADT_LET_CONST_REF(lambda, CastStrToLambda(serialized_attributes)); + ap::memory::Guard guard{}; + ap::axpr::Interpreter interpreter( + ap::paddle::MakeBuiltinFrameAttrMap(), + guard.circlable_ref_list()); + ADT_LET_CONST_REF(ret_val, interpreter.Interpret(lambda, {})); + ADT_CHECK(ret_val.template CastableTo>()); + return ret_val; +} + +adt::Result GetApOpApFacadeOpSerializedAttributes( + pir::Operation* op) { + const auto& iter = + op->attributes().find("__original_serialized_attributes__"); + ADT_CHECK(iter != op->attributes().end()); + ADT_CHECK(iter->second.template isa()); + const auto serialized_attributes = + iter->second.template dyn_cast().AsString(); + return serialized_attributes; +} + +adt::Result GetApOpFacadeOpAttrsVal(pir::Operation* op) { + ADT_LET_CONST_REF(serialized_attributes, + GetApOpApFacadeOpSerializedAttributes(op)); + return Unserialize(serialized_attributes); +} + +adt::Result GetPdOpApFacadeOpSerializedAttributes( + pir::Operation* op) { + const auto& iter = op->attributes().find("serialized_attributes"); + ADT_CHECK(iter != op->attributes().end()); + ADT_CHECK(iter->second.template isa()); + const auto serialized_attributes = + iter->second.template dyn_cast().AsString(); + return serialized_attributes; +} + +adt::Result GetPdOpApFacadeOpAttrsVal(pir::Operation* op) { + ADT_LET_CONST_REF(serialized_attributes, + GetPdOpApFacadeOpSerializedAttributes(op)); + return Unserialize(serialized_attributes); +} + +template +using MakeT = adt::Result (*)(const std::string& str); + +template Make> +adt::Result CacheResult(const std::string& serialized_attributes) { + static std::unordered_map> cache; + static std::mutex mutex; + std::unique_lock lock(mutex); + auto iter = cache.find(serialized_attributes); + if (iter == cache.end()) { + iter = + cache.emplace(serialized_attributes, Make(serialized_attributes)).first; + } + ADT_LET_CONST_REF(ret, iter->second); + return ret; +} + +adt::Result MakeInferSymbolicLambda( + const std::string& infer_symbolic_func_name) { + auto dot_pos = infer_symbolic_func_name.find('.'); + ADT_CHECK(dot_pos != std::string::npos); + const auto& module_name = infer_symbolic_func_name.substr(0, dot_pos); + const auto& func_name = infer_symbolic_func_name.substr(dot_pos + 1); + ADT_CHECK(func_name.find('.') == std::string::npos); + ap::axpr::LambdaExprBuilder lmd; + const ap::axpr::AnfExpr anf_expr = + lmd.Lambda({"infer_ctx", "inputs", "attrs"}, [&](auto& ctx) { + auto& infer_hooks = ctx.Var("import").Call(ctx.String(module_name)); + auto& method = infer_hooks.Attr(func_name); + auto& infer_ctx = ctx.Var("infer_ctx"); + auto& inputs = ctx.Var("inputs"); + auto& attrs = ctx.Var("attrs"); + auto& ret = method.Call(infer_ctx, inputs, attrs); + return ret; + }); + const auto& core_expr = ap::axpr::ConvertAnfExprToCoreExpr(anf_expr); + const auto& atomic = core_expr.Get>(); + return atomic.Get>(); +} + +const auto GetInferSymbolicLambda = + &CacheResult; + +adt::Result GetInferSymbolicFuncName(const pir::Operation* op) { + const auto& attrs = op->attributes(); + const auto& iter = attrs.find("infer_symbolic_func_name"); + ADT_CHECK(iter != attrs.end()); + const auto& attr = iter->second; + ADT_CHECK(attr.isa()); + return attr.dyn_cast().AsString(); +} + +adt::Result> +InferOutputsShapeOrValue(const std::string& infer_symbolic_func_name, + const axpr::Value& infer_ctx_val, + const axpr::Value& inputs_val, + const axpr::Value& attrs_val) { + ADT_LET_CONST_REF(lambda, GetInferSymbolicLambda(infer_symbolic_func_name)); + ap::memory::Guard guard{}; + ap::axpr::Interpreter interpreter( + ap::paddle::MakeBuiltinFrameAttrMap(), + guard.circlable_ref_list()); + ADT_LET_CONST_REF( + ret_val, + interpreter.Interpret(lambda, {infer_ctx_val, inputs_val, attrs_val})); + ADT_LET_CONST_REF(lst, axpr::AbstractList::CastFrom(ret_val)); + std::vector ret{}; + ADT_LET_CONST_REF(lst_size, lst.size()); + ret.reserve(lst_size); + for (int i = 0; i < lst_size; ++i) { + ADT_LET_CONST_REF(elt_val, lst.at(i)); + ADT_LET_CONST_REF(shape_or_data, + elt_val.template CastTo()); + ADT_CHECK(shape_or_data.template isa()); + ret.emplace_back( + shape_or_data.template dyn_cast()); + } + return ret; +} + +adt::Result TryApOpFacadeOpInferSymbolicShape( + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { + ADT_LET_CONST_REF(infer_ctx_val, GetInferCtxVal(infer_context)); + ADT_LET_CONST_REF(inputs_val, GetApOpFacadeOpInputsVal(op, infer_context)); + ADT_LET_CONST_REF(attrs_val, GetApOpFacadeOpAttrsVal(op)); + ADT_LET_CONST_REF(infer_symbolic_func_name, GetInferSymbolicFuncName(op)); + ADT_LET_CONST_REF( + outputs_shape_or_value, + InferOutputsShapeOrValue( + infer_symbolic_func_name, infer_ctx_val, inputs_val, attrs_val)); + ADT_CHECK(op->num_results() == outputs_shape_or_value.size()); + for (int i = 0; i < op->num_results(); ++i) { + infer_context->SetShapeOrDataForValue(op->result(i), + outputs_shape_or_value.at(i)); + } + return adt::Ok{}; +} + +adt::Result TryPdOpApFacadeOpInferSymbolicShape( + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { + ADT_LET_CONST_REF(infer_ctx_val, GetInferCtxVal(infer_context)); + ADT_LET_CONST_REF(inputs_val, GetPdOpApFacadeOpInputsVal(op, infer_context)); + ADT_LET_CONST_REF(attrs_val, GetPdOpApFacadeOpAttrsVal(op)); + ADT_LET_CONST_REF(infer_symbolic_func_name, GetInferSymbolicFuncName(op)); + ADT_LET_CONST_REF( + outputs_shape_or_value, + InferOutputsShapeOrValue( + infer_symbolic_func_name, infer_ctx_val, inputs_val, attrs_val)); + std::size_t num_outputs = 0; + { + const auto iter = op->attributes().find("num_outputs"); + ADT_CHECK(iter != op->attributes().end()); + const auto& num_outputs_attr = iter->second; + ADT_CHECK(num_outputs_attr.isa()); + num_outputs = num_outputs_attr.dyn_cast().data(); + } + ADT_CHECK(num_outputs == outputs_shape_or_value.size()); + ADT_CHECK(op->num_results(), 1); + symbol::ShapeOrDataDimExprs shape_or_value{outputs_shape_or_value}; + infer_context->SetShapeOrDataForValue(op->result(0), shape_or_value); + return adt::Ok{}; +} + +} // namespace + +bool ApOpFacadeOpInferSymbolicShape( + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { + const auto& ret = TryApOpFacadeOpInferSymbolicShape(op, infer_context); + bool success = !ret.HasError(); + PADDLE_ENFORCE_EQ(success, + true, + phi::errors::Fatal("ApOpFacadeOpInferSymbolicShape failed. " + "\nTraceback (most recent call " + "last):\n%s\n%s: %s. ", + ret.GetError().CallStackToString(), + ret.GetError().class_name(), + ret.GetError().msg())); + return success; +} + +bool PdOpApFacadeOpInferSymbolicShape( + pir::Operation* op, pir::InferSymbolicShapeContext* infer_context) { + const auto& ret = TryPdOpApFacadeOpInferSymbolicShape(op, infer_context); + bool success = !ret.HasError(); + PADDLE_ENFORCE_EQ( + success, + true, + phi::errors::Fatal("PdOpApFacadeOpInferSymbolicShape failed. " + "\nTraceback (most recent call " + "last):\n%s\n%s: %s. ", + ret.GetError().CallStackToString(), + ret.GetError().class_name(), + ret.GetError().msg())); + return success; +} + +} // namespace ap::dialect diff --git a/paddle/ap/src/paddle/pir/pir_method_class.cc b/paddle/ap/src/paddle/pir/pir_method_class.cc index 0c479d7f609359..be64573673bcd3 100644 --- a/paddle/ap/src/paddle/pir/pir_method_class.cc +++ b/paddle/ap/src/paddle/pir/pir_method_class.cc @@ -39,30 +39,10 @@ void DefineMethods(Builder* m) { #define DEF_MAKE_TYPE(cls) m->Def(cls::name(), &MakePirTypeImpl::Call); FOR_EACH_PIR_ALTERNATIVE_TYPE(DEF_MAKE_TYPE); #undef DEF_MAKE_TYPE + ForEachShapeOrDataMaker( + [&](const auto& name, const auto& value) { m->Def(name, value); }); } REGISTER_AP_BUILTIN_MODULE("pir", [](auto* m) { DefineMethods(m); }); -axpr::TypeImpl> GetPirClass() { - static auto cls( - axpr::MakeBuiltinClass("pir", [&](const auto& Yield) { - Yield("UndefinedPlace", &CreateUndefinedPlace); - Yield("CPUPlace", &CreateCPUPlace); - Yield("GPUPlace", &CreateGPUPlace); - Yield("GPUPinnedPlace", &CreateGPUPinnedPlace); - Yield("XPUPlace", &CreateXPUPlace); - Yield("IPUPlace", &CreateIPUPlace); - Yield("CustomPlace", &CreateCustomPlace); -#define YIELD_MAKE_ATTRIBUTE(attr_type) \ - Yield(attr_type::name(), &MakePirAttributeImpl::Call); - FOR_EACH_PIR_ATTRIBUTE_TYPE(YIELD_MAKE_ATTRIBUTE); -#undef YIELD_MAKE_ATTRIBUTE - -#define YIELD_MAKE_TYPE(cls) Yield(cls::name(), &MakePirTypeImpl::Call); - FOR_EACH_PIR_ALTERNATIVE_TYPE(YIELD_MAKE_TYPE); -#undef YIELD_MAKE_TYPE - })); - return axpr::MakeGlobalNaiveClassOps(cls); -} - } // namespace ap::paddle diff --git a/paddle/ap/src/paddle/pir/pir_node_matched_src_ptn_ctx_helper.cc b/paddle/ap/src/paddle/pir/pir_node_matched_src_ptn_ctx_helper.cc index c637a9eb72316e..ad068694c7219e 100644 --- a/paddle/ap/src/paddle/pir/pir_node_matched_src_ptn_ctx_helper.cc +++ b/paddle/ap/src/paddle/pir/pir_node_matched_src_ptn_ctx_helper.cc @@ -278,8 +278,7 @@ std::unique_ptr MakeSourcePatternCtxBuilder( node_arena, std::map{}, drr_ctx}, drr::TensorPatternCtx{ node_arena, std::map{}, drr_ctx}}; - const auto& builtin_frame = - ap::drr::MakeBuiltinFrameAttrMap([&](const auto&) {}); + const auto& builtin_frame = ap::drr::MakeBuiltinFrameAttrMap(); auto interpreter = std::make_unique( builtin_frame, drr_ctx->circlable_ref_list); return std::make_unique(src_ptn_ctx, diff --git a/paddle/ap/src/paddle/pir/shape_or_data_method_class.cc b/paddle/ap/src/paddle/pir/shape_or_data_method_class.cc index d4debbff0af3f8..99ec2f4f31705f 100644 --- a/paddle/ap/src/paddle/pir/shape_or_data_method_class.cc +++ b/paddle/ap/src/paddle/pir/shape_or_data_method_class.cc @@ -12,13 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/ap/include/paddle/pir/shape_or_data_method_class.h" +#include "paddle/ap/include/axpr/abstract_list.h" #include "paddle/ap/include/axpr/callable_helper.h" #include "paddle/ap/include/axpr/data_type_util.h" +#include "paddle/ap/include/axpr/dim_expr.h" #include "paddle/ap/include/paddle/pir/type_adt_type_id.h" #include "paddle/ap/include/paddle/pir/type_method_class.h" namespace ap::paddle { +axpr::TypeImpl> +GetPirShapeOrDataClass(); + adt::Result PirShapeOrDataString( const axpr::Value& self_val, const std::vector& args) { ADT_LET_CONST_REF(self, @@ -28,12 +34,196 @@ adt::Result PirShapeOrDataString( return ss.str(); } +adt::Result> GetConstructorArgsImpl( + const symbol::NullShapeOrDataDimExpr& impl) { + return adt::List{}; +} + +adt::Result> GetConstructorArgsImpl( + const symbol::TensorShapeOrDataDimExprs& impl) { + adt::List shape{}; + shape->reserve(impl.shape().size()); + for (const auto& dim_expr : impl.shape()) { + shape->emplace_back(axpr::GetDimExprClass().New(dim_expr)); + } + if (impl.data().has_value()) { + adt::List data{}; + data->reserve(impl.data().value().size()); + for (const auto& dim_expr : impl.data().value()) { + data->emplace_back(axpr::GetDimExprClass().New(dim_expr)); + } + return adt::List{axpr::Value{shape}, axpr::Value{data}}; + } else { + return adt::List{axpr::Value{shape}, + axpr::Value{adt::Nothing{}}}; + } +} + +adt::Result> GetConstructorArgsImpl( + const symbol::TensorListShapeOrDataDimExprs& impl) { + adt::List lst{}; + lst->reserve(impl.size()); + for (const auto& shape_or_data : impl) { + lst->push_back(GetPirShapeOrDataClass().New(shape_or_data)); + } + return adt::List{axpr::Value{lst}}; +} + +adt::Result> GetConstructorArgsImpl( + const symbol::RankedTensorArrayShapeOrDataDimExprs& impl) { + return adt::errors::NotImplementedError{ + "pir.s_ranked_tensor_array_shape_or_data not implemented"}; +} + +std::string PirShapeOrDataGetTypeNameImpl( + const symbol::ShapeOrDataDimExprs& self) { + return self.Match( + [](const symbol::NullShapeOrDataDimExpr&) -> std::string { + return "s_null"; + }, + [](const symbol::TensorShapeOrDataDimExprs&) -> std::string { + return "s_tensor_shape_or_data"; + }, + [](const symbol::TensorListShapeOrDataDimExprs&) -> std::string { + return "s_tensor_list_shape_or_data"; + }, + [](const symbol::RankedTensorArrayShapeOrDataDimExprs&) -> std::string { + return "s_ranked_tensor_array_shape_or_data"; + }); +} + +adt::Result PirShapeOrDataGetTypeName( + const axpr::Value& self_val, const std::vector& args) { + ADT_CHECK(args.size() == 0); + ADT_LET_CONST_REF(self, + self_val.template CastTo()); + return PirShapeOrDataGetTypeNameImpl(self); +} + +adt::Result PirShapeOrDataMatch( + axpr::InterpreterBase* interpreter, + const axpr::Value& self_val, + const std::vector& packed_args_val) { + ADT_LET_CONST_REF(self, + self_val.template CastTo()); + const auto& packed_args = + axpr::CastToPackedArgs(packed_args_val); + const auto& type_name = PirShapeOrDataGetTypeNameImpl(self); + const auto& [args, kwargs] = *packed_args; + ADT_CHECK(args->size() == 0) << adt::errors::TypeError{ + std::string() + + "PirShapeOrData.match() supports keyword arguments only, but " + + std::to_string(args->size()) + " positional arguments were given"}; + + std::string key = type_name; + if (!kwargs->Has(type_name)) { + if (!kwargs->Has("_")) { + return adt::errors::TypeError{ + std::string() + "PirShapeOrData.match() failed. no keyword '" + + type_name + "' or '_' provided"}; + } + key = "_"; + } + ADT_LET_CONST_REF(func, kwargs->Get(key)); + auto GetConstructorArgs = + [&](const auto& impl) -> adt::Result> { + return GetConstructorArgsImpl(impl); + }; + ADT_LET_CONST_REF(shape_or_data_constructor_args, + self.Match(GetConstructorArgs)); + ADT_CHECK(axpr::CallableHelper{}.IsCallable(func)) << adt::errors::TypeError{ + std::string() + + "the arguments of PirShapeOrData.match() should be callable"}; + if (key == "_") { + return interpreter->InterpretCall(func, {}); + } else { + return interpreter->InterpretCall(func, + shape_or_data_constructor_args.vector()); + } +} + axpr::TypeImpl> GetPirShapeOrDataClass() { static auto cls(axpr::MakeBuiltinClass( - "PirShapeOrData", - [&](const auto& Yield) { Yield("__str__", &PirShapeOrDataString); })); + "PirShapeOrData", [&](const auto& Yield) { + Yield("__str__", &PirShapeOrDataString); + Yield("get_type_name", &PirShapeOrDataGetTypeName); + Yield("match", &PirShapeOrDataMatch); + })); return axpr::MakeGlobalNaiveClassOps(cls); } +adt::Result MakeNullShapeOrDataDimExpr( + const axpr::Value&, const std::vector& args) { + ADT_CHECK(args.size() == 0); + symbol::ShapeOrDataDimExprs shape_or_data{symbol::NullShapeOrDataDimExpr{}}; + return GetPirShapeOrDataClass().New(shape_or_data); +} + +adt::Result> GetDimExprs( + const axpr::Value& dim_exprs) { + ADT_LET_CONST_REF(lst, axpr::AbstractList::CastFrom(dim_exprs)); + std::vector ret; + ADT_LET_CONST_REF(lst_size, lst.size()); + ret.reserve(lst_size); + for (int i = 0; i < lst_size; ++i) { + ADT_LET_CONST_REF(elt_val, lst.at(i)); + ADT_LET_CONST_REF(elt, elt_val.template CastTo()); + ret.emplace_back(elt); + } + return ret; +} + +adt::Result>> GetDataByArgVec( + const std::vector& args) { + if (args.size() != 2) { + return std::nullopt; + } + if (args.at(1).template CastableTo()) { + return std::nullopt; + } + ADT_LET_CONST_REF(dim_exprs, GetDimExprs(args.at(1))); + return dim_exprs; +} + +adt::Result MakeTensorShapeOrDataDimExprs( + const axpr::Value&, const std::vector& args) { + ADT_CHECK(args.size() == 1 || args.size() == 2); + ADT_LET_CONST_REF(shape, GetDimExprs(args.at(0))); + ADT_LET_CONST_REF(opt_data, GetDataByArgVec(args)); + if (opt_data.has_value()) { + symbol::TensorShapeOrDataDimExprs tensor_shape_or_data{shape, + opt_data.value()}; + symbol::ShapeOrDataDimExprs shape_or_data{tensor_shape_or_data}; + return GetPirShapeOrDataClass().New(shape_or_data); + } else { + symbol::TensorShapeOrDataDimExprs tensor_shape_or_data{shape}; + symbol::ShapeOrDataDimExprs shape_or_data{tensor_shape_or_data}; + return GetPirShapeOrDataClass().New(shape_or_data); + } +} + +adt::Result MakeTensorListShapeOrDataDimExprs( + const axpr::Value&, const std::vector& args) { + ADT_CHECK(args.size() == 1); + ADT_LET_CONST_REF(lst, axpr::AbstractList::CastFrom(args.at(0))); + std::vector elts; + ADT_LET_CONST_REF(lst_size, lst.size()); + elts.reserve(lst_size); + for (int i = 0; i < lst_size; ++i) { + ADT_LET_CONST_REF(elt_val, lst.at(i)); + ADT_LET_CONST_REF( + elt, elt_val.template CastTo()); + elts.emplace_back(elt); + } + symbol::ShapeOrDataDimExprs shape_or_data{elts}; + return GetPirShapeOrDataClass().New(shape_or_data); +} + +adt::Result MakeRankedTensorArrayShapeOrDataDimExprs( + const axpr::Value&, const std::vector& args) { + return adt::errors::NotImplementedError{ + "pir.s_ranked_tensor_array_shape_or_data not implemented"}; +} + } // namespace ap::paddle diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index 12bae7211efa08..d86ef006585806 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -29,6 +29,7 @@ #include "paddle/ap/include/memory/guard.h" #include "paddle/ap/include/paddle/pass/ap_generic_drr_pass.h" +#include "paddle/ap/include/paddle/pass/convert_pd_facade_to_ap_facade.h" #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.h" @@ -227,9 +228,19 @@ void ApplyApGenericDrrPass( ::pir::Program* program, const std::function()>& CreatePassManager) { - std::shared_ptr pass_manager = CreatePassManager(); + { + pir::IrPrinter(LOG(ERROR) << "before ConvertPdFacadeToApFacadePass:\n") + .PrintProgram(program); + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass(CreateConvertPdFacadeToApFacadePass()); + pass_manager->Run(program); + pir::IrPrinter(LOG(ERROR) << "after ConvertPdFacadeToApFacadePass:\n") + .PrintProgram(program); + } ap::memory::Guard guard{}; if (auto pass = CreateApGenericClassicDrrPass(guard.circlable_ref_list())) { + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass(CreateConvertPdFacadeToApFacadePass()); pass_manager->AddPass(std::move(pass.value())); pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); pir::IrPrinter(LOG(ERROR) << "before ApGenericClassicDrrPass:\n") @@ -239,6 +250,7 @@ void ApplyApGenericDrrPass( .PrintProgram(program); } if (auto pass = CreateApGenericAbstractDrrPass(guard.circlable_ref_list())) { + std::shared_ptr pass_manager = CreatePassManager(); pass_manager->AddPass(std::move(pass.value())); pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); pir::IrPrinter(LOG(ERROR) << "before ApGenericAbstractDrrPass:\n") diff --git a/paddle/fluid/pir/dialect/CMakeLists.txt b/paddle/fluid/pir/dialect/CMakeLists.txt index f8e7598e48a0b1..31d7611f88c789 100644 --- a/paddle/fluid/pir/dialect/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/CMakeLists.txt @@ -304,6 +304,7 @@ file(GLOB_RECURSE dist_dialect_srcs # if(WITH_DISTRIBUTE) FIXME in next PR set(op_dialect_srcs ${op_dialect_srcs} ${dist_dialect_srcs}) # endif() + set(op_dialect_deps phi common @@ -312,6 +313,9 @@ set(op_dialect_deps string_helper global_utils amp) +if(WITH_CINN) + set(op_dialect_deps ${op_dialect_deps} ap_pir) +endif() if(WITH_ROCM) set(op_dialect_deps ${op_dialect_deps} global_utils) endif() diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index f505cf0841d8ff..9b29cdadcabe54 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -70,6 +70,8 @@ 'ResnetBasicBlockGradInferMeta', # multiary.h 'AddNInferMeta', + 'ApVariadicInferMeta', + 'ApFacadeInferMeta', 'AddNTensorArrayInferMeta', 'AttentionLstmInferMeta', 'AucInferMeta', diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.cc new file mode 100644 index 00000000000000..ce4702ec343065 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.h" + +#include "paddle/common/ddim.h" +#include "paddle/common/enforce.h" +#include "paddle/common/layout.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" + +#ifdef PADDLE_WITH_CINN +#include "paddle/ap/include/paddle/pir/infer_symbolic_shape_util.h" +#endif + +namespace paddle::dialect { + +bool ApFacadeOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { +#ifdef PADDLE_WITH_CINN + return ap::dialect::PdOpApFacadeOpInferSymbolicShape(op, infer_context); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "ap_facade is not implemented when cinn is not enabled.")); + return false; +#endif +} + +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.h new file mode 100644 index 00000000000000..65f76357ede436 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.h @@ -0,0 +1,23 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace paddle::dialect { + +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ApFacade) + +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h index b45883428e4bff..5588668e1b0606 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/ap_infer_sym.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/backward_infer_sym.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h" diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index d79fe55f99321d..aa75119ba556a0 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -90,6 +90,7 @@ #include "pybind11/stl.h" #ifdef PADDLE_WITH_CINN +#include "paddle/ap/include/paddle/hlir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_util.h" @@ -2754,6 +2755,7 @@ void ApplyCinnPass(Program &program) { // NOLINT pir::IrContext *ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); auto pass_manager = std::make_shared(ctx); if (FLAGS_print_ir && VLOG_IS_ON(4)) { diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index bf08f04af8e77e..ae48a2d44e66a6 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -482,12 +482,44 @@ void ApVariadicInferMeta(const std::vector& xs, #ifdef PADDLE_WITH_CINN ApInferMetaHelper helper{}; const auto& ret = helper.InferMeta(infer_meta_lambda, &xs, &outs); + PADDLE_ENFORCE_EQ( + ret.HasError(), + false, + phi::errors::Fatal( + "ApVariadicInferMeta failed. \nTraceback (most recent call " + "last):\n%s\n%s: %s. ", + ret.GetError().CallStackToString(), + ret.GetError().class_name(), + ret.GetError().msg())); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "ap_variadic is not implemented when cinn is not enabled.")); +#endif +} + +void ApFacadeInferMeta( + const paddle::optional>& xs, + int64_t num_outputs, + const std::string& custom_op_name, + const std::string& infer_meta_func_name, + const std::string& infer_symbolic_func_name, + const std::string& serialized_attributes, + std::vector outs, + MetaConfig config) { +#ifdef PADDLE_WITH_CINN + ApInferMetaHelper helper{}; + const auto& ret = helper.InferMetaByAxprHook( + xs, infer_meta_func_name, serialized_attributes, outs); PADDLE_ENFORCE(!ret.HasError(), - "ApVariadicInferMeta failed. \nTraceback (most recent call " - "last):\n%s\n%s: %s. ", - ret.GetError().CallStackToString(), - ret.GetError().class_name(), - ret.GetError().msg()); + phi::errors::Fatal( + "ApFacadeInferMeta failed. \nTraceback (most recent call " + "last):\n%s\n%s: %s. ", + ret.GetError().CallStackToString(), + ret.GetError().class_name(), + ret.GetError().msg())); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "ap_facade is not implemented when cinn is not enabled.")); #endif } diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 618287989180f2..435bc8f5ece1a9 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -140,6 +140,16 @@ void AddNInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void ApFacadeInferMeta( + const paddle::optional>& xs, + int64_t num_outputs, + const std::string& custom_op_name, + const std::string& infer_meta_func_name, + const std::string& infer_symbolic_func_name, + const std::string& serialized_attributes, + std::vector outs, + MetaConfig config = MetaConfig()); + void ApVariadicInferMeta(const std::vector& xs, int num_outputs, const std::string& code_module_lambda, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 45dcbc36399577..19b9f35bbf5a2d 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -241,8 +241,10 @@ endif() # Remove AP kernel when CINN is not enabled. if(NOT WITH_CINN) - list(REMOVE_ITEM kernel_cu "gpu/ap_variadic_kernel.cu") - list(REMOVE_ITEM kernel_gpu "gpu/ap_variadic_kernel.cu") + list(REMOVE_ITEM kernel_cu "gpu/ap_facade_kernel.cu" + "gpu/ap_variadic_kernel.cu") + list(REMOVE_ITEM kernel_gpu "gpu/ap_facade_kernel.cu" + "gpu/ap_variadic_kernel.cu") endif() set(cc_search_pattern diff --git a/paddle/phi/kernels/gpu/ap_facade_kernel.cu b/paddle/phi/kernels/gpu/ap_facade_kernel.cu new file mode 100644 index 00000000000000..1d57345118480b --- /dev/null +++ b/paddle/phi/kernels/gpu/ap_facade_kernel.cu @@ -0,0 +1,48 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/common/enforce.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void ApFacadeKernel(const Context& dev_ctx, + const paddle::optional>& xs, + int64_t num_outputs, + const std::string& custom_op_name, + const std::string& infer_meta_func_name, + const std::string& infer_symbolic_func_name, + const std::string& serialized_attributes, + std::vector outs) { + PADDLE_THROW( + common::errors::Unimplemented("ap_facade has no kernel registered.")); +} + +} // namespace phi + +PD_REGISTER_KERNEL(ap_facade, + GPU, + ALL_LAYOUT, + phi::ApFacadeKernel, + float, + double, + int, + phi::dtype::bfloat16, + phi::dtype::float16, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/ap_variadic_kernel.cu b/paddle/phi/kernels/gpu/ap_variadic_kernel.cu index 340ed89cd52caf..19c45bbe0a27a7 100644 --- a/paddle/phi/kernels/gpu/ap_variadic_kernel.cu +++ b/paddle/phi/kernels/gpu/ap_variadic_kernel.cu @@ -12,18 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include "glog/logging.h" -#include "jitify.hpp" // NOLINT #include "paddle/common/enforce.h" - #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" -#include "paddle/phi/kernels/impl/activation_grad_impl.h" -#include "paddle/phi/kernels/impl/activation_impl.h" #include "paddle/ap/include/axpr/data_type_util.h" #include "paddle/ap/include/kernel_dispatch/ap_variadic_kernel.h" @@ -101,12 +93,13 @@ void ApVariadicKernel(const Context& dev_ctx, kernel_dispatch_lambda, kernel_dispatch_const_data_lambda, outs); - PADDLE_ENFORCE( - !ret.HasError(), - "ap_kernel failed. \nTraceback (most recent call last):\n%s\n%s: %s. ", - ret.GetError().CallStackToString(), - ret.GetError().class_name(), - ret.GetError().msg()); + PADDLE_ENFORCE_EQ(ret.HasError(), + false, + phi::errors::Fatal("ap_variadic failed. \nTraceback (most " + "recent call last):\n%s\n%s: %s. ", + ret.GetError().CallStackToString(), + ret.GetError().class_name(), + ret.GetError().msg())); } } // namespace phi diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 24e1cb816cff6a..7b4044b5aa1c8a 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -278,6 +278,17 @@ traits : paddle::dialect::ForwardOnlyTrait interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface +- op : ap_facade + args : (Tensor[] xs, int64_t num_outputs, str custom_op_name, str infer_meta_func_name, str infer_symbolic_func_name, str serialized_attributes) + output : Tensor[](out){num_outputs} + optional : xs + infer_meta : + func : ApFacadeInferMeta + interfaces : paddle::dialect::InferSymbolicShapeInterface + kernel : + func : ap_facade + traits : paddle::dialect::ForwardOnlyTrait + - op : ap_variadic args : (Tensor[] xs, int num_outputs, str code_module_lambda, str infer_meta_lambda, str rnel_dispatch_lambda, str kernel_dispatch_const_data_lambda) output : Tensor[](out){num_outputs} @@ -285,6 +296,7 @@ func : ApVariadicInferMeta kernel : func : ap_variadic + traits : paddle::dialect::ForwardOnlyTrait - op : apply_per_channel_scale args: (Tensor x, Tensor scales) diff --git a/python/paddle/incubate/cc/__init__.py b/python/paddle/incubate/cc/__init__.py index 930555f7ff3586..a9db10e531636d 100644 --- a/python/paddle/incubate/cc/__init__.py +++ b/python/paddle/incubate/cc/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import fuse +from . import ap as ap, fuse as fuse from .compiler import compile __all__ = ['fuse', 'compile'] diff --git a/python/paddle/incubate/cc/ap/__init__.py b/python/paddle/incubate/cc/ap/__init__.py new file mode 100644 index 00000000000000..9581d32e5c1dfe --- /dev/null +++ b/python/paddle/incubate/cc/ap/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .facade_op import FacadeOp as FacadeOp diff --git a/python/paddle/incubate/cc/ap/facade_op.py b/python/paddle/incubate/cc/ap/facade_op.py new file mode 100644 index 00000000000000..39ef8464c8286a --- /dev/null +++ b/python/paddle/incubate/cc/ap/facade_op.py @@ -0,0 +1,94 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import paddle + +from .pir_attrs_serializer import PirAttrsSerializer + + +class FacadeOp: + + def __init__(self): + self.custom_op_name_ = self.custom_op_name() + self.infer_meta_ = self._check_to_str_pair(self.infer_meta()) + self.infer_symbolic_ = self._check_to_str_pair(self.infer_symbolic()) + self.num_inputs_ = self.num_inputs() + self.attrs_serializer_ = PirAttrsSerializer(self.attributes_schema) + + def custom_op_name(self) -> str: + raise NotImplementedError( + "static method custom_op_name() is not overwritten" + ) + + def infer_meta(self) -> str: + raise NotImplementedError( + "static method infer_meta() is not overwritten" + ) + + def infer_symbolic(self) -> str: + raise NotImplementedError( + "static method infer_symbolic() is not overwritten" + ) + + def num_inputs(self) -> int: + raise NotImplementedError( + "static method num_inputs() is not overwritten" + ) + + def num_outputs(self, args) -> int: + raise NotImplementedError( + "static method num_outputs() is not overwritten" + ) + + def attributes_schema(self): + # annotations matter. + raise NotImplementedError( + "static method attributes_schema() is not overwritten" + ) + + def __call__(self, args, **kwargs): + if paddle.in_dynamic_mode(): + warnings.warn("ap FacadeOp should not run in dynamic mode") + assert isinstance(args, (tuple, list)) + self._check_num_inputs(len(args)) + serialized_attrs = self.attrs_serializer_(**kwargs) + ret = paddle._C_ops.ap_facade( + args if len(args) > 0 else None, + self.num_outputs(args), + self.custom_op_name_, + self.infer_meta_, + self.infer_symbolic_, + serialized_attrs, + ) + self._check_num_outputs(args, len(ret)) + return ret + + def _check_num_inputs(self, num_args): + if self.num_inputs_ >= 0: + assert self.num_inputs_ == num_args + + def _check_num_outputs(self, args, num_rets): + num_outputs = self.num_outputs(args) + if num_outputs >= 0: + assert num_outputs == num_rets + + def _check_to_str_pair(self, pair_str): + assert isinstance(pair_str, str) + pair = pair_str.split(".") + assert len(pair) == 2 + assert pair[0] not in (None, "") + assert pair[1] not in (None, "") + return pair_str diff --git a/python/paddle/incubate/cc/ap/pir_attrs_serializer.py b/python/paddle/incubate/cc/ap/pir_attrs_serializer.py new file mode 100644 index 00000000000000..62b8bfea583f96 --- /dev/null +++ b/python/paddle/incubate/cc/ap/pir_attrs_serializer.py @@ -0,0 +1,246 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect + +import paddle + +from ..data_type_util import get_dtype_lower_case_name +from ..typing import DType +from .py_to_axpr_json import convert_python_stmts_to_axpr_json + + +class PirAttrsSerializer: + + def __init__(self, func): + self.attributes_schema = self._get_attributes_schema(func) + self._check_attributes_schema(self.attributes_schema) + self.attr_name2serializer = { + attr_name: serializer + for attr_name, schema_item in self.attributes_schema + for serializer in [self._get_serializer(attr_name, schema_item)] + } + + def __call__(self, **attributes): + print(attributes) + attributes_names = {name for name, _ in attributes.items()} + attr_names = {name for name, _ in self.attributes_schema} + assert ( + attributes_names == attr_names + ), f"expected attr_names: {attr_names}, but actual attr_names are {attributes_names}" + py_assigns = "\n".join( + py_stmt + for attr_name, attr_val in attributes.items() + for py_stmt in self.attr_name2serializer[attr_name](attr_val) + ) + py_stmts_str = f"{py_assigns}\n{self._get_attr_map_ctor_str(self.attributes_schema)}" + return convert_python_stmts_to_axpr_json(py_stmts_str) + + def _get_attr_map_ctor_str(self, attributes_schema): + kwargs = ", ".join(f"{name}={name}" for name, _ in attributes_schema) + return f"__builtin__AttrMap({kwargs})" + + def _get_attributes_schema(self, obj): + if isinstance(obj, (list, tuple)): + return obj + func = obj + assert inspect.isfunction(func) or inspect.ismethod(func) + full_arg_spec = inspect.getfullargspec(func) + args = ( + full_arg_spec.args[1:] + if inspect.ismethod(func) + else full_arg_spec.args + ) + return [ + (arg_name, annotation) + for arg_name in args + for annotation in [full_arg_spec.annotations[arg_name]] + ] + + def _check_attributes_schema(self, attributes_schema): + for _, attr_type in attributes_schema: + self._check_attributes_schema_item_is_valid(attr_type) + + def _check_attributes_schema_item_is_valid(self, attr_type): + if attr_type in self._supported_basic_types(): + return + assert isinstance( + attr_type, list + ), f"attribute type {attr_type} is not supported." + assert ( + len(attr_type) == 1 + ), "only syntax like [bool], [int], [float], [str] supported." + assert ( + attr_type[0] in self._supported_basic_types() + ), f"supported list element types are bool/int/float/str, not include {attr_type[0]}." + + def _supported_basic_types(self): + return (bool, int, float, str, DType) + + def _get_serializer(self, attr_name, schema_item): + assert attr_name not in ( + "custom_op_name", + "infer_meta_func_name", + "infer_symbolic_func_name", + ) + schema_item_as_key = self._get_schema_item_as_key(schema_item) + return _get_serializer_factory[schema_item_as_key](attr_name) + + def _get_schema_item_as_key(self, schema_item): + if schema_item in self._supported_basic_types(): + return schema_item + assert isinstance(schema_item, list) + return tuple(schema_item) + + +class PirAttributeSerializer: + + def __init__(self, attr_name): + self.attr_name = attr_name + + def __call__(self, value): + yield from [] + raise NotImplementedError + + +class BoolAttributeSerializer(PirAttributeSerializer): + + def __init__(self, attr_name): + self.attr_name = attr_name + + def __call__(self, value): + assert isinstance(value, bool) + yield f"{self.attr_name} = {value}" + + +class IntAttributeSerializer(PirAttributeSerializer): + + def __init__(self, attr_name): + self.attr_name = attr_name + + def __call__(self, value): + assert isinstance(value, int) + yield f"{self.attr_name} = {value}" + + +class FloatAttributeSerializer(PirAttributeSerializer): + + def __init__(self, attr_name): + self.attr_name = attr_name + + def __call__(self, value): + assert isinstance(value, float) + yield f"{self.attr_name} = {value}" + + +class StrAttributeSerializer(PirAttributeSerializer): + + def __init__(self, attr_name): + self.attr_name = attr_name + + def __call__(self, value): + assert isinstance(value, str) + yield f"{self.attr_name} = {value}" + + +class DTypeAttributeSerializer(PirAttributeSerializer): + + def __init__(self, attr_name): + self.attr_name = attr_name + + def __call__(self, value): + assert isinstance(value, paddle.dtype) + name = get_dtype_lower_case_name(value) + yield f"{self.attr_name} = __builtin__DataType.{name}" + + +class BoolArrayAttributeSerializer(PirAttributeSerializer): + + def __init__(self, attr_name): + self.attr_name = attr_name + + def __call__(self, value): + assert isinstance(value, list) + for elt in value: + assert isinstance(elt, bool) + yield f"{self.attr_name} = {value}" + + +class IntArrayAttributeSerializer(PirAttributeSerializer): + + def __init__(self, attr_name): + self.attr_name = attr_name + + def __call__(self, value): + assert isinstance(value, list) + for elt in value: + assert isinstance(elt, int) + yield f"{self.attr_name} = {value}" + + +class FloatArrayAttributeSerializer(PirAttributeSerializer): + + def __init__(self, attr_name): + self.attr_name = attr_name + + def __call__(self, value): + assert isinstance(value, list) + for elt in value: + assert isinstance(elt, float) + yield f"{self.attr_name} = {value}" + + +class StrArrayAttributeSerializer(PirAttributeSerializer): + + def __init__(self, attr_name): + self.attr_name = attr_name + + def __call__(self, value): + assert isinstance(value, list) + for elt in value: + assert isinstance(elt, str) + yield f"{self.attr_name} = {value}" + + +class DTypeArrayAttributeSerializer(PirAttributeSerializer): + + def __init__(self, attr_name): + self.attr_name = attr_name + + def __call__(self, value): + assert isinstance(value, list) + for elt in value: + assert isinstance(elt, paddle.dtype) + value_str = ", ".join( + f"__builtin__DataType.{name}" + for dtype in value + for name in [get_dtype_lower_case_name(dtype)] + ) + yield f"{self.attr_name} = [{value_str}]" + + +_get_serializer_factory = { + bool: BoolAttributeSerializer, + int: IntAttributeSerializer, + float: FloatAttributeSerializer, + str: StrAttributeSerializer, + DType: DTypeAttributeSerializer, + (bool,): BoolArrayAttributeSerializer, + (int,): IntArrayAttributeSerializer, + (float,): FloatArrayAttributeSerializer, + (str,): StrArrayAttributeSerializer, + (DType,): DTypeArrayAttributeSerializer, +} diff --git a/python/paddle/incubate/cc/ap/py_to_axpr_json.py b/python/paddle/incubate/cc/ap/py_to_axpr_json.py new file mode 100644 index 00000000000000..482072e90b8c63 --- /dev/null +++ b/python/paddle/incubate/cc/ap/py_to_axpr_json.py @@ -0,0 +1,567 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import ast +import functools +import itertools +import json +import operator +import sys +import typing as t +from dataclasses import dataclass + + +def convert_python_stmts_to_axpr_json(python_code_stmts_str): + tree = ast.parse(python_code_stmts_str) + parser = PyToAnfParser() + return parser(tree).ConvertToAnfExpr().JsonDump() + + +@dataclass +class AnfExpr: + + def DumpToFileAsJson(self, file_name): + with open(file_name, "w") as f: + json.dump(self.value, f, indent=2) + + def JsonDump(self): + return json.dumps(self.value) + + +@dataclass +class AtomicAnfExpr(AnfExpr): + value: t.Any + + +@dataclass +class CombinedAnfExpr(AnfExpr): + value: t.Any + + +@dataclass +class AnfParseResult: + bindings: list[str] + body_atomic_anf_expr: AtomicAnfExpr + + def __add__(self, other): + return AnfParseResult( + bindings=[*self.bindings, *other.bindings], + body_atomic_anf_expr=other.body_atomic_anf_expr, + ) + + def ConvertToAnfExpr(self): + ret = self.body_atomic_anf_expr + if len(self.bindings) == 0: + return ret + assert isinstance(ret, AtomicAnfExpr) + ret = CombinedAnfExpr( + ["__builtin_identity__", self.body_atomic_anf_expr.value] + ) + return CombinedAnfExpr(["__builtin_let__", self.bindings, ret.value]) + + +class PyToAnfParser: + def __init__(self, seq_no_counter=None, return_count_constraint=None): + self.bindings = [] + self.seq_no_counter = ( + seq_no_counter if seq_no_counter is not None else itertools.count() + ) + self.return_count_constraint = ( + return_count_constraint + if return_count_constraint is not None + else ReturnCounterConstraint(limits=1) + ) + + def __call__(self, tree): + ret = self.Parse(tree) + return AnfParseResult(bindings=self.bindings, body_atomic_anf_expr=ret) + + def Parse(self, tree): + method_name = f"Parse{type(tree).__name__}" + return getattr(self, method_name)(tree) + + def ParseImport(self, tree): + for alias in tree.names: + assert isinstance(alias, ast.alias) + name = alias.name + asname = alias.asname if alias.asname is not None else name + self.Bind(asname, ["import", {"str": name}]) + return AtomicAnfExpr(None) + + def ParseClassDef(self, tree: ast.ClassDef): + assert len(tree.keywords) == 0 + class_name = tree.name + + def GetBases(): + bases = [self.Parse(base) for base in tree.bases] + return self.BindToTmpVar( + ['__builtin_list__', *[x.value for x in bases]] + ) + + def GetFunctions(): + body_name_and_method_pair = [] + for func_def in tree.body: + if isinstance(func_def, ast.Pass): + continue + assert isinstance( + func_def, ast.FunctionDef + ), f"only method supported in class definition, {type(func_def)} were given." + func_code = self.BindToTmpVar( + [ + '__builtin_getattr__', + self.Parse(func_def).value, + {"str": '__function__'}, + ] + ) + pair = self.BindToTmpVar( + [ + "__builtin_list__", + {"str": func_def.name}, + func_code.value, + ] + ) + body_name_and_method_pair.append(pair) + positional_args = self.BindToTmpVar(['__builtin_list__']) + keyword_args = self.BindToTmpVar( + [ + '__builtin_list__', + *[x.value for x in body_name_and_method_pair], + ] + ) + packed_args = self.BindToTmpVar( + [ + '__builtin_PackedArgs__', + positional_args.value, + keyword_args.value, + ] + ) + return self.BindToTmpVar( + ['BuiltinSerializableAttrMap', packed_args.value] + ) + + class_anf_expr = self.BindToTmpVar( + [ + 'type', + {"str": class_name}, + GetBases().value, + GetFunctions().value, + ] + ) + for elt in reversed(tree.decorator_list): + decorator = self.Parse(elt) + class_anf_expr = self.BindToTmpVar( + [decorator.value, class_anf_expr.value] + ) + self.Bind(class_name, class_anf_expr) + return class_anf_expr + + def Parsekeyword(self, tree): + value = self.Parse(tree.value) + return self.BindToTmpVar( + ["__builtin_list__", {"str": tree.arg}, value.value] + ) + + def ParseBinOp(self, tree): + left = self.Parse(tree.left) + op = self.Parse(tree.op) + right = self.Parse(tree.right) + return self.BindToTmpVar([op.value, left.value, right.value]) + + def ParseUnaryOp(self, tree): + op = self.Parse(tree.op) + operand = self.Parse(tree.operand) + return self.BindToTmpVar([op.value, operand.value]) + + def ParseCompare(self, tree): + assert len(tree.ops) == 1 + op = self.Parse(tree.ops[0]) + left = self.Parse(tree.left) + assert len(tree.comparators) == 1 + right = self.Parse(tree.comparators[0]) + return self.BindToTmpVar([op.value, left.value, right.value]) + + def ParseAdd(self, tree): + return AtomicAnfExpr("__builtin_Add__") + + def ParseSub(self, tree): + return AtomicAnfExpr("__builtin_Sub__") + + def ParseMult(self, tree): + return AtomicAnfExpr("__builtin_Mul__") + + def ParseDiv(self, tree): + return AtomicAnfExpr("__builtin_Div__") + + def ParseFloorDiv(self, tree): + return AtomicAnfExpr("__builtin_FloorDiv__") + + def ParseMod(self, tree): + return AtomicAnfExpr("__builtin_Mod__") + + def ParseUSub(self, tree): + return AtomicAnfExpr("__builtin_Neg__") + + def ParseEq(self, tree): + return AtomicAnfExpr("__builtin_EQ__") + + def ParseNotEq(self, tree): + return AtomicAnfExpr("__builtin_NE__") + + def ParseGt(self, tree): + return AtomicAnfExpr("__builtin_GT__") + + def ParseGtE(self, tree): + return AtomicAnfExpr("__builtin_GE__") + + def ParseLt(self, tree): + return AtomicAnfExpr("__builtin_LT__") + + def ParseLtE(self, tree): + return AtomicAnfExpr("__builtin_LE__") + + def ParseModule(self, module: ast.Module): + parse_result = AnfParseResult( + bindings=[], body_atomic_anf_expr=AtomicAnfExpr(None) + ) + if len(module.body) > 0: + seq_no_counter = itertools.count() + return_count_constraint = ReturnCounterConstraint(limits=0) + parse_result = functools.reduce( + operator.add, + ( + PyToAnfParser(seq_no_counter, return_count_constraint)(tree) + for tree in module.body + ), + ) + return parse_result.ConvertToAnfExpr() + + def ParseFunctionDef(self, function_def: ast.FunctionDef): + if len(function_def.body) > 0: + return_count_constraint = ReturnCounterConstraint(limits=1) + return_stmt_idx = self.GetStmtSizeUntilReturn(function_def.body) + parse_result = functools.reduce( + operator.add, + [ + PyToAnfParser(self.seq_no_counter, return_count_constraint)( + tree + ) + for tree in function_def.body[0:return_stmt_idx] + if not isinstance(tree, ast.Pass) + ] + + [ + AnfParseResult( + bindings=[], body_atomic_anf_expr=AtomicAnfExpr(None) + ) + ], + ) + else: + parse_result = AnfParseResult( + bindings=[], body_atomic_anf_expr=AtomicAnfExpr(None) + ) + args = [arg.arg for arg in function_def.args.args] + lmbd = AtomicAnfExpr( + ['lambda', args, parse_result.ConvertToAnfExpr().value] + ) + for elt in reversed(function_def.decorator_list): + decorator = self.Parse(elt) + lmbd = self.BindToTmpVar([decorator.value, lmbd.value]) + func_name = function_def.name + self.Bind(func_name, lmbd) + return AtomicAnfExpr(func_name) + + def ParseLambda(self, function_def: ast.Lambda): + return_count_constraint = ReturnCounterConstraint(limits=0) + parser = PyToAnfParser(self.seq_no_counter, return_count_constraint) + parse_result = parser(function_def.body) + args = [arg.arg for arg in function_def.args.args] + return AtomicAnfExpr( + ['lambda', args, parse_result.ConvertToAnfExpr().value] + ) + + def ParseIfExp(self, if_expr: ast.IfExp): + test_value = self.Parse(if_expr.test) + true_value = self.ParseExprTo0ArgLambda(if_expr.body) + false_value = self.ParseExprTo0ArgLambda(if_expr.orelse) + ret = self.BindToTmpVar( + [ + '__builtin_if__', + test_value.value, + true_value.value, + false_value.value, + ] + ) + return ret + + def ParseBoolOp(self, bool_op: ast.BoolOp): + name = type(bool_op.op).__name__ + method = f"Parse{name}" + return getattr(self, method)(bool_op) + + def ParseOr(self, bool_op: ast.BoolOp): + assert len(bool_op.values) == 2 + test_value = self.Parse(bool_op.values[0]) + true_value = AtomicAnfExpr(['lambda', [], AtomicAnfExpr(True).value]) + false_value = self.ParseExprTo0ArgLambda(bool_op.values[1]) + ret = self.BindToTmpVar( + [ + '__builtin_if__', + test_value.value, + true_value.value, + false_value.value, + ] + ) + return ret + + def ParseAnd(self, bool_op: ast.BoolOp): + assert len(bool_op.values) == 2 + test_value = self.Parse(bool_op.values[0]) + true_value = self.ParseExprTo0ArgLambda(bool_op.values[1]) + false_value = AtomicAnfExpr(['lambda', [], AtomicAnfExpr(False).value]) + ret = self.BindToTmpVar( + [ + '__builtin_if__', + test_value.value, + true_value.value, + false_value.value, + ] + ) + return ret + + def ParseNot(self, unary_op: ast.UnaryOp): + return AtomicAnfExpr('__builtin_not__') + + def ParseExprTo0ArgLambda(self, expr): + return_count_constraint = ReturnCounterConstraint(limits=0) + parser = PyToAnfParser(self.seq_no_counter, return_count_constraint) + parse_result = parser(expr) + return AtomicAnfExpr( + ['lambda', [], parse_result.ConvertToAnfExpr().value] + ) + + def ParseAssert(self, expr: ast.Assert): + test_value = self.Parse(expr.test) + true_value = AtomicAnfExpr(['lambda', [], AtomicAnfExpr(None).value]) + # handle lambda: rase(msg) + return_count_constraint = ReturnCounterConstraint(limits=0) + parser = PyToAnfParser(self.seq_no_counter, return_count_constraint) + if expr.msg is None: + msg = parser.BindToTmpVar(AtomicAnfExpr({"str": ""})) + else: + msg = parser.Parse(expr.msg) + exception = parser.BindToTmpVar(['AssertionError', msg.value]) + raise_ret = parser.BindToTmpVar(['raise', exception.value]) + false_value = AtomicAnfExpr( + [ + 'lambda', + [], + AnfParseResult( + bindings=parser.bindings, body_atomic_anf_expr=raise_ret + ) + .ConvertToAnfExpr() + .value, + ] + ) + ret = self.BindToTmpVar( + [ + '__builtin_if__', + test_value.value, + true_value.value, + false_value.value, + ] + ) + return ret + + def ParseAssign(self, tree): + assert len(tree.targets) == 1 + if isinstance(tree.targets[0], ast.Name): + val = self.Parse(tree.value) + var = tree.targets[0].id + self.Bind(var, val) + return AtomicAnfExpr(var) + elif isinstance(tree.targets[0], ast.Attribute): + val = self.Parse(tree.value) + attr = tree.targets[0] + f = self.BindToTmpVar( + [ + '__builtin_setattr__', + self.Parse(attr.value).value, + {"str": attr.attr}, + ] + ) + return self.BindToTmpVar([f.value, {"str": attr.attr}, val.value]) + elif isinstance(tree.targets[0], ast.Subscript): + val = self.Parse(tree.value) + subscript = tree.targets[0] + slice_val = self.Parse(subscript.slice).value + f = self.BindToTmpVar( + [ + '__builtin_setitem__', + self.Parse(subscript.value).value, + slice_val, + ] + ) + return self.BindToTmpVar([f.value, slice_val, val.value]) + else: + raise NotImplementedError(tree.targets) + + def ParseSubscript(self, tree): + val = self.Parse(tree.value) + slc = self.Parse(tree.slice) + return self.BindToTmpVar(["__builtin_getitem__", val.value, slc.value]) + + def ParseExpr(self, tree): + return self.BindToTmpVar(self.Parse(tree.value)) + + def BindToTmpVar(self, value): + tmp_var = self.get_tmp_var() + self.Bind(tmp_var, value) + return AtomicAnfExpr(tmp_var) + + def GetStmtSizeUntilReturn(self, stmts): + for idx, stmt in enumerate(stmts): + if isinstance(stmt, ast.Return): + return idx + 1 + return len(stmts) + + def ParseReturn(self, tree: ast.Return): + self.return_count_constraint.CountAndCheck() + value = self.Parse(tree.value) + return self.BindToTmpVar(["__builtin_return__", value.value]) + + def ParseStarred(self, tree: ast.Starred): + value = self.Parse(tree.value) + return self.BindToTmpVar(["__builtin_starred__", value.value]) + + def ParseCall(self, tree: ast.Call): + func = self.Parse(tree.func) + assert isinstance(func, AtomicAnfExpr) + + def ParseArg(arg): + parsed_arg = self.Parse(arg) + assert isinstance(parsed_arg, AtomicAnfExpr) + return parsed_arg + + args = [ParseArg(arg).value for arg in tree.args] + kwargs = None + if len(tree.keywords) > 0: + keywords = [ParseArg(arg).value for arg in tree.keywords] + kwargs = self.BindToTmpVar(["__builtin_list__", *keywords]) + if kwargs is None: + if any(isinstance(arg, ast.Starred) for arg in tree.args): + l = self.BindToTmpVar(["__builtin_list__", *args]) + return self.BindToTmpVar( + ["__builtin_apply__", func.value, l.value] + ) + else: + return self.BindToTmpVar([func.value, *args]) + else: + args = self.BindToTmpVar(["__builtin_list__", *args]) + packed_args = self.BindToTmpVar( + ["__builtin_PackedArgs__", args.value, kwargs.value] + ) + return self.BindToTmpVar([func.value, packed_args.value]) + + def ParseList(self, lst: ast.List): + return self._ParseCall('__builtin_list__', lst.elts) + + def _ParseCall(self, func, ast_args): + def ParseArg(arg): + parsed_arg = self.Parse(arg) + assert isinstance(parsed_arg, AtomicAnfExpr) + return parsed_arg + + args = [ParseArg(arg).value for arg in ast_args] + ret_var = self.get_tmp_var() + self.Bind(ret_var, [func, *args]) + return AtomicAnfExpr(ret_var) + + def ParseAttribute(self, attr: ast.Attribute): + ret_var = self.get_tmp_var() + self.Bind( + ret_var, + [ + '__builtin_getattr__', + self.Parse(attr.value).value, + {"str": attr.attr}, + ], + ) + return AtomicAnfExpr(ret_var) + + def ParseName(self, name: ast.Name): + return AtomicAnfExpr(name.id) + + def ParseConstant(self, constant: ast.Constant): + if isinstance(constant.value, str): + return AtomicAnfExpr({"str": constant.value}) + if isinstance(constant.value, (bool, int, float)): + return AtomicAnfExpr(constant.value) + if constant.value is None: + return AtomicAnfExpr(None) + raise NotImplementedError(f"{constant} not supported by anf_expr") + + def ParseJoinedStr(self, tree: ast.JoinedStr): + if len(tree.values) == 0: + return AtomicAnfExpr({"str": ""}) + + def ToString(elt): + parsed_elt = self.Parse(elt) + parsed_elt = self.BindToTmpVar( + ['__builtin_ToString__', parsed_elt.value] + ) + return parsed_elt + + ret = ToString(tree.values[0]) + for elt in tree.values[1:]: + parsed_elt = ToString(elt) + ret = self.BindToTmpVar( + ['__builtin_Add__', ret.value, parsed_elt.value] + ) + return ret + + def ParseFormattedValue(self, tree: ast.FormattedValue): + return self.Parse(tree.value) + + def Bind(self, var_name, anf_expr): + return getattr(self, f"Bind{type(anf_expr).__name__}")( + var_name, anf_expr + ) + + def BindAtomicAnfExpr(self, var_name, anf_expr): + self.bindings.append( + [var_name, ["__builtin_identity__", anf_expr.value]] + ) + + def Bindlist(self, var_name, anf_expr): + self.bindings.append([var_name, anf_expr]) + + def get_tmp_var(self): + return f"___{next(self.seq_no_counter)}" + + +class ReturnCounterConstraint: + def __init__(self, limits): + self.counter = itertools.count() + self.limits = limits + + def CountAndCheck(self): + return_stmt_id = next(self.counter) + assert return_stmt_id < self.limits + + +if __name__ == "__main__": + tree = ast.parse(open(sys.argv[1]).read()) + parser = PyToAnfParser() + parser(tree).ConvertToAnfExpr().DumpToFileAsJson(sys.argv[2]) diff --git a/python/paddle/incubate/cc/data_type_util.py b/python/paddle/incubate/cc/data_type_util.py new file mode 100644 index 00000000000000..77fc2b10d16e80 --- /dev/null +++ b/python/paddle/incubate/cc/data_type_util.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def get_dtype_lower_case_name(dtype): + return _up_case_name2lower_case_name[dtype.name] + + +_up_case_name2lower_case_name = { + "UNDEFINED": "void", + "BOOL": "bool", + "INT8": "int8", + "UINT8": "uint8", + "INT16": "int16", + "UINT16": "uint16", + "INT32": "int32", + "UINT32": "uint32", + "INT64": "int64", + "UINT64": "uint64", + "FLOAT8_E4M3FN": "float8_e4m3fn", + "FLOAT8_E5M2": "float8_e5m2", + "BFLOAT16": "bfloat16", + "FLOAT16": "float16", + "FLOAT32": "float32", + "FLOAT64": "float64", + "COMPLEX64": "complex64", + "COMPLEX128": "complex128", + "PSTRING": "pstring", +} diff --git a/python/paddle/incubate/cc/typing.py b/python/paddle/incubate/cc/typing.py index 6a0f1421d8dd84..8abaf3fef6a292 100644 --- a/python/paddle/incubate/cc/typing.py +++ b/python/paddle/incubate/cc/typing.py @@ -36,6 +36,10 @@ def __init__( self.max = max +# alias +Dim = DimVar + + # Usage: # T = paddle.incubate.cc.typing.DTypeVar("T", "bfloat16", "float32") class DTypeVar: @@ -48,6 +52,10 @@ def __init__(self, name: str, *candidates): self.candidates = candidates +# alias +DType = DTypeVar + + # Usage: # # import paddle.incubate.cc.typing as pct diff --git a/python/setup.py.in b/python/setup.py.in index 4639dc731f6375..866155c3d91e8e 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -828,6 +828,7 @@ packages=['paddle', 'paddle.text.datasets', 'paddle.incubate', 'paddle.incubate.cc', + 'paddle.incubate.cc.ap', 'paddle.incubate.jit', 'paddle.incubate.nn', 'paddle.incubate.nn.functional', diff --git a/setup.py b/setup.py index c6dd56f694b770..af0f8d547faf17 100644 --- a/setup.py +++ b/setup.py @@ -2219,6 +2219,7 @@ def get_setup_parameters(): 'paddle.text.datasets', 'paddle.incubate', 'paddle.incubate.cc', + 'paddle.incubate.cc.ap', 'paddle.incubate.nn', 'paddle.incubate.jit', 'paddle.incubate.nn.functional',