diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index bebad4e0ea34..297e05f69b8c 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -27,10 +27,6 @@ def __init__(self, repo_map: Dict[str, str]): f"{iree_core_repo}//compiler/src/iree/compiler/API:CAPI": [ "IREECompilerCAPILib" ], - # IREE llvm-external-projects - f"{iree_core_repo}//llvm-external-projects/iree-dialects:CAPI": [ - "IREEDialectsCAPI" - ], # Disable all hard-coded codegen targets (they are expanded dynamically # in CMake). "@llvm-project//llvm:AArch64AsmParser": ["IREELLVMCPUTargetDeps"], diff --git a/compiler/plugins/target/LLVMCPU/BUILD.bazel b/compiler/plugins/target/LLVMCPU/BUILD.bazel index 2b931e7df565..8e25d6be7a3a 100644 --- a/compiler/plugins/target/LLVMCPU/BUILD.bazel +++ b/compiler/plugins/target/LLVMCPU/BUILD.bazel @@ -47,7 +47,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/PluginAPI", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:AArch64AsmParser", "@llvm-project//llvm:AArch64CodeGen", "@llvm-project//llvm:ARMAsmParser", diff --git a/compiler/plugins/target/LLVMCPU/CMakeLists.txt b/compiler/plugins/target/LLVMCPU/CMakeLists.txt index b0cb849fa713..a70f61efd8a9 100644 --- a/compiler/plugins/target/LLVMCPU/CMakeLists.txt +++ b/compiler/plugins/target/LLVMCPU/CMakeLists.txt @@ -31,7 +31,6 @@ iree_cc_library( ::LinkerTool ::StaticLibraryGenerator IREELLVMCPUTargetDeps - IREELinalgTransformDialect LLVMAnalysis LLVMBitReader LLVMBitWriter diff --git a/compiler/src/iree/compiler/API/BUILD.bazel b/compiler/src/iree/compiler/API/BUILD.bazel index 9ca3ea751001..f10c9226d0ed 100644 --- a/compiler/src/iree/compiler/API/BUILD.bazel +++ b/compiler/src/iree/compiler/API/BUILD.bazel @@ -37,7 +37,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/API/Internal:IREEOptToolEntryPoint", "//compiler/src/iree/compiler/API/Internal:IREEReduceToolEntryPoint", "//compiler/src/iree/compiler/API/Internal:LLDToolEntryPoint", - "//llvm-external-projects/iree-dialects:CAPI", "@llvm-project//mlir:CAPIAMDGPU", "@llvm-project//mlir:CAPIDebug", "@llvm-project//mlir:CAPIGPU", diff --git a/compiler/src/iree/compiler/API/CMakeLists.txt b/compiler/src/iree/compiler/API/CMakeLists.txt index 8d76da8211cb..d0a0a697a71f 100644 --- a/compiler/src/iree/compiler/API/CMakeLists.txt +++ b/compiler/src/iree/compiler/API/CMakeLists.txt @@ -14,7 +14,6 @@ iree_cc_library( NAME StaticImpl DEPS - IREEDialectsCAPI MLIRCAPIAMDGPU MLIRCAPIDebug MLIRCAPIExportSMTLIB @@ -82,7 +81,6 @@ set(_EXPORT_OBJECT_LIBS iree_compiler_API_Internal_IREEOptToolEntryPoint.objects iree_compiler_API_Internal_IREEReduceToolEntryPoint.objects iree_compiler_API_Internal_LLDToolEntryPoint.objects - obj.IREEDialectsCAPI obj.MLIRCAPIAMDGPU obj.MLIRCAPIDebug obj.MLIRCAPIExportSMTLIB diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c index 31a69d22d4a9..731a366c9faf 100644 --- a/compiler/src/iree/compiler/API/api_exports.c +++ b/compiler/src/iree/compiler/API/api_exports.c @@ -154,7 +154,6 @@ extern void ireeLinkRunMain(); extern void ireeMlirLspServerRunMain(); extern void ireeOptRunMain(); extern void ireeReduceRunMain(); -extern void ireeRegisterTransformExtensions(); extern void mlirAffineAddExprGet(); extern void mlirAffineBinaryOpExprGetLHS(); extern void mlirAffineBinaryOpExprGetRHS(); @@ -1330,7 +1329,6 @@ uintptr_t __iree_compiler_hidden_force_extern() { x += (uintptr_t)&ireeMlirLspServerRunMain; x += (uintptr_t)&ireeOptRunMain; x += (uintptr_t)&ireeReduceRunMain; - x += (uintptr_t)&ireeRegisterTransformExtensions; x += (uintptr_t)&mlirAffineAddExprGet; x += (uintptr_t)&mlirAffineBinaryOpExprGetLHS; x += (uintptr_t)&mlirAffineBinaryOpExprGetRHS; diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def index c62f0082631f..a1eb623ceae8 100644 --- a/compiler/src/iree/compiler/API/api_exports.def +++ b/compiler/src/iree/compiler/API/api_exports.def @@ -144,7 +144,6 @@ EXPORTS ireeMlirLspServerRunMain ireeOptRunMain ireeReduceRunMain - ireeRegisterTransformExtensions mlirAffineAddExprGet mlirAffineBinaryOpExprGetLHS mlirAffineBinaryOpExprGetRHS diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld index 902f3c9c7986..8916a9d471b4 100644 --- a/compiler/src/iree/compiler/API/api_exports.ld +++ b/compiler/src/iree/compiler/API/api_exports.ld @@ -145,7 +145,6 @@ VER_0 { ireeMlirLspServerRunMain; ireeOptRunMain; ireeReduceRunMain; - ireeRegisterTransformExtensions; mlirAffineAddExprGet; mlirAffineBinaryOpExprGetLHS; mlirAffineBinaryOpExprGetRHS; diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst index fc7d0fa15962..3daa67377cac 100644 --- a/compiler/src/iree/compiler/API/api_exports.macos.lst +++ b/compiler/src/iree/compiler/API/api_exports.macos.lst @@ -143,7 +143,6 @@ _ireeLinkRunMain _ireeMlirLspServerRunMain _ireeOptRunMain _ireeReduceRunMain -_ireeRegisterTransformExtensions _mlirAffineAddExprGet _mlirAffineBinaryOpExprGetLHS _mlirAffineBinaryOpExprGetRHS diff --git a/compiler/src/iree/compiler/API/generate_exports.py b/compiler/src/iree/compiler/API/generate_exports.py index 31b6bd0e741a..05d79fef99e3 100755 --- a/compiler/src/iree/compiler/API/generate_exports.py +++ b/compiler/src/iree/compiler/API/generate_exports.py @@ -67,10 +67,6 @@ "Dialect/PDL.h", ] -IREE_DIALECTS_HEADER_FILES = [ - "Dialects.h", -] - IREE_COMPILER_DIALECTS_HEADER_FILES = [ "iree_codegen.h", "iree_gpu.h", @@ -101,16 +97,6 @@ def main(repo_root: Path, api_root: Path): for local_name in LOCAL_HEADER_FILES: export_symbols.extend(collect_header_exports(api_root / local_name)) - # Collect symbols from iree-dialects header files. - for local_name in IREE_DIALECTS_HEADER_FILES: - export_symbols.extend( - collect_header_exports( - repo_root - / "llvm-external-projects/iree-dialects/include/iree-dialects-c" - / local_name - ) - ) - # Collect symbols from iree compiler dialect header files. for local_name in IREE_COMPILER_DIALECTS_HEADER_FILES: export_symbols.extend( diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 81167e58463b..fa6dd7e8158d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -103,6 +103,7 @@ iree_compiler_cc_library( "EncodingUtils.cpp", "EraseDeadAllocAndStores.cpp", "EraseHALDescriptorTypeFromMemRef.cpp", + "ErrorCheckingTrackingListener.cpp", "FastMathPatterns.cpp", "FissionTransferOpsInControlFlow.cpp", "FlattenMemRefSubspan.cpp", @@ -190,6 +191,7 @@ iree_compiler_cc_library( "CombineLayoutTransformation.h", "EmulateNarrowType.h", "EncodingUtils.h", + "ErrorCheckingTrackingListener.h", "FastMathPatterns.h", "PassUtils.h", "Passes.h", @@ -234,7 +236,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", @@ -332,7 +333,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions:LinalgExtExtensions", "//compiler/src/iree/compiler/Dialect/TensorExt/IR", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineUtils", @@ -377,7 +377,6 @@ iree_compiler_cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:DialectUtils", # TransformExtensions (needed for registration in the pass) - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", "//compiler/src/iree/compiler/Codegen/Common/TransformExtensions:CommonExtensions", "//compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions:LLVMCPUExtensions", "//compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions:LLVMGPUExtensions", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index a3f5d46293dd..0b38a523d4c4 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -56,6 +56,7 @@ iree_cc_library( "CombineLayoutTransformation.h" "EmulateNarrowType.h" "EncodingUtils.h" + "ErrorCheckingTrackingListener.h" "FastMathPatterns.h" "PassUtils.h" "Passes.h" @@ -96,6 +97,7 @@ iree_cc_library( "EncodingUtils.cpp" "EraseDeadAllocAndStores.cpp" "EraseHALDescriptorTypeFromMemRef.cpp" + "ErrorCheckingTrackingListener.cpp" "FastMathPatterns.cpp" "FissionTransferOpsInControlFlow.cpp" "FlattenMemRefSubspan.cpp" @@ -180,7 +182,6 @@ iree_cc_library( DEPS ::PassHeaders ::PassesIncGen - IREELinalgTransformDialect LLVMCore LLVMSupport MLIRAMDGPUDialect @@ -287,8 +288,6 @@ iree_cc_library( ::Common ::PassHeaders ::PassesIncGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRAMDGPUTransforms MLIRAffineDialect diff --git a/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp b/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp index d7187490c02c..69733a6d8b3b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" #include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" @@ -97,8 +96,7 @@ void registerTransformDialectTranslationDependentDialects( vector::registerBufferizableOpInterfaceExternalModels(registry); registry.addExtensions< - mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension, - transform_ext::StructuredTransformOpsExtension>(); + mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension>(); iree_compiler::registerTransformDialectCommonExtension(registry); iree_compiler::registerTransformDialectLLVMCPUExtension(registry); iree_compiler::registerTransformDialectLLVMGPUExtension(registry); diff --git a/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.cpp b/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.cpp new file mode 100644 index 000000000000..91285ec78cd3 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.cpp @@ -0,0 +1,40 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h" + +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#define DEBUG_TYPE "iree-codegen-error-checking-tracking-listener" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +namespace mlir::iree_compiler { + +void ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( + Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) { + // Certain ops can dropped safely. + if (isa(op)) { + LLVM_DEBUG(DBGS() << "Silently dropping scf.for op mapping\n"); + return; + } + + SmallVector diags; + diag.takeDiagnostics(diags); + if (!status.succeeded()) { + status.takeDiagnostics(diags); + } + status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags)); + + status = emitSilenceableFailure( + getTransformOp(), "!!! tracking listener failed to find replacement op"); + status.attachNote(op->getLoc()) << "replaced op"; + for (Value v : values) { + status.attachNote(v.getLoc()) << "replacement value"; + } +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h b/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h new file mode 100644 index 000000000000..35ea97d88136 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h @@ -0,0 +1,48 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_COMMON_ERRORCHECKINGTRACKINGLISTENER_H_ +#define IREE_COMPILER_CODEGEN_COMMON_ERRORCHECKINGTRACKINGLISTENER_H_ + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" + +namespace mlir::iree_compiler { + +/// A tracking listener for tensor IR that checks for payload replacement +/// errors. +class ErrorCheckingTrackingListener : public transform::TrackingListener { +public: + using transform::TrackingListener::TrackingListener; + + ~ErrorCheckingTrackingListener() override { + assert(status.succeeded() && "must check listener error state"); + } + + /// Return "true" if this tracking listener had a failure. + bool failed() const { return !status.succeeded(); } + + /// Check and return the current error state of this listener. In case of a + /// failure state, only the most recent error is returned. Afterwards, resets + /// the error state. + DiagnosedSilenceableFailure checkAndResetError() { + DiagnosedSilenceableFailure result(std::move(status)); + status = DiagnosedSilenceableFailure::success(); + return result; + } + +private: + void + notifyPayloadReplacementNotFound(Operation *op, ValueRange values, + DiagnosedSilenceableFailure &&diag) override; + + /// The error state of this listener. "Success" indicates that no error + /// happened so far. Otherwise, the status contains the most recent error. + DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success(); +}; + +} // namespace mlir::iree_compiler + +#endif // IREE_COMPILER_CODEGEN_COMMON_ERRORCHECKINGTRACKINGLISTENER_H_ diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel index 296f03528d75..6601bda95e5e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel @@ -75,8 +75,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Utils", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineUtils", diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt index 2f91228d2fde..a98181a05315 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt @@ -31,8 +31,6 @@ iree_cc_library( "CommonExtensionsOps.cpp.inc" DEPS ::CommonExtensionsOpGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRAffineDialect MLIRAffineUtils diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index b630323b05b1..8113593b4b6e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -6,8 +6,7 @@ #include "CommonExtensions.h" -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" -#include "iree-dialects/Transforms/TransformMatchers.h" +#include "iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h" #include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h" #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" @@ -48,6 +47,7 @@ #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel index bcba86ba95cd..7ff9311888eb 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel @@ -61,7 +61,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Dialect/PCF/ExternalInterfaces:ExternalModels", "//compiler/src/iree/compiler/Codegen/ExternalInterfaces:ExternalModels", "//compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions:LinalgExtExtensions", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:SideEffectInterfaces", diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt index c22bc95db95d..646bba0bd068 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt @@ -25,7 +25,6 @@ iree_cc_library( ::TensorMaskingOpInterface ::UKernelOpInterface ::VectorizableOpInterface - IREELinalgTransformDialect MLIRAMDGPUDialect MLIRAffineDialect MLIRAffineTransformOps diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp index e21cde562ec4..69483e8aa7f0 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/Interfaces/Interfaces.h" +#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h" #include "iree/compiler/Codegen/Dialect/GPU/ExternalInterfaces/Interfaces.h" #include "iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.h" #include "iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/Interfaces.h" @@ -15,14 +16,8 @@ #include "iree/compiler/Codegen/Interfaces/HoistableRegionOpInterface.h" #include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h" #include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h" -#include "iree/compiler/Codegen/Interfaces/VectorizableOpInterface.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -// TODO: Remove this dependency once the transform dialect extensions -// have a better registration mechanism. -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" -#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h" #include "iree/compiler/Codegen/Interfaces/TensorMaskingOpInterface.h" +#include "iree/compiler/Codegen/Interfaces/VectorizableOpInterface.h" #include "iree/compiler/Codegen/LLVMCPU/TransformExtensions/LLVMCPUExtensions.h" #include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h" #include "iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.h" @@ -33,6 +28,7 @@ #include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" #include "mlir/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" @@ -50,6 +46,7 @@ #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h" #include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir::iree_compiler { @@ -90,11 +87,8 @@ void registerCodegenInterfaces(DialectRegistry ®istry) { registerPCFExternalInterfaces(registry); registerBufferizationInterfaces(registry); registerTensorMaskingOpInterface(registry); - // TODO: Remove this dependency once the transform dialect extensions - // have a better registration mechanism. // TODO: when warranted, move to its own file. - registry.addExtensions(); + registry.addExtensions(); registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { linalg::GenericOp::attachInterface(*ctx); }); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index 2e82813eb75b..10537c9023df 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -117,7 +117,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Transforms", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "//runtime/src/iree/schemas:cpu_data", "//runtime/src/iree/schemas/instruments", "@llvm-project//llvm:BinaryFormat", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 21ee96022063..b49e7dbce7a7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -82,7 +82,6 @@ iree_cc_library( DEPS ::PassHeaders ::PassesIncGen - IREELinalgTransformDialect LLVMBinaryFormat LLVMSupport LLVMTargetParser diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel index 522ae9359b0e..367ec8a8ea34 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel @@ -58,8 +58,6 @@ iree_compiler_cc_library( ], deps = [ ":LLVMCPUExtensionsOpGen", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:TransformDialect", ], diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/CMakeLists.txt index fa06b1419552..10f5eae8ce47 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/CMakeLists.txt @@ -31,8 +31,6 @@ iree_cc_library( "LLVMCPUExtensionsOps.cpp.inc" DEPS ::LLVMCPUExtensionsOpGen - IREEDialectsTransforms - IREELinalgTransformDialect MLIRIR MLIRTransformDialect PUBLIC diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index e8e27a7cb247..dc6a9d4f8fa8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -193,7 +193,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Transforms", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", "@llvm-project//mlir:AMDGPUToROCDL", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index 8897d7253611..a38916616b89 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -129,7 +129,6 @@ iree_cc_library( ::PassesIncGen ::ROCDLPassHeaders ::ROCDLPassesIncGen - IREELinalgTransformDialect LLVMSupport MLIRAMDGPUDialect MLIRAMDGPUToROCDL diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel index df19585b67ab..3ee6e782ff41 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel @@ -66,8 +66,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms:VectorExtTransforms", "//compiler/src/iree/compiler/Codegen/LLVMGPU/Utils", "//compiler/src/iree/compiler/Codegen/Utils", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", "@llvm-project//mlir:AffineDialect", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt index e7ff813e374e..a98db3845cf4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt @@ -31,8 +31,6 @@ iree_cc_library( "LLVMGPUExtensionsOps.cpp.inc" DEPS ::LLVMGPUExtensionsOpGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRAMDGPUDialect MLIRAffineDialect diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 36a59eca50a7..04387136a2bc 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -6,7 +6,7 @@ #include "LLVMGPUExtensions.h" -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" +#include "iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h" #include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h" #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" @@ -29,6 +29,7 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel index 3f0cb389343a..201158fcfff9 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel @@ -110,7 +110,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Transforms", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineAnalysis", "@llvm-project//mlir:AffineDialect", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt index 3a72e9982213..ec2e172ef876 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt @@ -79,7 +79,6 @@ iree_cc_library( DEPS ::PassHeaders ::PassesIncGen - IREELinalgTransformDialect LLVMSupport MLIRAffineAnalysis MLIRAffineDialect diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel index b528e6b198d2..98eb5921c6c9 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel @@ -61,8 +61,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/Flow/Transforms", "//compiler/src/iree/compiler/Dialect/TensorExt/IR", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/CMakeLists.txt index e698587c89b7..66fe9c2cb575 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/CMakeLists.txt @@ -31,8 +31,6 @@ iree_cc_library( "FlowExtensionsOps.cpp.inc" DEPS ::FlowExtensionsOpGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRAnalysis MLIRArithDialect diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel index 3b367579d3ec..da95c44b97ee 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel @@ -84,8 +84,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineUtils", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index 136de136cb2b..69c8b8185038 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -53,8 +53,6 @@ iree_cc_library( "VerifyInputLegality.cpp" DEPS ::PassesIncGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRAffineDialect MLIRAffineUtils diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel index 71ccf89cd779..11a03e6a7151 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel @@ -63,8 +63,6 @@ iree_compiler_cc_library( ":LinalgExtExtensionsOpGen", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/CMakeLists.txt index a634d0e869b5..0e2dd24b5cc9 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/CMakeLists.txt @@ -31,8 +31,6 @@ iree_cc_library( "LinalgExtExtensionsOps.cpp.inc" DEPS ::LinalgExtExtensionsOpGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRIR MLIRLinalgDialect diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel index 37f58d4280a1..03dd920131e8 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel @@ -103,8 +103,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Modules/IO/Parameters/Transforms", "//compiler/src/iree/compiler/Pipelines:Options", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt index 0b5cfb188450..050b3c950fa7 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt +++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt @@ -67,8 +67,6 @@ iree_cc_library( DEPS ::PassHeaders ::PassesIncGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRAffineDialect MLIRArithDialect diff --git a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp index b4e177eddfe4..4077d7604254 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree-dialects/Transforms/TransformMatchers.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" @@ -19,6 +18,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinAttributes.h" @@ -28,7 +28,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; -using transform_ext::StructuredOpMatcher; namespace mlir::iree_compiler::GlobalOptimization { @@ -374,6 +373,293 @@ class NamedImplicitCastOpConversion : public OpInterfaceRewritePattern { // Softmax Raising //===----------------------------------------------------------------------===// +// Recognizing a numerically-stabilized softmax means walking a small graph of +// `linalg.generic` ops. The helpers below match the individual nodes of that +// graph directly, replacing the generic `transform_ext` matcher framework that +// used to live in iree-dialects. The matched dataflow, working back from the +// root op, is: +// +// max = reduce_max(src) (reduction over the innermost dim) +// sub = src - broadcast(max) +// exp = exp(sub) +// sum = reduce_add(exp) (reduction over the innermost dim) +// root = exp * reciprocal(sum) OR exp / broadcast(sum) +// +// where each broadcast is either implicit (the consumer indexing map drops the +// innermost dim) or explicit (a pass-through `linalg.generic` that broadcasts +// along the innermost dim). + +// Returns true if every iterator of `op` is parallel. +static bool isAllParallel(linalg::GenericOp op) { + return llvm::all_of(op.getIteratorTypesArray(), [](utils::IteratorType t) { + return t == utils::IteratorType::parallel; + }); +} + +// Returns true if `operand`'s indexing map is the identity. +static bool isIdentityOperand(linalg::GenericOp op, OpOperand *operand) { + return op.getMatchingIndexingMap(operand).isIdentity(); +} + +// Returns true if `operand`'s indexing map is the projection that drops the +// innermost (last) loop dim, e.g. (d0, d1, d2) -> (d0, d1). +static bool dropsInnermostDim(linalg::GenericOp op, OpOperand *operand) { + AffineMap map = op.getMatchingIndexingMap(operand); + unsigned numLoops = op.getNumLoops(); + if (!map.isProjectedPermutation() || map.getNumResults() + 1 != numLoops) { + return false; + } + for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { + if (map.getDimPosition(i) != i) { + return false; + } + } + return true; +} + +// Returns the single compute op of `op`'s body when the body is exactly that op +// followed by a `linalg.yield` of its single result; otherwise nullptr. +static Operation *getSingleComputeOp(linalg::GenericOp op) { + Block *body = op.getBlock(); + if (body->getOperations().size() != 2) { + return nullptr; + } + Operation *computeOp = &body->front(); + auto yieldOp = cast(body->getTerminator()); + if (computeOp->getNumResults() != 1 || yieldOp.getNumOperands() != 1 || + yieldOp.getOperand(0).getDefiningOp() != computeOp) { + return nullptr; + } + return computeOp; +} + +// Returns true if `v` references block argument number `argNumber` of `block`. +static bool isBlockArg(Value v, Block *block, unsigned argNumber) { + auto arg = dyn_cast(v); + return arg && arg.getOwner() == block && arg.getArgNumber() == argNumber; +} + +// Returns true if `op`'s body is a single `OpTy` whose operands are the leading +// block arguments in order. When `commutative`, a two-operand body may also use +// the arguments in swapped order. +template +static bool hasSingleOpBody(linalg::GenericOp op, bool commutative = false) { + Operation *computeOp = getSingleComputeOp(op); + if (!isa_and_nonnull(computeOp)) { + return false; + } + Block *body = op.getBlock(); + if (commutative && computeOp->getNumOperands() == 2) { + return (isBlockArg(computeOp->getOperand(0), body, 0) && + isBlockArg(computeOp->getOperand(1), body, 1)) || + (isBlockArg(computeOp->getOperand(0), body, 1) && + isBlockArg(computeOp->getOperand(1), body, 0)); + } + for (auto [index, operand] : llvm::enumerate(computeOp->getOperands())) { + if (!isBlockArg(operand, body, index)) { + return false; + } + } + return true; +} + +// Returns true if `op`'s body just yields its (single) block argument, i.e. it +// is a pure broadcast/transpose of the input. +static bool isPassThroughBody(linalg::GenericOp op) { + Block *body = op.getBlock(); + if (body->getOperations().size() != 1) { + return false; + } + auto yieldOp = cast(body->getTerminator()); + for (auto [index, operand] : llvm::enumerate(yieldOp.getOperands())) { + if (!isBlockArg(operand, body, index)) { + return false; + } + } + return true; +} + +// Returns true if `v` is a float constant satisfying `pred`. +static bool isFloatConstant(Value v, llvm::function_ref pred) { + auto cstOp = v.getDefiningOp(); + return cstOp && pred(cstOp.value()); +} + +// Returns true if `v` is a `linalg.fill` whose scalar value is a float constant +// satisfying `pred`. +static bool isFilledWith(Value v, llvm::function_ref pred) { + auto fillOp = v.getDefiningOp(); + return fillOp && isFloatConstant(fillOp.getInputs()[0], pred); +} + +// Matches a single-input `linalg.generic` reduction over the innermost dim with +// a commutative `OpTy` body and an init produced by a `linalg.fill` of a +// constant satisfying `fillPred`. Returns the reduced value on success. +template +static Value +matchInnermostReduction(Value v, llvm::function_ref fillPred) { + auto op = v.getDefiningOp(); + if (!op || op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1) { + return {}; + } + SmallVector iterators = op.getIteratorTypesArray(); + if (iterators.empty() || iterators.back() != utils::IteratorType::reduction) { + return {}; + } + for (unsigned i = 0, e = iterators.size() - 1; i < e; ++i) { + if (iterators[i] != utils::IteratorType::parallel) { + return {}; + } + } + OpOperand *input = op.getDpsInputOperand(0); + OpOperand *init = op.getDpsInitOperand(0); + if (!isIdentityOperand(op, input) || !dropsInnermostDim(op, init) || + !hasSingleOpBody(op, /*commutative=*/true) || + !isFilledWith(init->get(), fillPred)) { + return {}; + } + return input->get(); +} + +// Matches a single-input all-parallel `linalg.generic` with identity-mapped +// input and output and an `OpTy` body. Returns the input value on success. +template +static Value matchUnaryElementwise(Value v) { + auto op = v.getDefiningOp(); + if (!op || op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1 || + !isAllParallel(op)) { + return {}; + } + if (!isIdentityOperand(op, op.getDpsInputOperand(0)) || + !isIdentityOperand(op, op.getDpsInitOperand(0)) || + !hasSingleOpBody(op)) { + return {}; + } + return op.getDpsInputOperand(0)->get(); +} + +// Matches a reciprocal `linalg.generic` (body `divf 1.0, %arg0`) that is +// single-input, all-parallel and identity-mapped. Returns the input value. +static Value matchReciprocal(Value v) { + auto op = v.getDefiningOp(); + if (!op || op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1 || + !isAllParallel(op)) { + return {}; + } + if (!isIdentityOperand(op, op.getDpsInputOperand(0)) || + !isIdentityOperand(op, op.getDpsInitOperand(0))) { + return {}; + } + auto divOp = dyn_cast_or_null(getSingleComputeOp(op)); + if (!divOp || + !isFloatConstant(divOp.getOperand(0), + [](APFloat f) { return f.convertToDouble() == 1.0; }) || + !isBlockArg(divOp.getOperand(1), op.getBlock(), 0)) { + return {}; + } + return op.getDpsInputOperand(0)->get(); +} + +// Resolves `operand` of `consumer`, which is expected to carry the broadcast of +// an innermost-reduced value, to that pre-broadcast value. Handles both the +// implicit broadcast (the operand's map drops the innermost dim) and the +// explicit broadcast (the operand is produced by a pass-through generic that +// broadcasts along the innermost dim). +static Value resolveBroadcastedReduction(linalg::GenericOp consumer, + OpOperand *operand) { + if (dropsInnermostDim(consumer, operand)) { + return operand->get(); + } + if (!isIdentityOperand(consumer, operand)) { + return {}; + } + auto bcast = operand->get().getDefiningOp(); + if (!bcast || bcast.getNumDpsInputs() != 1 || bcast.getNumDpsInits() != 1 || + !isAllParallel(bcast) || !isPassThroughBody(bcast) || + !dropsInnermostDim(bcast, bcast.getDpsInputOperand(0)) || + !isIdentityOperand(bcast, bcast.getDpsInitOperand(0))) { + return {}; + } + return bcast.getDpsInputOperand(0)->get(); +} + +// Recognizes the softmax graph rooted at `rootOp` and returns the softmax +// source value on success. +static FailureOr matchSoftmax(linalg::LinalgOp rootOp) { + auto root = dyn_cast(rootOp.getOperation()); + if (!root || root.getNumDpsInputs() != 2 || root.getNumDpsInits() != 1 || + !isAllParallel(root)) { + return failure(); + } + // The numerator (exp(...)) always flows in as input 0 with an identity map. + OpOperand *numerator = root.getDpsInputOperand(0); + OpOperand *denominator = root.getDpsInputOperand(1); + if (!isIdentityOperand(root, numerator)) { + return failure(); + } + + // The root normalizes either as `exp * reciprocal(sum)` or `exp / sum`. + Value sumValue; + if (hasSingleOpBody(root, /*commutative=*/true)) { + if (!dropsInnermostDim(root, denominator)) { + return failure(); + } + sumValue = matchReciprocal(denominator->get()); + } else if (hasSingleOpBody(root)) { + sumValue = resolveBroadcastedReduction(root, denominator); + } + if (!sumValue) { + return failure(); + } + + // exp = exp(sub); sum = reduce_add(exp). + Value expValue = numerator->get(); + Value subValue = matchUnaryElementwise(expValue); + Value summedValue = matchInnermostReduction( + sumValue, [](APFloat f) { return f.isZero(); }); + if (!subValue || summedValue != expValue) { + return failure(); + } + + // sub = src - broadcast(max). + auto subOp = subValue.getDefiningOp(); + if (!subOp || subOp.getNumDpsInputs() != 2 || subOp.getNumDpsInits() != 1 || + !isAllParallel(subOp) || + !isIdentityOperand(subOp, subOp.getDpsInitOperand(0)) || + !hasSingleOpBody(subOp)) { + return failure(); + } + OpOperand *subSource = subOp.getDpsInputOperand(0); + if (!isIdentityOperand(subOp, subSource)) { + return failure(); + } + Value maxValue = + resolveBroadcastedReduction(subOp, subOp.getDpsInputOperand(1)); + if (!maxValue) { + return failure(); + } + + // max = reduce_max(src), reducing the same source the subtraction reads. The + // init must be -inf or the lowest finite value. Accept both arith.maximumf + // (NaN-propagating, as emitted by e.g. StableHLO frontends) and arith.maxnumf + // (NaN-ignoring, which is what linalg.softmax itself decomposes to), since + // both denote the stabilizing max of a softmax. + auto isNegInfOrLowest = [](APFloat f) { + return (f.isLargest() || f.isInfinity()) && f.isNegative(); + }; + Value source = subSource->get(); + Value reducedValue = + matchInnermostReduction(maxValue, isNegInfOrLowest); + if (!reducedValue) { + reducedValue = + matchInnermostReduction(maxValue, isNegInfOrLowest); + } + if (reducedValue != source) { + return failure(); + } + return source; +} + class RaiseSoftmax : public OpInterfaceRewritePattern { public: using OpInterfaceRewritePattern::OpInterfaceRewritePattern; @@ -386,19 +672,15 @@ class RaiseSoftmax : public OpInterfaceRewritePattern { return failure(); } - transform_ext::MatcherContext matcherContext; - transform_ext::StructuredOpMatcher *maxReduction; - transform_ext::StructuredOpMatcher *softmaxroot; - makeSoftmaxMatcher(matcherContext, maxReduction, softmaxroot); - if (!matchPattern(linalgOp, *softmaxroot)) { + FailureOr src = matchSoftmax(linalgOp); + if (failed(src)) { return rewriter.notifyMatchFailure(linalgOp, "failed to match softmax root"); } - Value src = maxReduction->getCaptured()->getOperand(0); rewriter.setInsertionPoint(linalgOp); rewriter.replaceOpWithNewOp( - linalgOp, linalgOp->getResultTypes(), src, + linalgOp, linalgOp->getResultTypes(), *src, linalgOp.getDpsInitOperand(0)->get(), linalgOp.getNumLoops() - 1); return success(); } diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir index fc1b4e23e8a2..877b169c70c5 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir @@ -168,6 +168,154 @@ util.func public @softmax_broadcast(%93 : tensor<12x128x128xf32>) -> (tensor<12x // ----- +// The stabilizing max may use arith.maxnumf (NaN-ignoring). +// CHECK-LABEL: @softmax_maxnumf +// CHECK-SAME: %[[ARG:.+]]: tensor<2x4xf32> +// CHECK: %[[S:.+]] = linalg.softmax dimension(1) ins(%[[ARG]] : tensor<2x4xf32>) +// CHECK: util.return %[[S]] +util.func public @softmax_maxnumf(%src : tensor<2x4xf32>) -> (tensor<2x4xf32>) { + %cst0 = arith.constant 0.000000e+00 : f32 + %cstlow = arith.constant -3.40282347E+38 : f32 + %e1 = tensor.empty() : tensor<2xf32> + %e2 = tensor.empty() : tensor<2x4xf32> + %fillmax = linalg.fill ins(%cstlow : f32) outs(%e1 : tensor<2xf32>) -> tensor<2xf32> + %max = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%src : tensor<2x4xf32>) outs(%fillmax : tensor<2xf32>) { + ^bb0(%a: f32, %b: f32): + %m = arith.maxnumf %a, %b : f32 + linalg.yield %m : f32 + } -> tensor<2xf32> + %sub = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%src, %max : tensor<2x4xf32>, tensor<2xf32>) outs(%e2 : tensor<2x4xf32>) { + ^bb0(%a: f32, %b: f32, %c: f32): + %s = arith.subf %a, %b : f32 + linalg.yield %s : f32 + } -> tensor<2x4xf32> + %exp = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%sub : tensor<2x4xf32>) outs(%e2 : tensor<2x4xf32>) { + ^bb0(%a: f32, %b: f32): + %e = math.exp %a : f32 + linalg.yield %e : f32 + } -> tensor<2x4xf32> + %fillsum = linalg.fill ins(%cst0 : f32) outs(%e1 : tensor<2xf32>) -> tensor<2xf32> + %sum = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%exp : tensor<2x4xf32>) outs(%fillsum : tensor<2xf32>) { + ^bb0(%a: f32, %b: f32): + %s = arith.addf %a, %b : f32 + linalg.yield %s : f32 + } -> tensor<2xf32> + %div = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%exp, %sum : tensor<2x4xf32>, tensor<2xf32>) outs(%e2 : tensor<2x4xf32>) { + ^bb0(%a: f32, %b: f32, %c: f32): + %d = arith.divf %a, %b : f32 + linalg.yield %d : f32 + } -> tensor<2x4xf32> + util.return %div : tensor<2x4xf32> +} + +// ----- + +// Negative test: max init is 0.0, not -inf/lowest. +// CHECK-LABEL: @not_softmax_wrong_max_init +// CHECK-NOT: linalg.softmax +util.func public @not_softmax_wrong_max_init(%src : tensor) -> (tensor) { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %c_0_index = arith.constant 0 : index + %c_1_index = arith.constant 1 : index + %c_2_index = arith.constant 2 : index + %dim_0 = tensor.dim %src, %c_0_index : tensor + %dim_1 = tensor.dim %src, %c_1_index : tensor + %dim_2 = tensor.dim %src, %c_2_index : tensor + %1 = tensor.empty(%dim_0, %dim_1) : tensor + // Wrong init: 0.0 rather than -inf / lowest. + %2 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor) -> tensor + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%src : tensor) outs(%2 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.maximumf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + %4 = tensor.empty(%dim_0, %dim_1, %dim_2) : tensor + %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%src, %3 : tensor, tensor) outs(%4 : tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %11 = arith.subf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5 : tensor) outs(%4 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = math.exp %arg0 : f32 + linalg.yield %11 : f32 + } -> tensor + %7 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor) -> tensor + %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%6 : tensor) outs(%7 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.addf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%8 : tensor) outs(%1 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.divf %cst, %arg0 : f32 + linalg.yield %11 : f32 + } -> tensor + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6, %9 : tensor, tensor) outs(%4 : tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %11 = arith.mulf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + util.return %10 : tensor +} + +// ----- + +// Negative test: max and subtraction read different sources. +// CHECK-LABEL: @not_softmax_mismatched_source +// CHECK-NOT: linalg.softmax +util.func public @not_softmax_mismatched_source(%src : tensor, %other : tensor) -> (tensor) { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %cst_1 = arith.constant -3.40282347E+38 : f32 + %c_0_index = arith.constant 0 : index + %c_1_index = arith.constant 1 : index + %c_2_index = arith.constant 2 : index + %dim_0 = tensor.dim %src, %c_0_index : tensor + %dim_1 = tensor.dim %src, %c_1_index : tensor + %dim_2 = tensor.dim %src, %c_2_index : tensor + %1 = tensor.empty(%dim_0, %dim_1) : tensor + %2 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor) -> tensor + // Reduces %src ... + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%src : tensor) outs(%2 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.maximumf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + %4 = tensor.empty(%dim_0, %dim_1, %dim_2) : tensor + // ... but subtracts from %other. + %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%other, %3 : tensor, tensor) outs(%4 : tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %11 = arith.subf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5 : tensor) outs(%4 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = math.exp %arg0 : f32 + linalg.yield %11 : f32 + } -> tensor + %7 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor) -> tensor + %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%6 : tensor) outs(%7 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.addf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%8 : tensor) outs(%1 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.divf %cst, %arg0 : f32 + linalg.yield %11 : f32 + } -> tensor + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6, %9 : tensor, tensor) outs(%4 : tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %11 = arith.mulf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + util.return %10 : tensor +} + +// ----- + util.func public @aTransposeBMatmul(%arg0 : tensor<10x20xf32>, %arg1 : tensor<40x20xf32>) -> tensor<10x40xf32> { %0 = tensor.empty() : tensor<20x40xf32> diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel index 5ddf952d8fe9..7982c8feaf87 100644 --- a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel @@ -60,8 +60,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgInterfaces", diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt index 21977dd066e9..633f4e874ab1 100644 --- a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt @@ -31,8 +31,6 @@ iree_cc_library( "PreprocessingExtensionsOps.cpp.inc" DEPS ::PreprocessingExtensionsOpGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRIR MLIRLinalgInterfacesIncGenLib diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel index 3ddd96e03d39..5a8acd213473 100644 --- a/compiler/src/iree/compiler/Tools/BUILD.bazel +++ b/compiler/src/iree/compiler/Tools/BUILD.bazel @@ -76,7 +76,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Preprocessing:Passes", "//compiler/src/iree/compiler/Preprocessing/TransformExtensions:PreprocessingExtensions", "//compiler/src/iree/compiler/Transforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//mlir:IR", ], ) diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt index b8be6ff36a3d..de2f6848f95c 100644 --- a/compiler/src/iree/compiler/Tools/CMakeLists.txt +++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt @@ -24,7 +24,6 @@ iree_cc_library( "init_iree_dialects.h" "init_iree_passes.h" DEPS - IREELinalgTransformDialect MLIRIR iree::compiler::Bindings::Native::Transforms iree::compiler::Bindings::TFLite::Transforms diff --git a/docs/website/generate_extra_files.sh b/docs/website/generate_extra_files.sh index 25be565b4fb1..c1014c69724b 100755 --- a/docs/website/generate_extra_files.sh +++ b/docs/website/generate_extra_files.sh @@ -52,7 +52,6 @@ cp -r "${BUILD_PASSES_ORIGINAL_DIR}/." "${BUILD_PASSES_PROCESSED_DIR}" # Delete any dialect docs we don't want to publish (yet?). rm "${BUILD_DIALECTS_PROCESSED_DIR}/SimpleIODialect.md" # Sample dialect, just ignore -rm "${BUILD_DIALECTS_PROCESSED_DIR}/StructuredTransformOpsExt.md" # Dialect extensions # Trim "Dialect"/"Passes" suffix from file names e.g. FlowDialect.md -> Flow.md. for f in ${BUILD_DIALECTS_PROCESSED_DIR}/*Dialect.md; do diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel index c2dce5810f34..811772bff3fe 100644 --- a/llvm-external-projects/iree-dialects/BUILD.bazel +++ b/llvm-external-projects/iree-dialects/BUILD.bazel @@ -1,5 +1,4 @@ -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") +load("@rules_cc//cc:defs.bzl", "cc_binary") package( default_visibility = ["//visibility:public"], @@ -25,212 +24,6 @@ filegroup( ), ) -################################################################################ -# Tablegen exports -################################################################################ - -td_library( - name = "TdFiles", - srcs = glob( - [ - "include/iree-dialects/Dialect/Input/*.td", - "include/iree-dialects/Dialect/LinalgTransform/*.td", - "python/iree/compiler/dialects/*.td", - ], - allow_empty = True, - ), - includes = ["include"], - deps = [ - "@llvm-project//mlir:BuiltinDialectTdFiles", - "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:PDLDialectTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - "@llvm-project//mlir:TransformDialectTdFiles", - ], -) - -################################################################################ -# IREELinalgTransform Dialect -################################################################################ - -cc_library( - name = "IREEDialectsTransforms", - srcs = glob( - [ - "lib/Transforms/*.cpp", - ], - allow_empty = True, - ), - hdrs = glob( - [ - "include/iree-dialects/Transforms/*.h", - ], - allow_empty = True, - ), - includes = ["include"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FunctionInterfaces", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LinalgInterfaces", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:Rewrite", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformDialect", - "@llvm-project//mlir:TransformDialectInterfaces", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - ], -) - -gentbl_cc_library( - name = "IREELinalgTransformStructuredIncGen", - strip_include_prefix = "include", - tbl_outs = [ - ( - ["--gen-op-decls"], - "include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h.inc", - ), - ( - ["--gen-op-defs"], - "include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td", - deps = [ - ":TdFiles", - ], -) - -cc_library( - name = "IREELinalgTransformDialect", - srcs = glob( - [ - "lib/Dialect/LinalgTransform/IR/*.cpp", - "lib/Dialect/LinalgTransform/IR/*.h", - ], - allow_empty = True, - ), - hdrs = glob( - [ - "include/iree-dialects/Dialect/LinalgTransform/*.h", - ], - allow_empty = True, - ), - includes = ["include"], - deps = [ - ":IREEDialectsTransforms", - ":IREELinalgTransformStructuredIncGen", - "@llvm-project//llvm:Support", - - # Dialects - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineUtils", - "@llvm-project//mlir:AsyncDialect", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:BufferizationDialect", - "@llvm-project//mlir:BufferizationTransforms", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FunctionInterfaces", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:PDLDialect", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SCFUtils", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TilingInterface", - "@llvm-project//mlir:TransformDialect", - "@llvm-project//mlir:TransformDialectInterfaces", - "@llvm-project//mlir:TransformPDLExtension", - - # IR - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Rewrite", - - # Interfaces - "@llvm-project//mlir:ControlFlowInterfaces", - - # Transforms - "@llvm-project//mlir:AffineToStandard", - "@llvm-project//mlir:AsyncTransforms", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:MemRefTransforms", - "@llvm-project//mlir:ReconcileUnrealizedCasts", - "@llvm-project//mlir:SCFTransforms", - "@llvm-project//mlir:TensorTransformOps", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:VectorToSCF", - - # Utils - "@llvm-project//mlir:ArithUtils", - "@llvm-project//mlir:DialectUtils", - - # Conversions - "@llvm-project//mlir:AsyncToLLVM", - "@llvm-project//mlir:FuncToLLVM", - "@llvm-project//mlir:IndexToLLVM", - "@llvm-project//mlir:LinalgToStandard", - "@llvm-project//mlir:MathToLLVM", - "@llvm-project//mlir:MemRefToLLVM", - "@llvm-project//mlir:SCFToControlFlow", - "@llvm-project//mlir:VectorToLLVM", - ], -) - -################################################################################ -# CAPI -################################################################################ - -cc_library( - name = "CAPI", - srcs = glob( - ["lib/CAPI/*.cpp"], - allow_empty = True, - ), - hdrs = glob( - ["include/iree-dialects-c/*.h"], - allow_empty = True, - ), - includes = ["include"], - deps = [ - ":IREELinalgTransformDialect", - "@llvm-project//mlir:CAPIIR", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgTransformOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformDialect", - ], -) - -################################################################################ -# Test lib -################################################################################ - -cc_library( - name = "IREEDialectsTest", - deps = [ - ":IREEDialectsTransforms", - ":IREELinalgTransformDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Rewrite", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - ], -) - ################################################################################ # Tools ################################################################################ @@ -242,8 +35,6 @@ cc_binary( ], tags = ["hostonly"], deps = [ - ":IREEDialectsTest", - ":IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h deleted file mode 100644 index c729238d8f63..000000000000 --- a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_C_DIALECTS_H -#define IREE_DIALECTS_C_DIALECTS_H - -#include "mlir-c/IR.h" -#include "mlir-c/Pass.h" -#include "mlir-c/RegisterEverything.h" - -#ifdef __cplusplus -extern "C" { -#endif - -//===--------------------------------------------------------------------===// -// TransformDialect -//===--------------------------------------------------------------------===// - -MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform); - -MLIR_CAPI_EXPORTED void ireeRegisterTransformExtensions(MlirContext context); - -#ifdef __cplusplus -} -#endif - -#endif // IREE_DIALECTS_C_DIALECTS_H diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/CMakeLists.txt index 0ca0f41c5af4..e69de29bb2d1 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/CMakeLists.txt @@ -1 +0,0 @@ -add_subdirectory(Dialect) diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt deleted file mode 100644 index 1da2860785ea..000000000000 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(LinalgTransform) diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt deleted file mode 100644 index c0ebeb41f586..000000000000 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -function(_add_transform_dialect_extension) - set(LLVM_TARGET_DEFINITIONS StructuredTransformOpsExt.td) - mlir_tablegen(StructuredTransformOpsExt.h.inc -gen-op-decls) - mlir_tablegen(StructuredTransformOpsExt.cpp.inc -gen-op-defs) - add_public_tablegen_target(IREELinalgTransformExtIncGen) - add_dependencies(mlir-headers IREELinalgTransformExtIncGen) -endfunction() - -function(_add_structured_transform_doc) - set(LLVM_TARGET_DEFINITIONS StructuredTransformOpsExt.td) - mlir_tablegen(StructuredTransformOpsExt.md -gen-dialect-doc) - set(GEN_DOC_FILE ${IREE_DIALECTS_BINARY_DIR}/docs/Dialects/StructuredTransformOpsExt.md) - add_custom_command( - OUTPUT ${GEN_DOC_FILE} - COMMAND ${CMAKE_COMMAND} -E copy - ${CMAKE_CURRENT_BINARY_DIR}/StructuredTransformOpsExt.md - ${GEN_DOC_FILE} - DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/StructuredTransformOpsExt.md) - add_custom_target(StructuredTransformOpsExtDocGen DEPENDS ${GEN_DOC_FILE}) - add_dependencies(iree-dialects-doc StructuredTransformOpsExtDocGen) -endfunction() - -_add_transform_dialect_extension() -_add_structured_transform_doc() diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h deleted file mode 100644 index c7b6d0dc750b..000000000000 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H -#define IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H - -#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Dialect/Transform/IR/TransformOps.h" -#include "mlir/Dialect/Transform/IR/TransformTypes.h" -#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/OpDefinition.h" - -namespace mlir { -namespace linalg { -class LinalgOp; -} // namespace linalg -namespace scf { -class ForOp; -} // namespace scf -namespace transform_ext { -class MatchCallbackOp; -} // namespace transform_ext - -/// Matches a C++ callback previously registered under `callbackName` and -/// taking arguments `args`. -/// Unpacks a number of handles `N` (asserts there are exactly `N` matched -/// ops but this could be relaxed if needed). Returns the tuple of handles. -template -auto unpackRegisteredMatchCallback(ImplicitLocOpBuilder &b, - StringRef callbackName, - MatchingArgs... args) { - SmallVector matchedTypes(N, transform::AnyOpType::get(b.getContext())); - auto matchOp = b.create( - matchedTypes, callbackName, std::forward(args)...); - assert(matchOp->getNumResults() == N && "Unexpected number of results"); - std::array a; - for (int64_t i = 0; i < N; ++i) { - a[i] = matchOp->getResult(i); - } - return std::tuple_cat(a); -} - -/// A tracking listener for tensor IR that checks for payload replacement -/// errors. -class ErrorCheckingTrackingListener : public transform::TrackingListener { -public: - using transform::TrackingListener::TrackingListener; - - ~ErrorCheckingTrackingListener() override { - assert(status.succeeded() && "must check listener error state"); - } - - /// Return "true" if this tracking listener had a failure. - bool failed() const { return !status.succeeded(); } - - /// Check and return the current error state of this listener. In case of a - /// failure state, only the most recent error is returned. Afterwards, resets - /// the error state. - DiagnosedSilenceableFailure checkAndResetError() { - DiagnosedSilenceableFailure result(std::move(status)); - status = DiagnosedSilenceableFailure::success(); - return result; - } - -private: - void - notifyPayloadReplacementNotFound(Operation *op, ValueRange values, - DiagnosedSilenceableFailure &&diag) override; - - /// The error state of this listener. "Success" indicates that no error - /// happened so far. Otherwise, the status contains the most recent error. - DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success(); -}; - -} // namespace mlir - -#define GET_OP_CLASSES -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h.inc" - -namespace mlir { -namespace transform_ext { -class StructuredTransformOpsExtension - : public mlir::transform::TransformDialectExtension< - StructuredTransformOpsExtension> { -public: - StructuredTransformOpsExtension(); -}; - -} // namespace transform_ext -} // namespace mlir - -#endif // IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td deleted file mode 100644 index 81beb062615d..000000000000 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef STRUCTURED_TRANSFORM_OPS_EXT -#define STRUCTURED_TRANSFORM_OPS_EXT - -include "mlir/Dialect/Transform/IR/TransformAttrs.td" -include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/OpAsmInterface.td" -include "mlir/IR/OpBase.td" - -def RegisterMatchCallbacksOp : - Op, - DeclareOpInterfaceMethods, - ReportTrackingListenerFailuresOpTrait]> { - let description = [{ - Registers named structured op matcher callbacks specific for IREE to use - with `transform.iree.match_callback`. This should be called before first - `match_callback` may be executed following the transform dialect control - flow. - - The callbacks must have a unique name and a signature compatible with - `MatchCallbacksRegistry::MatchCallbackFn`, which currently means - `DiagnosedSilenceableFailure(MatchCallbackResult &, Location, - const TransformState &, ValueRange)`. The callback receives a "result", - followed by a location at which errors should be reported, a transform - state at the moment of the _match_ (not registration) and a list of - handle values passed as operands to the `match_callback` operation. - It is expected to populate the "result" object with lists of payload - operations that will be bound to the handles produced by the - `match_callback` operation. The callback may fail, at which point - it should produce a silenceable error. The callback currently is not - allowed to modify the payload IR (though this may be revised in the - future for the purpose of communicating the properties of the IR - captured by the match). Therefore, it should not have a reason to - produce a definite error. - }]; - - let arguments = (ins); - let results = (outs); - let assemblyFormat = "attr-dict"; - let cppNamespace = "mlir::transform_ext"; -} - -def MatchCallbackOp : - Op, - DeclareOpInterfaceMethods, - ReportTrackingListenerFailuresOpTrait]> { - let description = [{ - Performs payload IR matching using a C++ callback registered beforehand. - The callback is identified by name and is passed the current transform - state and the list of handle operands, along with information necessary - for error propagation. See `register_match_callbacks` for the description - of the callback contract. - - If `failure_propagation_mode` is set to `suppress`, any silenceable errors - in the callback (typically, "failure to match") will be ignored and the - resulting handles will be associated with empty lists of payload - operations. Otherwise, silenceable failures are propagated. - }]; - - let arguments = (ins StrAttr:$callback_name, - FailurePropagationMode:$failure_propagation_mode, - Variadic:$inputs); - let results = (outs Variadic:$outputs); - let assemblyFormat = "`failures` `(` $failure_propagation_mode `)` " - "$callback_name `(` $inputs `)` attr-dict " - "`:` functional-type($inputs, $outputs)"; - let cppNamespace = "mlir::transform_ext"; -} - -def TakeFirstOp : - Op, - DeclareOpInterfaceMethods, - ReportTrackingListenerFailuresOpTrait]> { - let description = [{ - Given an arbitrary list of handles associated with potentially empty lists - of payload operations, produces two new handles: - - - a handle pointing to the same payload operations as the first operand - handle with a non-empty list of payload operations; - - a handle pointing to the concatenated list of payload operations - associated with any other handle. - - Note that this does not perform any deduplication. - - This operation is useful to select a single target after some potentially - unsuccessful matches. - }]; - - let arguments = (ins Variadic:$inputs); - let results = (outs TransformHandleTypeInterface:$first, - TransformHandleTypeInterface:$rest); - let assemblyFormat = - "$inputs attr-dict `:` functional-type($inputs, results)"; - let cppNamespace = "mlir::transform_ext"; -} - -def EmitRemarkOp : - Op, - TransformOpInterface, TransformEachOpTrait, - ReportTrackingListenerFailuresOpTrait]> { - let description = [{ - Emits a diagnostic remark with the given message located at payload ops - associated with the given handle. This can be used, e.g., for debugging. - }]; - - let arguments = (ins TransformHandleTypeInterface:$handle, - StrAttr:$message); - let assemblyFormat = "$message `at` $handle attr-dict `:` type($handle)"; - let cppNamespace = "mlir::transform_ext"; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::transform::TransformRewriter &rewriter, - ::mlir::Operation *target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - -#endif // STRUCTURED_TRANSFORM_OPS_EXT diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h deleted file mode 100644 index 4b7d7169bb41..000000000000 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h +++ /dev/null @@ -1,1201 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_ -#define IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_ - -#include -#include -#include - -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" -#include "mlir/IR/Matchers.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/StringMap.h" - -namespace mlir { -namespace transform_ext { - -//===---------------------------------------------------------------------===// -// StructuredOpMatcher and predicates. -//===---------------------------------------------------------------------===// - -class StructuredOpMatcher; -class MatcherContext; -StructuredOpMatcher &m_StructuredOp(MatcherContext &); - -/// A tag indicating the shape being static or dynamic, for use with the -/// structured op matcher. -enum class ShapeKind { Static, Dynamic }; - -/// A placeholder indicating the structured op matcher to check the predicate -/// for all dimensions. -struct AllDims {}; - -/// A predicate indicating the structured op matcher to check the predicate for -/// all dimensions except the specified ones. -struct AllDimsExcept { - explicit AllDimsExcept(std::initializer_list range) { - llvm::append_range(exceptions, range); - } - ArrayRef getExcluded() const { return llvm::ArrayRef(exceptions); } - -private: - SmallVector exceptions; -}; - -/// A placeholder indicating the structured op matcher to check the predicate -/// for all operands of the relevant kind. -struct AllOperands {}; - -/// Base class for single-value captures. Concrete captures should inherit this -/// and forward the constructor via `using Base::Base`. -template -struct CaptureStaticValue { - using Base = CaptureStaticValue; - explicit CaptureStaticValue(T &value) : value(value) {} - T &value; -}; - -/// Captures the (static) size of the dimension. -struct CaptureDim : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures the (static) sizes of multiple dimensions. -struct CaptureDims : public CaptureStaticValue> { - using Base::Base; -}; - -/// Captures the contraction dimensions of the target operation. -struct CaptureIndexingMaps : public CaptureStaticValue> { - using Base::Base; -}; - -/// Captures the contraction dimensions of the target operation. -struct CaptureContractionDims - : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures the convolution dimensions of the target operation. -struct CaptureConvDims - : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures the rank of the operation. -struct CaptureRank : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures the bitwidth of an element type. -struct CaptureElementTypeBitWidth : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures element element type. -struct CaptureElementType : public CaptureStaticValue { - using Base::Base; -}; - -template -struct CaptureAttribute : public CaptureStaticValue { - static_assert(std::is_base_of_v, - "can only capture a subclass of Attribute"); - using CaptureStaticValue::CaptureStaticValue; -}; - -/// A tag indicating to look for any user of the operation's result that would -/// satisfy the predicate. -struct HasAnyUse {}; - -/// Base class for predicate parameters that can be described with the single -/// value. Concrete predicate parameters should inherit this and forward the -/// constructor via `using Base::Base`. -template -struct SingleValuePredicateParam { - using Base = SingleValuePredicateParam; - explicit SingleValuePredicateParam(T value) : value(value) {} - const T value; -}; - -/// Indicates that the dimension must be divisible by the given value. -struct DivisibleBy : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the number of entities must be equal to the given value. -struct NumEqualsTo : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the number of entities must be greater than the given value. -struct NumGreaterEqualTo : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the number of entities must be greater than the given value. -struct NumLowerEqualTo : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the bit width of the elemental type must be equal to the give -/// value. -struct ElementTypeBitWidth : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Predicate tag indicating that the affine map is a permutation. -struct IsPermutation {}; - -/// Predicate tag indicating that the affine map is a projected permutation. -struct IsProjectedPermutation {}; - -/// Predicate tag indicating that the affine map is a projection of given -/// dimension. -struct IsProjected : public SingleValuePredicateParam { - using Base::Base; -}; -/// Predicate tag indicating that the affine map is an identity. -struct IsIdentity {}; - -/// Predicate tag indicating that the operand is a special float constant. -struct ConstantFloatMinOrMinusInf {}; -struct ConstantFloatZero {}; -struct ConstantFloatOne {}; - -/// Indicates that the match optional. The matcher is still expected to run and -/// capture if successful. The parameter can be set to false -struct OptionalMatch : public SingleValuePredicateParam { - OptionalMatch() : Base(true) {} - explicit OptionalMatch(bool set) : Base(set) {} -}; - -/// Predicate tag indicating that the reduction is produced by a single combiner -/// operation. -struct SingleCombinerReduction {}; - -class CapturingOpMatcher; -class CapturingValueMatcher; - -/// Base class for capturing matchers that can be owned by the context. -class CapturingMatcherBase { -public: - // Virtual destructor so unique pointers are deallocated correctly. - // TODO: if efficiency is a problem, consider disallowing non-trivial - // destructors for subclasses. - virtual ~CapturingMatcherBase() = default; - -protected: - /// Informs the matcher that it has another, nested matcher. Derived classes - /// must call this to keep track of nested matchers for capture resetting - /// purposes. - template - void recordNestedMatcher(T &nested) { - if constexpr (std::is_base_of_v) { - nestedCapturingMatchers.push_back(&nested); - } - if constexpr (std::is_base_of_v) { - nestedCapturingValueMatchers.push_back(&nested); - } - } - - /// Appends all nested capturing matchers of a certain kind, excluding this - /// one, to `nested`. - void getAllNested(SmallVectorImpl &nested); - void - getAllNestedValueMatchers(SmallVectorImpl &nested); - - /// Resets nested capturing matchers but does NOT reset the current one. - void resetCapture(); - -private: - /// A list of (recursively) nested capturing matchers that should be reset - /// when the current matcher is. - SmallVector nestedCapturingMatchers; - SmallVector nestedCapturingValueMatchers; -}; - -/// A context object holding capturing matchers, must outlive any individual -/// matcher. When matching complex subgraphs, the caller often doesn't care -/// about all intermediate nodes (operations) in the graph and shouldn't need to -/// hold matcher objects for those. These matchers can be created in this -/// context. -class MatcherContext { -public: - /// Create a new matcher of the specified type owned by this context. - template - std::enable_if_t, T> & - allocate(Args &&...args) { - // Need to call "new" explicitly as make_unique wouldn't have access to the - // private constructor when this class would. - ownedMatchers.emplace_back( - std::unique_ptr(new T(std::forward(args)...))); - return *static_cast(ownedMatchers.back().get()); - } - -private: - /// Owning list of matchers. - // TODO: If this becomes inefficient, consider something like BumpPtrAllocator - // that derived classes can use to store their members as well. - SmallVector> ownedMatchers; -}; - -/// Base class for value matchers that capture the matched value. Stores a list -/// of predicates and requires all of them to match for the value to match. Once -/// a value matched, any repeated use just verifies that equality of the value. -class CapturingValueMatcher : public CapturingMatcherBase { - friend class CapturingMatcherBase; - friend class MatcherContext; - - using PredicateFn = std::function; - -public: - /// Resets the captured value to null. This should be called if the same - /// pattern needs to be applied more than once as it may keep captured values - /// for optional nested predicates from the previous application. - void resetCapture() { - captured = nullptr; - CapturingMatcherBase::resetCapture(); - } - - /// Returns the matched value if the match was successful. - Value getCaptured() const { return captured; } - - /// Matches the given value, hook for `matchPattern`. - bool match(Value value); - -protected: - CapturingValueMatcher() = default; - - /// Adds a predicate to the end of the predicate list for this value matcher. - template - void addPredicate(Fn &&predicate) { - predicates.emplace_back(std::forward(predicate)); - } - - /// The captured value. - Value captured = nullptr; - -private: - /// Additional predicates to be checked on the value. - SmallVector predicates; -}; - -/// Creates a matcher of an arbitrary value. -inline CapturingValueMatcher &m_Value(MatcherContext &context) { - return context.allocate(); -} - -/// Matcher for typed values whose type implements the `ShapedType` interface. -/// Allows for matching the components of the shaped type such as rank and -/// dimensions. -class ShapedValueMatcher : public CapturingValueMatcher { - friend class MatcherContext; - - ShapedValueMatcher(); - -public: - /// Add an always-succeeding matcher predicate capturing the rank. - ShapedValueMatcher &rank(CaptureRank capture); - - /// Add an always-succeeding matcher predicate capturing the size of the - /// dimension identified by the first argument. - ShapedValueMatcher &dim(int64_t dimension, CaptureDim capture); - - /// Add an always-succeeding matcher predicate capturing the sizes of all - /// dimensions in order of appearance. - ShapedValueMatcher &dim(AllDims tag, CaptureDims captures); - - /// Add an always-succeeding matcher predicate capturing the element type of - /// the value. - ShapedValueMatcher &elementType(CaptureElementType captures); -}; - -/// Construct a new matcher of a value whose type is a `ShapedType`, owned by -/// the given context. -inline ShapedValueMatcher &m_ShapedValue(MatcherContext &context) { - return context.allocate(); -} - -/// Matcher for operations with additional predicates attachable through the -/// fluent, a.k.a. chainable, API. Note that public API must *not* accept -/// additional callbacks even; new predicates should be added instead when -/// necessary. Not only this decreases the depth of the callback stack and -/// increases readability, it also allows us to port the matcher to a -/// declarative format using PDL and/or Transform dialect in the future. The -/// latter will become impossible with arbitrary C++ callbacks. -class CapturingOpMatcher : public CapturingMatcherBase { - friend class CapturingMatcherBase; - friend class MatcherContext; - - template - friend CapturingOpMatcher &m_Operation(MatcherContext &matcherContext); - -public: - using PredicateFn = std::function; - - /// Matches the given operation, hook for `matchPattern`. - bool match(Operation *op); - - /// Resets the captured value to null. This should be called if the same - /// pattern needs to be applied more than once as it may keep captured values - /// for optional nested predicates from the previous application. - void resetCapture() { - captured = nullptr; - CapturingMatcherBase::resetCapture(); - } - - /// Returns the matched operation if the match was successful. - Operation *getCaptured() const { return captured; } - - /// Adds alternative paths for predicates. In practice, this is just a - /// predicate that is satisfied when either the first or the second matcher is - /// satisfied. The alternative satisfaction is eager and short-cutting, i.e., - /// the second alternative will not be processed, and therefore will not - /// capture values, if the first alternative succeeded. - CapturingOpMatcher &alternatives(CapturingOpMatcher &first, - CapturingOpMatcher &second); - - //===-------------------------------------------------------------------===// - // Constraints on adjacent ops. - //===-------------------------------------------------------------------===// - - /// Adds a predicate checking that all ops implementing TilingInterface in the - /// parent of the given type (e.g., a function or a module) were matched by - /// this or nested matchers. This is useful to ensure that the matcher covered - /// the entire parent region, not just a parent of it. This predicate **must** - /// be added *after* all the other predicates that capture. - template - CapturingOpMatcher &allTilableOpsCaptured() { - SmallVector copy; - copy.push_back(this); - getAllNested(copy); - addPredicate([copy = std::move(copy)](Operation *op) { - Operation *parent = op->getParentOfType(); - return checkAllTilableMatched(parent, op, copy); - }); - return *this; - } - - //-------------------------------------------------------------------------// - // Predicates for operands and results. - //-------------------------------------------------------------------------// - - /// Adds a predicate checking that the operation has exactly the given number - /// of operands. - CapturingOpMatcher &operand(NumEqualsTo num); - - /// Adds a predicate checking that the `pos`-th operand of the operation is - /// defined by an operation that satisfies the given matcher. - CapturingOpMatcher &operand(int64_t pos, CapturingOpMatcher &nested); - - /// Adds a predicate checking that the `pos`-th operand of the operation - /// satisfies the given value matcher. - CapturingOpMatcher &operand(int64_t pos, CapturingValueMatcher &nested); - - /// Adds a predicate checking that the `pos`-th operand of the operation is - /// defined by `arith.constant` with the value 1.0. - // TODO: better matching for attributes. - CapturingOpMatcher &operand(int64_t pos, ConstantFloatOne); - - /// Adds a predicate checking that the operation has exactly the given number - /// of results. - CapturingOpMatcher &result(NumEqualsTo num); - - /// Adds a predicate checking that the `pos`-th result of the operation - /// satisfies the given value matcher. - CapturingOpMatcher &result(int64_t pos, CapturingValueMatcher &nested); - -protected: - /// Constructs a default operation matcher accepting any operation. - CapturingOpMatcher() = default; - - /// Adds a predicate for the matched operation to satisfy. - template - void addPredicate(Fn &&predicate) { - predicates.emplace_back(std::forward(predicate)); - } - - /// Produce the debug output for `create` method in a non-templated way. - static void debugOutputForCreate(ArrayRef opNames); - -private: - /// A list of additional conditions for the operation to match. - SmallVector predicates; - - /// Checks that `matchers` captured all tilable ops nested in `parent` except - /// for `linalgOp`. This is an implementation detail of allTilableOpsCaptured. - static bool checkAllTilableMatched(Operation *parent, Operation *op, - ArrayRef matchers); - - /// Creates a matcher for an operation with one of the given types. - template - static CapturingOpMatcher create() { - CapturingOpMatcher matcher; - matcher.addPredicate([](Operation *op) { - debugOutputForCreate(ArrayRef{OpType::getOperationName()...}); - return isa(op); - }); - return matcher; - } - - /// Common util for constant matcher. - CapturingOpMatcher &operand(int64_t position, - std::function floatValueFn); - -protected: - /// Matched value. - Operation *captured = nullptr; -}; - -namespace detail { -/// Prints the debug output from the ConcreteOpMatcher constructor. The -/// implementation must reside in the C++ file so we don't pollute the header -/// with debug includes, and ConcreteOpMatcher is a class template that can only -/// reside in the header. -void debugOutputForConcreteOpMatcherConstructor(StringRef name); -} // namespace detail - -/// Base class for matchers that match a specific op. Adds an initial predicate -/// checking if the op is indeed of the specified kind. -/// Derived classes specializing this for op interfaces MUST also define a -/// specialization of DebugOpKindDescription. -template -class ConcreteOpMatcher : public CapturingOpMatcher { -protected: - using Base = ConcreteOpMatcher; - - static StringRef getConcreteOpDescription() { - return OpTy::getOperationName(); - } - - /// Adds a predicate checking if the op is of the OpTy kind. - ConcreteOpMatcher() { - CapturingOpMatcher::addPredicate([](Operation *op) { - detail::debugOutputForConcreteOpMatcherConstructor( - Derived::getConcreteOpDescription()); - return isa(op); - }); - } - - /// Adds a predicate for the matched operation to satisfy. - template - Derived &addPredicate(FnTy &&predicate) { - // Dispatch to the callback. - CapturingOpMatcher::addPredicate( - [inner = std::move(predicate)](Operation *op) { - return inner(cast(op)); - }); - return static_cast(*this); - } - -public: - /// Adds alternative paths for predicates. In practice, this is just a - /// predicate that is satisfied when either the first or the second matcher is - /// satisfied. The alternative satisfaction is eager and short-cutting, i.e., - /// the second alternative will not be processed, and therefore will not - /// capture values, if the first alternative succeeded. - Derived &alternatives(CapturingOpMatcher &first, CapturingOpMatcher &second) { - return static_cast( - CapturingOpMatcher::alternatives(first, second)); - } - - /// Adds a predicate checking that all ops implementing TilingInterface in the - /// parent of the given type (e.g., a function or a module) were matched by - /// this or nested matchers. This is useful to ensure that the matcher covered - /// the entire parent region, not just a parent of it. This predicate **must** - /// be added *after* all the other predicates that capture. - template - Derived &allTilableOpsCaptured() { - return static_cast( - CapturingOpMatcher::allTilableOpsCaptured()); - } - - //-------------------------------------------------------------------------// - // Predicates for operands and results. - //-------------------------------------------------------------------------// - - /// Adds a predicate checking that the operation has exactly the given number - /// of operands. - Derived &operand(NumEqualsTo num) { - return static_cast(CapturingOpMatcher::operand(num)); - } - - /// Adds a predicate checking that the `pos`-th operand of the operation is - /// defined by an operation that satisfies the given matcher. - Derived &operand(int64_t pos, CapturingOpMatcher &nested) { - return static_cast(CapturingOpMatcher::operand(pos, nested)); - } - - /// Adds a predicate checking that the `pos`-th operand of the operation - /// satisfies the given value matcher. - Derived &operand(int64_t pos, CapturingValueMatcher &nested) { - return static_cast(CapturingOpMatcher::operand(pos, nested)); - } - - /// Adds a predicate checking that the `pos`-th operand of the operation is - /// defined by `arith.constant` with the value 1.0. - // TODO: better matching for attributes. - Derived &operand(int64_t pos, ConstantFloatOne c) { - return static_cast(CapturingOpMatcher::operand(pos, c)); - } - - /// Adds a predicate checking that the operation has exactly the given number - /// of results. - Derived &result(NumEqualsTo num) { - return static_cast(CapturingOpMatcher::result(num)); - } - - /// Adds a predicate checking that the `pos`-th result of the operation - /// satisfies the given value matcher. - Derived &result(int64_t pos, CapturingValueMatcher &nested) { - return static_cast(CapturingOpMatcher::result(pos, nested)); - } -}; - -/// Matcher for the `tensor.pad` operation. -class TensorPadOpMatcher - : public ConcreteOpMatcher { - friend class MatcherContext; - - TensorPadOpMatcher() = default; - -public: - /// Adds a predicate checking that the low padding sizes are exactly the given - /// values. - TensorPadOpMatcher &low(ArrayRef sizes); - - /// Adds a predicate checking that the low padding sizes for all dimensions - /// are exactly the same given value. - TensorPadOpMatcher &low(AllDims tag, int64_t size); - - /// Adds a predicate checking that the high padding sizes for all dimensions - /// are exactly the same given value. - TensorPadOpMatcher &high(ArrayRef sizes); - - /// Adds a predicate checking that the high padding sizes for all dimensions - /// are exactly the same given value. - TensorPadOpMatcher &high(AllDims tag, int64_t size); - - /// Adds a predicate checking that the body of the pad only yields values - /// defined outside the pad region. - TensorPadOpMatcher &yieldsExternalValue(); -}; - -inline TensorPadOpMatcher &m_tensorPad(MatcherContext &matcherContext) { - return matcherContext.allocate(); -} - -/// Creates a default operation matcher in the given context that accepts any -/// operation. -inline CapturingOpMatcher &m_Operation(MatcherContext &matcherContext) { - return matcherContext.allocate(); -} - -/// Creates an operation matcher in the given context that accepts only -/// operations of the kinds provided as template arguments. -template -inline CapturingOpMatcher &m_Operation(MatcherContext &matcherContext) { - return matcherContext.allocate( - CapturingOpMatcher::create()); -} - -/// Matcher for structured aka Linalg operations. -class StructuredOpMatcher - : public ConcreteOpMatcher { - friend class MatcherContext; - - StructuredOpMatcher() = default; - -public: - static StringRef getConcreteOpDescription() { - return "linalg interface implementation"; - } - - /// Creates a matcher for a structured operation with one of the given types. - template - static StructuredOpMatcher create() { - StructuredOpMatcher matcher; - matcher.addPredicate([](Operation *op) { - debugOutputForCreate(ArrayRef{OpType::getOperationName()...}); - return isa(op) && isa(op); - }); - return matcher; - } - - /// Matches a structured operation if either patterns A or B match. - StructuredOpMatcher(StructuredOpMatcher &A, StructuredOpMatcher &B); - - //===-------------------------------------------------------------------===// - // Constraints on op rank and dims. - //===-------------------------------------------------------------------===// - /// Adds a predicate checking that the given rank must be greater than some - /// constant value. - StructuredOpMatcher &rank(NumGreaterEqualTo minRank); - StructuredOpMatcher &rank(NumLowerEqualTo maxRank); - StructuredOpMatcher &rank(NumEqualsTo exactRank); - - /// Adds a predicate checking that the given iteration space dimension is - /// static/dynamic. The dimension index may be negative, in which case - /// dimensions are counted from the last one (i.e. Python-style), or be an - /// AllDims tag, in which case all dimensions are checked. This may be - /// eventually extended to slices and/or lists of dimensions. - StructuredOpMatcher &dim(int64_t dimension, ShapeKind kind) { - return dim(SmallVector{dimension}, kind); - } - StructuredOpMatcher &dim(SmallVector &&dimensions, ShapeKind kind); - StructuredOpMatcher &dim(AllDims tag, ShapeKind kind); - - /// Adds a predicate checking that the given iteration space dimension has the - /// given iterator type, e.g., parallel or reduction. The dimension index may - /// be negative, in which case dimensions are counted from the last one - /// (i.e. Python-style), or be an AllDims tag, in which case all dimensions - /// are checked. This may be eventually extended to slices and/or lists of - /// dimensions. - StructuredOpMatcher &dim(int64_t dimension, utils::IteratorType kind) { - return dim(SmallVector{dimension}, kind); - } - // Ownership may get tricky here so we wrap in an explicit vector. - StructuredOpMatcher &dim(SmallVector &&dimensions, - utils::IteratorType kind); - StructuredOpMatcher &dim(AllDims tag, utils::IteratorType kind); - StructuredOpMatcher &dim(AllDimsExcept &&dimensions, - utils::IteratorType kind); - - /// Adds a predicate checking that the given iteration space dimension is - /// statically known to be divisible by the given value. The dimension index - /// may be negative, in which case dimensions are counted from the last one - /// (i.e. Python-style). - StructuredOpMatcher &dim(int64_t dimension, DivisibleBy divisibleBy); - - //===-------------------------------------------------------------------===// - // Capture directives. - //===-------------------------------------------------------------------===// - StructuredOpMatcher &rank(CaptureRank capture); - StructuredOpMatcher &dim(int64_t dimension, CaptureDim capture); - StructuredOpMatcher &dim(AllDims tag, CaptureDims captures); - StructuredOpMatcher &indexingMaps(CaptureIndexingMaps indexingMaps); - StructuredOpMatcher &contractionDims(CaptureContractionDims contractionDims); - StructuredOpMatcher &convolutionDims(CaptureConvDims convDims); - - //===-------------------------------------------------------------------===// - // Constraints on input operands. - //===-------------------------------------------------------------------===// - /// Adds a predicate checking that the structured op has the given number of - /// inputs. - StructuredOpMatcher &input(NumEqualsTo num); - - /// Adds a predicate that recursively applies other predicates to the - /// operation defining the `position`-th operand. The position may be - /// negative, in which case positions are counted from the last one - /// (i.e. Python-style). When the match is optional, the predicate check - /// succeeds as long as the `position` is in bounds. The matcher is executed - /// if there is a defining operation for the input operand. - template - std::enable_if_t::value, - StructuredOpMatcher &> - input(int64_t position, T &operandMatcher, - OptionalMatch optional = OptionalMatch(false)) { - addInputMatcher( - position, - [&operandMatcher](Operation *op) { return operandMatcher.match(op); }, - optional); - recordNestedMatcher(operandMatcher); - return *this; - } - template - std::enable_if_t::value, - StructuredOpMatcher &> - input(int64_t position, T &operandMatcher, - OptionalMatch optional = OptionalMatch(false)) { - addInputMatcher( - position, - [&operandMatcher](Value v) { return operandMatcher.match(v); }, - optional); - recordNestedMatcher(operandMatcher); - return *this; - } - - /// Adds a predicate checking that all input operands of the structured op - /// have a permutation indexing map. - StructuredOpMatcher &input(AllOperands tag, IsPermutation); - - /// Adds a predicate checking that all input operands of the structured op - /// have a projected permutation indexing map. - StructuredOpMatcher &input(AllOperands tag, IsProjectedPermutation); - - /// Adds a predicate checking that all input operands of the structured op - /// are projected along the given dimension. - StructuredOpMatcher &input(SmallVector &&positions, IsProjected dim); - StructuredOpMatcher &input(int64_t position, IsProjected dim) { - return input(SmallVector{position}, dim); - } - - /// Adds a predicate checking that all input operands of the structured op - /// have identity indexing map. - StructuredOpMatcher &input(AllOperands tag, IsIdentity); - StructuredOpMatcher &input(SmallVector &&positions, IsIdentity); - StructuredOpMatcher &input(int64_t position, IsIdentity) { - return input(SmallVector{position}, IsIdentity()); - } - - /// Adds a predicate checking that the bit width of the elemental type of the - /// structured op input at the given position is equal to the given value. - StructuredOpMatcher &input(int64_t position, ElementTypeBitWidth width); - - /// Capture the elemental type bitwidth of input operand `position`. - StructuredOpMatcher &input(int64_t position, - CaptureElementTypeBitWidth width); - - /// Capture the elemental type of input operand `position`. - StructuredOpMatcher &input(int64_t position, CaptureElementType elem); - - /// Check if input is equal to a known constant. - // TODO: Support matching for constant ops. - StructuredOpMatcher &input(int64_t position, ConstantFloatMinOrMinusInf); - StructuredOpMatcher &input(int64_t position, ConstantFloatZero); - - //===-------------------------------------------------------------------===// - // Constraints on output operands. - //===-------------------------------------------------------------------===// - - /// Adds a predicate checking that the structured op has the given number of - /// outputs. - StructuredOpMatcher &output(NumEqualsTo num); - - /// Adds a predicate checking that all output operands of the structured op - /// have a permutation indexing map. - StructuredOpMatcher &output(AllOperands tag, IsPermutation); - - /// Adds a predicate checking that all output operands of the structured op - /// have a projected permutation indexing map. - StructuredOpMatcher &output(AllOperands tag, IsProjectedPermutation); - - /// Adds a predicate checking that all output operands of the structured op - /// have a - StructuredOpMatcher &output(AllOperands tag, IsProjected dim); - - /// Adds a predicate checking that all output operands of the structured op - /// have identity indexing map. - StructuredOpMatcher &output(AllOperands tag, IsIdentity); - - /// Adds a predicate checking that the bit width of the elemental type of the - /// structured op output at the given position is equal to the given value. - StructuredOpMatcher &output(int64_t position, ElementTypeBitWidth width); - - /// Capture the elemental type bitwidth of output operand `position`. - StructuredOpMatcher &output(int64_t position, - CaptureElementTypeBitWidth width); - - /// Capture the elemental type of output operand `position`. - StructuredOpMatcher &output(int64_t position, CaptureElementType elem); - - /// Adds a predicate checking that the output of the structured op is produced - /// by a reduction with a single-operation combinator (such as addf or mulf, - /// but not a compare+select pair). - StructuredOpMatcher &output(int64_t position, SingleCombinerReduction tag); - - /// Adds a predicate that recursively applies other predicates to the - /// operation defining the init/out operand corresponding to `position`-th - /// output. The position may be negative, in which case positions are counted - /// from the last one (i.e. Python-style). When the match is optional, the - /// predicate check succeeds as long as the `position` is in bounds. The - /// matcher executed if there is a defining operation for the output operand. - template - std::enable_if_t::value, - StructuredOpMatcher &> - output(int64_t position, T &operandMatcher, - OptionalMatch optional = OptionalMatch(false)) { - addOutputMatcher( - position, - [&operandMatcher](Operation *op) { return operandMatcher.match(op); }, - optional); - recordNestedMatcher(operandMatcher); - return *this; - } - - //===-------------------------------------------------------------------===// - // Constraints on results. - //===-------------------------------------------------------------------===// - - /// Adds a predicate that recursively applies to users of the `position`-th - /// result of the structured op. Succeeds if any user matches the predicate. - /// When the match is optional, the predicate check succeeds as long as the - /// `position` is in bounds, after running the given matcher. - template - std::enable_if_t::value, - StructuredOpMatcher &> - result(int64_t position, HasAnyUse tag, T &resultUserMatcher, - OptionalMatch optional = OptionalMatch(false)) { - addResultMatcher( - position, tag, - [&resultUserMatcher](Operation *op) { - return resultUserMatcher.match(op); - }, - optional); - recordNestedMatcher(resultUserMatcher); - return *this; - } - - //===-------------------------------------------------------------------===// - // Constraints on op region. - //===-------------------------------------------------------------------===// - - /// Return true if the linalg op only contains a single ops and the arguments - /// of the operation match the order of the linalg operand. - /// Example: - /// linalg.generic - /// ins(%0, %1 : tensor, tensor) - /// outs(%2 : tensor) { - /// ^bb0(%arg0: f32, %arg1: f32): - /// %3 = arith.maxf %arg0, %arg1 : f32 - /// linalg.yield %3 : f32 - /// } -> tensor - /// If commutative is set binary operations can have their operands swapped. - template - StructuredOpMatcher &singleOpWithCanonicaleArgs(bool commutative = false) { - return singleOpWithCanonicaleArgs(OpType::getOperationName(), commutative); - } - StructuredOpMatcher &singleOpWithCanonicaleArgs(StringRef opname, - bool commutative); - /// Check if the op is a linalg of with a single float reciprocal op. - StructuredOpMatcher &isFloatReciprocal(); - /// Check if the op is a linalg of with a region containing only a yield op - /// using block arguments in order. - StructuredOpMatcher &passThroughOp(); - - /// Check if the body of the linalg op implements a contraction of the kind - /// result = input1 input2 - template - StructuredOpMatcher &hasContractionBody() { - return hasContractionBody( - [](Operation *op) { return isa(op); }, - [](Operation *op) { return isa(op); }, - ElemOpTy::getOperationName(), ReductionOpTy::getOperationName()); - } - -private: - /// Non-template implementations of nested predicate builders for inputs, - /// outputs and results. Should not be called directly. - void addInputMatcher(int64_t position, - std::function matcher, - OptionalMatch optional); - void addInputMatcher(int64_t position, std::function matcher, - OptionalMatch optional); - void addOutputMatcher(int64_t position, - std::function matcher, - OptionalMatch optional); - void addResultMatcher(int64_t position, HasAnyUse tag, - std::function matcher, - OptionalMatch optional); - - // Common util for constant matcher. - StructuredOpMatcher &input(int64_t position, - std::function floatValueFn); - - /// Non-template implementation of hasContractionBody. Takes callbacks for - /// checking operation kinds and names for error reporting. - StructuredOpMatcher & - hasContractionBody(function_ref isaElemOpTy, - function_ref isaReductionOpTy, - StringRef elemOpName, StringRef reductionOpName); -}; - -/// Creates a matcher of an arbitrary structured op. -inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext) { - return matcherContext.allocate(); -} - -/// Creates a matcher that is a copy of the given matcher. -inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext, - const StructuredOpMatcher &other) { - return matcherContext.allocate(other); -} - -/// Creates a matcher that accepts as disjunction of the two given matchers. -inline StructuredOpMatcher &m_StructuredOp_Or(MatcherContext &matcherContext, - StructuredOpMatcher &A, - StructuredOpMatcher &B) { - return matcherContext.allocate(A, B); -} - -/// Creates a matcher of a structured op with kinds provided as template -/// arguments. -template -inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext) { - return matcherContext.allocate( - StructuredOpMatcher::create()); -} - -//===---------------------------------------------------------------------===// -// MatchCallback functionality. -//===---------------------------------------------------------------------===// - -/// Additional results of the C++ callback usable in the `match_callback` -/// transform operation. Conceptually, a list of lists of payload operations to -/// be associated with each result handle. -class MatchCallbackResult { -public: - /// Returns the number of lists of payload operations. - int64_t getNumPayloadGroups() const { return payloadGroupLengths.size(); } - - /// Returns the `position`-th list of payload operations. - ArrayRef getPayloadGroup(int64_t position) const; - - /// Adds a new list of payload operations to the list of lists. The new list - /// must not contain null operations. - template - int64_t addPayloadGroup(Range operations) { - int64_t originalLength = payloadOperations.size(); - assert(llvm::all_of(operations, [](Operation *op) -> bool { return op; }) && - "null operation"); - llvm::append_range(payloadOperations, operations); - payloadGroupLengths.push_back(payloadOperations.size() - originalLength); - return payloadGroupLengths.size() - 1; - } - void addPayloadGroup(ArrayRef operations) { - addPayloadGroup>(operations); - } - - /// Adds a new singleton list of payload operation to the list of lists if the - /// operation is non-null, adds an empty list otherwise. Useful for results of - /// optional matches. - void addPotentiallyEmptyPayloadGroup(Operation *op) { - if (!op) { - addPayloadGroup(ArrayRef()); - } else { - addPayloadGroup(ArrayRef(op)); - } - } - -private: - /// The flat list of all payload operations. `payloadGroupLengths` can be used - /// to compute the sublist that corresponds to one nested list. - // TODO: if somebody implements such a flattened vector generically, use it. - SmallVector payloadOperations; - SmallVector payloadGroupLengths; -}; - -/// A transform state extension that maintains the mapping between callback -/// names as strings usable in `match_callback` and their implementations. -class MatchCallbacksRegistry : public transform::TransformState::Extension { -public: - using MatchCallbackFn = std::function; - - /// Constructs the extension. - MatchCallbacksRegistry(transform::TransformState &state) - : transform::TransformState::Extension(state) {} - - /// Registers the given function as a callback with the given name. The name - /// must not be already present in the registry. The callback must be - /// convertible to MatchCallbackFn. - template - void registerCallback(StringRef name, Fn &&fn) { - bool succeeded = callbacks.try_emplace(name, std::forward(fn)).second; - (void)succeeded; - assert(succeeded && "adding a callback with a repeated name"); - } - - /// Returns a pointer to the implementation of the callback with the given - /// name, or null if it is not present in the registry. - const MatchCallbackFn *get(StringRef name) const { - auto iter = callbacks.find(name); - if (iter == callbacks.end()) { - return nullptr; - } - return &iter->getValue(); - } - -private: - llvm::StringMap callbacks; -}; - -//===---------------------------------------------------------------------===// -// Case-specific matcher builders. -//===---------------------------------------------------------------------===// - -struct MatchedReductionCaptures { - int64_t reductionRank = 0; - int64_t maybeLeadingRank = 0; - int64_t maybeTrailingRank = 0; - SmallVector leadingOpSizes = {}; - SmallVector reductionOpSizes = {}; - SmallVector trailingOpSizes = {}; - int64_t reductionOutputElementalTypeBitWidth = 0; - int64_t maybeLeadingOutputElementalTypeBitWidth = 0; - int64_t maybeTrailingOutputElementalTypeBitWidth = 0; -}; - -struct MatchedMatmulCaptures { - linalg::ContractionDimensions contractionDims = {}; - Type lhsElementType, rhsElementType, outputElementType; - SmallVector matmulOpSizes = {}; - SmallVector indexingMaps; - - /// Helper functions. - int64_t rank() const { return matmulOpSizes.size(); } - /// Return all batches. - ArrayRef batches() const { return contractionDims.batch; } - /// Return the most minor candidate dimension for `m`. - int64_t m() const { return contractionDims.m.back(); } - /// Return the most minor candidate dimension for `n`. - int64_t n() const { return contractionDims.n.back(); } - /// Return the most minor candidate dimension for `k`. - int64_t k() const { return contractionDims.k.back(); } - /// AffineMap for indexing into the LHS. - AffineMap lhsIndexing() const { - assert(indexingMaps.size() == 3 && "expected 3 indexing maps"); - return indexingMaps[0]; - } - /// AffineMap for indexing into the RHS. - AffineMap rhsIndexing() const { - assert(indexingMaps.size() == 3 && "expected 3 indexing maps"); - return indexingMaps[1]; - } - /// AffineMap for indexing into the RES. - AffineMap resIndexing() const { - assert(indexingMaps.size() == 3 && "expected 3 indexing maps"); - return indexingMaps[2]; - } -}; - -/// Creates a group of matchers for: -/// -/// trailing(reduction(leading(), fill())) -/// -/// where trailing and leading are elementwise operations whose presence is -/// optional. Each matcher will capture the corresponding operation. If -/// `mustMatchEntireFunc` is set, the matcher additionally checks if all -/// tileable operations in the functions are captured. -void makeReductionMatcher(MatcherContext &context, - StructuredOpMatcher *&reductionCapture, - StructuredOpMatcher *&fillCapture, - StructuredOpMatcher *&leadingCapture, - StructuredOpMatcher *&trailingCapture, - MatchedReductionCaptures &captures, - bool mustMatchEntireFunc); -void makeReductionMatcher(MatcherContext &context, - StructuredOpMatcher *&reductionCapture, - MatchedReductionCaptures &captures, - bool mustMatchEntireFunc); -/// -/// trailing(matmul(*, *, fill())) -/// -/// where trailing and leading are elementwise operations whose presence is -/// optional. Each matcher will capture the corresponding operation. If -/// `mustMatchEntireFunc` is set, the matcher additionally checks if all -/// tileable operations in the functions are captured. -void makeMatmulMatcher(MatcherContext &matcherContext, - StructuredOpMatcher *&matmulCapture, - StructuredOpMatcher *&fillCapture, - StructuredOpMatcher *&trailingCapture, - MatchedMatmulCaptures &captures, - bool mustMatchEntireFunc); - -/// Create a group of matchers of batch matmul with a fill: -/// -/// batch_matmul(*, *, fill()) -/// -/// and capture various useful quantities. If `mustMatchEntireFunc` is set, the -/// matcher additionally checks if all tileable operations in the functions are -/// captured. -void makeBatchMatmulMatcher(transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&bmmCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::MatchedMatmulCaptures &captures, - bool mustMatchEntireFunc); - -/// Create a group of matchers for a different code sequence of operations -/// matching exactly a softmax operation. -/// -/// %red = reduce_max(%0) -/// %sub = sub(%0, %red) -/// %exp = exp(%sub) -/// %sum = reduce_sum(%exp) -/// %mul = div(%exp, %%sum) -void makeSoftmaxMatcher(MatcherContext &context, - StructuredOpMatcher *&maxReductionCapture, - StructuredOpMatcher *&softmaxRootCapture); - -struct MatchedConvolutionCaptures { - Type inputElementType, filterElementType, outputElementType; - mlir::linalg::ConvolutionDimensions convolutionDims = {}; - SmallVector convolutionOpSizes = {}; - SmallVector trailingOpSizes = {}; - int64_t maybeTrailingOutputElementalTypeBitWidth = 0; - int64_t maybeFillElementalTypeBitWidth = 0; -}; - -/// Creates a group of matchers for: -/// -/// trailing(convolution(input, filter, fill())) -/// -/// where fill is a FillOp and trailing is an elementwise operation, both of -/// which is optional. Each matcher will capture the corresponding operation. If -/// `mustMatchEntireFunc` is set, the matcher additionally checks if all -/// tileable operations in the functions are captured. -void makeConvolutionMatcher(MatcherContext &context, - StructuredOpMatcher *&convolutionCapture, - StructuredOpMatcher *&fillCapture, - StructuredOpMatcher *&trailingCapture, - MatchedConvolutionCaptures &captures, - bool mustMatchEntireFunc); -void makeConvolutionMatcher(MatcherContext &context, - StructuredOpMatcher *&convolutionCapture, - MatchedConvolutionCaptures &captures, - bool mustMatchEntireFunc); - -struct MatchedPadCaptures { - int64_t rank = 0; - Type elementType; - SmallVector dims = {}; -}; - -/// Create a matcher for tensor.pad(*) without leading or trailing ops atm. -/// If `mustMatchEntireFunc` is set, the matcher additionally checks if all -/// tileable operations in the functions are captured. -void makePadMatcher(MatcherContext &context, CapturingOpMatcher *&padCapture, - MatchedPadCaptures &captures, bool mustMatchEntireFunc); - -/// Wraps the given matcher callback to indicate that it must capture all -/// tilable ops in the parent function. Expects the callback to accept the same -/// arguments as what is expected by MatchCallbacksRegistry::register, followed -/// by a bool. -template -auto wrapAsEntireFuncMatch(Fn &&fn) { - return [fn = std::move(fn)](MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - return fn(res, loc, state, handles, true); - }; -} - -/// Wraps the given matcher callback to indicate that it can match subgraphs. -/// Expects the callback to accept the same arguments as what is expected by -/// MatchCallbacksRegistry::register, followed by a bool. -template -auto wrapAsPartialMatch(Fn &&fn) { - return [fn = std::move(fn)](MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - return fn(res, loc, state, handles, false); - }; -} - -} // namespace transform_ext -} // namespace mlir - -#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_ diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt deleted file mode 100644 index 9ed9512dd331..000000000000 --- a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -add_mlir_public_c_api_library(IREEDialectsCAPI - Dialects.cpp - LINK_LIBS PUBLIC - IREELinalgTransformDialect - MLIRIR - MLIRLinalgTransformOps - MLIRTransformDialect -) - -iree_dialects_target_includes(IREEDialectsCAPI) diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp deleted file mode 100644 index 3bcb0b0d0b3e..000000000000 --- a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects-c/Dialects.h" - -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" -#include "mlir/CAPI/IR.h" -#include "mlir/CAPI/Pass.h" -#include "mlir/CAPI/Registration.h" -#include "mlir/CAPI/Support.h" -#include "mlir/CAPI/Utils.h" -#include "mlir/CAPI/Wrap.h" -#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Support/LLVM.h" - -using namespace mlir; - -//===--------------------------------------------------------------------===// -// TransformDialect -//===--------------------------------------------------------------------===// - -void ireeRegisterTransformExtensions(MlirContext context) { - MLIRContext *ctx = unwrap(context); - DialectRegistry registry; - registry - .addExtensions(); - ctx->appendDialectRegistry(registry); -} diff --git a/llvm-external-projects/iree-dialects/lib/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/CMakeLists.txt index 76b98c3aea72..e69de29bb2d1 100644 --- a/llvm-external-projects/iree-dialects/lib/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/lib/CMakeLists.txt @@ -1,3 +0,0 @@ -add_subdirectory(CAPI) -add_subdirectory(Dialect) -add_subdirectory(Transforms) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt deleted file mode 100644 index 1da2860785ea..000000000000 --- a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(LinalgTransform) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt deleted file mode 100644 index f33061b2d87c..000000000000 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt deleted file mode 100644 index 6eb51d339055..000000000000 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ -add_mlir_library(IREELinalgTransformDialect - StructuredTransformOpsExt.cpp - - DEPENDS - mlir-headers - - LINK_LIBS PUBLIC - IREEDialectsTransforms - MLIRIR - - MLIRAsyncDialect - MLIRControlFlowInterfaces - MLIRLinalgDialect - MLIRPDLDialect - MLIRRewrite - MLIRTransformDialect - MLIRTransformDialectInterfaces - MLIRTransformPDLExtension - - # Transforms - MLIRAffineToStandard - MLIRAsyncTransforms - MLIRLinalgTransforms - MLIRMemRefTransforms - MLIRReconcileUnrealizedCasts - MLIRTensorTransformOps - MLIRTransforms - MLIRVectorToSCF - - # Conversions - MLIRAsyncToLLVM - MLIRIndexToLLVM - MLIRMathToLLVM - MLIRMemRefToLLVM - MLIRSCFToControlFlow - MLIRVectorToLLVM - MLIRVectorToLLVMPass -) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp deleted file mode 100644 index 6f9ae25b15fc..000000000000 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp +++ /dev/null @@ -1,999 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" - -#include "iree-dialects/Transforms/TransformMatchers.h" -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" -#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" -#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/LoopUtils.h" -#include "mlir/Dialect/Async/Passes.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" -#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" -#include "mlir/Transforms/Passes.h" -#include "llvm/ADT/ScopeExit.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "transform-ops-ext" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// Additional constraints for PDLMatchOp. -//===----------------------------------------------------------------------===// - -/// Hook for PDL driver to check if an operation (`pdlValues[0]`) is directly -/// nested in a function with the name provided by an attribute -/// (`pdlValues[1]`). -/// TODO: PDL needs user-defined "questions". -static LogicalResult nestedInFunc(PatternRewriter &rewriter, - PDLResultList &pdlResults, - ArrayRef pdlValues) { - assert(pdlValues.size() == 2 && "expected 2 PDL values"); - Operation *operation = pdlValues[0].cast(); - Attribute attr = pdlValues[1].cast(); - - auto func = operation->getParentOfType(); - if (!func) { - return rewriter.notifyMatchFailure(operation, "not nested in a function"); - } - auto functionSymbol = dyn_cast(attr); - if (!functionSymbol) { - return rewriter.notifyMatchFailure(operation, "not a function identifier"); - } - return success(functionSymbol.getLeafReference() == func.getName()); -} - -/// Construct a IRMapping from `linalgOp` to `genericLinalgModelOp`. -/// Walk both ops and check whether all subops are the same. -static LogicalResult -haveIdenticalBodiesImpl(linalg::LinalgOp linalgOp, - linalg::LinalgOp genericLinalgModelOp) { - IRMapping bvm; - bvm.map(linalgOp.getBlock()->getArguments(), - genericLinalgModelOp.getBlock()->getArguments()); - SmallVector linalgBodyOps; - linalgOp.getBlock()->walk( - [&](Operation *op) { linalgBodyOps.push_back(op); }); - - unsigned idx = 0; - WalkResult res = genericLinalgModelOp.getBlock()->walk([&](Operation *op) { - Operation *linalgSubOp = linalgBodyOps[idx++]; - if (op->getName() != linalgSubOp->getName()) { - return WalkResult::interrupt(); - } - if (op->getAttrs() != linalgSubOp->getAttrs()) { - return WalkResult::interrupt(); - } - for (auto it : llvm::zip(op->getOperands(), linalgSubOp->getOperands())) { - if (std::get<0>(it) != bvm.lookupOrNull(std::get<1>(it))) { - return WalkResult::interrupt(); - } - } - bvm.map(linalgSubOp->getResults(), op->getResults()); - return WalkResult::advance(); - }); - - return success(!res.wasInterrupted()); -} - -/// Dispatch body equivalence check depending on case. -static LogicalResult haveEquivalentBodies(linalg::LinalgOp linalgOp, - linalg::LinalgOp genericLinalgModelOp, - PatternRewriter &rewriter) { - if (succeeded(haveIdenticalBodiesImpl(linalgOp, genericLinalgModelOp))) { - return success(); - } - // TODO: haveEquivalentBodiesImpl, see e.g. - // https://gist.github.com/nicolasvasilache/39e89e18c46e02335c16db6ec20a07e3 - return failure(); -} - -/// Succeed when `linalgOp` and `linalgModelOp` are deemed equivalent. -static LogicalResult isEquivalentToOpImpl(PatternRewriter &rewriter, - linalg::LinalgOp linalgOp, - linalg::LinalgOp linalgModelOp) { - // If basic properties do not match, return failure. - { - SmallVector opInputs = linalgOp.getDpsInputs(); - SmallVector modelInputs = linalgModelOp.getDpsInputs(); - ValueRange opOutputs = linalgOp.getDpsInits(); - ValueRange modelOutputs = linalgModelOp.getDpsInits(); - auto notEqualFn = [](std::tuple in) -> bool { - return std::get<0>(in) != std::get<1>(in); - }; - - if (opInputs.size() != modelInputs.size() || - opOutputs.size() != modelOutputs.size() || - llvm::any_of(llvm::zip(opInputs, modelInputs), notEqualFn) || - llvm::any_of(llvm::zip(opOutputs, modelOutputs), notEqualFn) || - linalgOp.getIndexingMaps() != linalgModelOp.getIndexingMaps() || - linalgOp.getIteratorTypesArray() != - linalgModelOp.getIteratorTypesArray()) { - return failure(); - } - } - - // Build the block and go perform a body comparison. - { - // createBlock moves the insertion point, scope it in an RAII block. - OpBuilder::InsertionGuard guard(rewriter); - Region &r = linalgModelOp->getRegion(0); - Block *bodyBlock = rewriter.createBlock( - &r, r.end(), linalgOp.getBlock()->getArgumentTypes(), - llvm::map_to_vector<4>(linalgOp.getBlock()->getArguments(), - [](Value v) { return v.getLoc(); })); - ImplicitLocOpBuilder b(linalgModelOp.getLoc(), rewriter); - auto regionBuilder = linalgModelOp.getRegionBuilder(); - llvm::ArrayRef attrs = {}; - regionBuilder(b, *bodyBlock, attrs, /*emitError=*/{}); - } - - return haveEquivalentBodies(linalgOp, linalgModelOp, rewriter); -} - -/// Check whether the unique Operation* stored in `pdlValues[0]` (assumed) is -/// equivalent to the unique StringRefAttr passed in `pdlValues[1]` (assumed). -/// Equivalence is achieved when either: -/// 1. `pdlValues[0]` has the name stored in `pdlValues[1]`. -/// 2. `pdlValues[0]` and `pdlValues[1]` are both linalg ops and their -/// structured interfaces as well as their bodies are equivalent. -/// Structured interfaces equivalence is a simple attribute level check. -/// Body equivalence is more involved and currently limited: -/// a. the current impl constructs an instance of the op whose name is -/// specified in `pdlValues[1]` and checks for exact body equality. -/// b. a more advanced version would "subtract" the bodies and fold, cse -/// and canonicalize to fixed point. If the result is "all zeros", -/// then the bodies would be equivalent (really isomorphic). -/// 3. other cases TBD (e.g. vector.generic when available). -static LogicalResult isEquivalentToOp(PatternRewriter &rewriter, - PDLResultList &pdlResults, - ArrayRef pdlValues) { - assert(pdlValues.size() == 2 && "expected 2 PDL values"); - Operation *operation = pdlValues[0].cast(); - Attribute attribute = pdlValues[1].cast(); - - auto modelOpNameAttr = dyn_cast(attribute); - if (!modelOpNameAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - auto modelOpName = modelOpNameAttr.strref(); - - // 1. If op has name `modelOpName`, the match is trivial. - if (operation->getName().getStringRef() == modelOpName) { - return success(); - } - - // 2. Linalg vs Linalg. - // Create op from `modelOpName`. - OperationState modelOpState( - operation->getLoc(), modelOpName, operation->getOperands(), - operation->getResultTypes(), operation->getAttrs()); - modelOpState.addRegion(); - Operation *modelOp = rewriter.create(modelOpState); - auto g1 = llvm::scope_exit([&]() { rewriter.eraseOp(modelOp); }); - linalg::LinalgOp linalgOp = dyn_cast(operation); - linalg::LinalgOp linalgModelOp = dyn_cast(modelOp); - if (linalgOp && linalgModelOp) { - return isEquivalentToOpImpl(rewriter, linalgOp, linalgModelOp); - } - - // 3. TBD - return failure(); -} - -/// Assume that: -/// 1. `pdlValues[0]` is an operands range -/// 2. `pdlValues[1]` contains a DictAttr with `operand_number`, `dim` and -/// `divisor` IntegerAttr entries. -/// Succeed if `operands`[`operand_number`] is a ranked type whose `dim` is a -/// multiple of `divisor`. -/// Note: 0 is the convention to express "do not tile", it is considered to -/// divide everything. -static LogicalResult isDimMultipleOf(PatternRewriter &rewriter, - PDLResultList &pdlResults, - ArrayRef pdlValues) { - assert(pdlValues.size() == 2 && "expected 2 PDL values"); - ValueRange operands = pdlValues[0].cast(); - Attribute attribute = pdlValues[1].cast(); - - auto dict = dyn_cast(attribute); - if (!dict) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - - int64_t dim; - auto dimAttr = dict.getAs("dim"); - if (!dimAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - dim = dimAttr.getInt(); - - int64_t divisor; - auto divisorAttr = dict.getAs("divisor"); - if (!divisorAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - divisor = divisorAttr.getInt(); - - int64_t operandNumber; - auto operandNumberAttr = dict.getAs("operand_number"); - if (!operandNumberAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - operandNumber = operandNumberAttr.getInt(); - - ShapedType shapedType; - if (static_cast(operands.size()) > operandNumber) { - shapedType = dyn_cast(operands[operandNumber].getType()); - } - if (!shapedType || shapedType.getRank() <= dim) { - return failure(); - } - return success(divisor == 0 || (shapedType.getShape()[dim] > 0 && - shapedType.getShape()[dim] % divisor == 0)); -} - -/// Assume that: -/// 1. `pdlValues[0]` is an operands range -/// 2. `pdlValues[1]` contains a DictAttr with `operand_number` and `dim` -/// IntegerAttr entries. -/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is -/// dynamic. -static LogicalResult isDimStatic(PatternRewriter &rewriter, - PDLResultList &pdlResults, - ArrayRef pdlValues) { - assert(pdlValues.size() == 2 && "expected 2 PDL values"); - ValueRange operands = pdlValues[0].cast(); - Attribute attribute = pdlValues[1].cast(); - - auto dict = dyn_cast(attribute); - if (!dict) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - - int64_t dim; - auto dimAttr = dict.getAs("dim"); - if (!dimAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - dim = dimAttr.getInt(); - - int64_t operandNumber; - auto operandNumberAttr = dict.getAs("operand_number"); - if (!operandNumberAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - operandNumber = operandNumberAttr.getInt(); - - ShapedType shapedType; - if (static_cast(operands.size()) > operandNumber) { - shapedType = dyn_cast(operands[operandNumber].getType()); - } - return success(shapedType && !shapedType.isDynamicDim(dim)); -} - -/// Assume that: -/// 1. `pdlValues[0]` is an operands range -/// 2. `pdlValues[1]` contains a DictAttr with `operand_number` and `dim` -/// IntegerAttr entries. -/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is -/// dynamic. -static LogicalResult isDimDynamic(PatternRewriter &rewriter, - PDLResultList &pdlResults, - ArrayRef pdlValues) { - assert(pdlValues.size() == 2 && "expected 2 PDL values"); - ValueRange operands = pdlValues[0].cast(); - Attribute attribute = pdlValues[1].cast(); - - auto dict = dyn_cast(attribute); - if (!dict) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - - int64_t dim; - auto dimAttr = dict.getAs("dim"); - if (!dimAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - dim = dimAttr.getInt(); - - int64_t operandNumber; - auto operandNumberAttr = dict.getAs("operand_number"); - if (!operandNumberAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - operandNumber = operandNumberAttr.getInt(); - - ShapedType shapedType; - if (static_cast(operands.size()) > operandNumber) { - shapedType = dyn_cast(operands[operandNumber].getType()); - } - return success(shapedType && shapedType.isDynamicDim(dim)); -} - -//===----------------------------------------------------------------------===// -// StructuredTransformOpsExtension -//===----------------------------------------------------------------------===// - -mlir::transform_ext::StructuredTransformOpsExtension:: - StructuredTransformOpsExtension() { - registerTransformOps< -#define GET_OP_LIST -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.cpp.inc" - >(); - - addDialectDataInitializer( - [&](transform::PDLMatchHooks &hooks) { - llvm::StringMap constraints; - constraints.try_emplace("nestedInFunc", nestedInFunc); - constraints.try_emplace("isDimDynamic", isDimDynamic); - constraints.try_emplace("isDimMultipleOf", isDimMultipleOf); - constraints.try_emplace("isDimStatic", isDimStatic); - constraints.try_emplace("isEquivalentToOp", isEquivalentToOp); - hooks.mergeInPDLMatchHooks(std::move(constraints)); - }); - - declareDependentDialect(); - declareDependentDialect(); -} - -#define GET_OP_CLASSES -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.cpp.inc" - -//===----------------------------------------------------------------------===// -// ErrorCheckingTrackingListener -//===----------------------------------------------------------------------===// - -void ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( - Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) { - // Certain ops can dropped safely. - if (isa(op)) { - LLVM_DEBUG(DBGS() << "Silently dropping scf.for op mapping\n"); - return; - } - - SmallVector diags; - diag.takeDiagnostics(diags); - if (!status.succeeded()) { - status.takeDiagnostics(diags); - } - status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags)); - - status = emitSilenceableFailure( - getTransformOp(), "!!! tracking listener failed to find replacement op"); - status.attachNote(op->getLoc()) << "replaced op"; - for (Value v : values) { - status.attachNote(v.getLoc()) << "replacement value"; - } -} - -//===---------------------------------------------------------------------===// -// MatchCallbackOp -//===---------------------------------------------------------------------===// - -DiagnosedSilenceableFailure transform_ext::MatchCallbackOp::apply( - mlir::transform::TransformRewriter &rewriter, - mlir::transform::TransformResults &results, - mlir::transform::TransformState &state) { - auto setEmptyResults = [&results, this] { - for (OpResult value : getResults()) { - results.set(value, {}); - } - }; - auto errorOut = [this, &setEmptyResults] { - setEmptyResults(); - return emitSilenceableError(); - }; - - auto *registry = state.getExtension(); - if (!registry) { - return errorOut() << "match registry not available"; - } - - const transform_ext::MatchCallbacksRegistry::MatchCallbackFn *callback = - registry->get(getCallbackName()); - if (!callback) { - return errorOut() << "callback '" << getCallbackName() - << "' not found in the registry"; - } - - MatchCallbackResult result; - DiagnosedSilenceableFailure status = - (*callback)(result, getLoc(), state, getInputs()); - if (!status.succeeded()) { - setEmptyResults(); - if (status.isDefiniteFailure()) { - return status; - } - if (getFailurePropagationMode() == - mlir::transform::FailurePropagationMode::Propagate) { - return emitSilenceableError() << "failed to match"; - } else { - return DiagnosedSilenceableFailure::success(); - } - } - if (getNumResults() != result.getNumPayloadGroups()) { - return errorOut() - << "callback produced a different number of handles than expected ( " - << result.getNumPayloadGroups() << " vs " << getNumResults() << " )"; - } - - for (OpResult value : getResults()) { - results.set(value, result.getPayloadGroup(value.getResultNumber())); - } - return DiagnosedSilenceableFailure::success(); -} - -void transform_ext::MatchCallbackOp::getEffects( - SmallVectorImpl &effects) { - mlir::transform::onlyReadsHandle(getInputsMutable(), effects); - mlir::transform::producesHandle(getOutputs(), effects); - // TODO: it doesn't really modify the payload, we need a separate resource for - // this mapping. - mlir::transform::modifiesPayload(effects); -} - -//===---------------------------------------------------------------------===// -// Callbacks for tests driven by RegisterMatchCallbacksOp -//===---------------------------------------------------------------------===// - -/// Match callback for "_test_match_callback" hook. Matches any payload -/// operations associated with operand handles unless they have the -/// "test.iree_transform_do_not_match" attribute, in which case produces a -/// silenceable failure. -static DiagnosedSilenceableFailure -testMatchCallbackCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - bool hadFailures = false; - for (Value handle : handles) { - if (llvm::any_of(state.getPayloadOps(handle), [](Operation *op) { - return op->hasAttr("test.iree_transform_do_not_match"); - })) { - res.addPayloadGroup(ArrayRef()); - hadFailures = true; - } else { - res.addPayloadGroup(state.getPayloadOps(handle)); - } - } - if (hadFailures) { - return emitSilenceableFailure(loc) << "failed to match"; - } - return DiagnosedSilenceableFailure::success(); -} - -static DiagnosedSilenceableFailure testRepeatedMatcherUseCallback( - transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, ValueRange handles) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - transform_ext::MatcherContext matcherContext; - auto &operand = transform_ext::m_StructuredOp(matcherContext); - auto &first = transform_ext::m_StructuredOp(matcherContext).input(0, operand); - auto &second = transform_ext::m_StructuredOp(matcherContext) - .input(0, operand) - .input(1, first); - - WalkResult walkResult = root->walk([&](Operation *op) { - second.resetCapture(); - if (!matchPattern(op, second)) { - return WalkResult::advance(); - } - - res.addPayloadGroup({first.getCaptured()}); - res.addPayloadGroup({second.getCaptured()}); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -static DiagnosedSilenceableFailure -testValueMatcherCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - transform_ext::MatcherContext matcherContext; - auto &operand = transform_ext::m_Value(matcherContext); - auto &first = transform_ext::m_StructuredOp(matcherContext).input(0, operand); - auto &second = transform_ext::m_StructuredOp(matcherContext) - .input(0, operand) - .input(1, first); - - WalkResult walkResult = root->walk([&](Operation *op) { - second.resetCapture(); - if (!matchPattern(op, second)) { - return WalkResult::advance(); - } - - res.addPayloadGroup({first.getCaptured()}); - res.addPayloadGroup({second.getCaptured()}); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -static DiagnosedSilenceableFailure testShapedValueMatcherCallback( - transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, ValueRange handles) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - int64_t rank; - SmallVector dims; - transform_ext::MatcherContext matcherContext; - auto &value = transform_ext::m_ShapedValue(matcherContext); - value.rank(transform_ext::CaptureRank(rank)) - .dim(transform_ext::AllDims(), transform_ext::CaptureDims(dims)); - auto &opMatcher = - transform_ext::m_Operation(matcherContext); - opMatcher.result(0, value); - - WalkResult walkResult = root->walk([&](Operation *op) { - opMatcher.resetCapture(); - if (!matchPattern(op, opMatcher)) { - return WalkResult::advance(); - } - - op->emitRemark() << "rank: " << rank; - std::string message; - llvm::raw_string_ostream os(message); - llvm::interleaveComma(dims, os); - os.flush(); - op->emitRemark() << "dimensions: " << message; - - res.addPayloadGroup({opMatcher.getCaptured()}); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -//===---------------------------------------------------------------------===// -// Callbacks for codegen driven by RegisterMatchCallbacksOp. -//===---------------------------------------------------------------------===// - -/// Match callback for a convolution with optional fill and trailing -/// elementwise operations. Matches *the first* occurrence of such a convolution -/// within an op associated with the given handle. -/// -/// Input handles: -/// -/// - container op, must be associated with one operation. -/// -/// Output handles: -/// -/// - the "fill" op preceding the convolution, if present; -/// - convolution op; -/// - trailing elementwise op, if any. -static DiagnosedSilenceableFailure -convolutionCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - - transform_ext::StructuredOpMatcher *pattern, *fill, *trailing; - transform_ext::MatchedConvolutionCaptures ignore; - transform_ext::MatcherContext matcherContext; - makeConvolutionMatcher(matcherContext, pattern, fill, trailing, ignore, - /*mustMatchEntireFunc=*/true); - - // TODO: need a mechanism for this to go around the entire IR, - // potentially with list matches for each group. - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - WalkResult walkResult = root->walk([&](Operation *op) { - pattern->resetCapture(); - if (!matchPattern(op, *pattern)) { - return WalkResult::advance(); - } - - // TODO: notify properly. - LLVM_DEBUG({ - DBGS() << "fill:\n"; - if (fill->getCaptured()) { - DBGS() << fill->getCaptured() << "\n"; - } - DBGS() << "pattern: " << pattern->getCaptured() << "\n"; - DBGS() << "trailing:\n"; - if (trailing->getCaptured()) { - DBGS() << trailing->getCaptured() << "\n"; - } - }); - - res.addPotentiallyEmptyPayloadGroup(fill->getCaptured()); - res.addPayloadGroup({pattern->getCaptured()}); - res.addPotentiallyEmptyPayloadGroup(trailing->getCaptured()); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -/// Match callback for a reduction with optional leading and trailing -/// elementwise operations. Matches *the first* occurrence of such a reduction -/// within an op associated with the given handle. -/// -/// Input handles: -/// -/// - container op, must be associated with one operation. -/// -/// Output handles: -/// -/// - leading elementwise op, if any; -/// - the "fill" op preceding the reduction; -/// - reduction op; -/// - trailing elementwise op, if any. -static DiagnosedSilenceableFailure -reductionCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles, bool mustMatchEntireFunc) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - - transform_ext::StructuredOpMatcher *pattern, *fill, *leading, *trailing; - transform_ext::MatchedReductionCaptures ignore; - transform_ext::MatcherContext matcherContext; - makeReductionMatcher(matcherContext, pattern, fill, leading, trailing, ignore, - mustMatchEntireFunc); - - // TODO: need a mechanism for this to go around the entire IR, - // potentially with list matches for each group. - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - WalkResult walkResult = root->walk([&](Operation *op) { - pattern->resetCapture(); - if (!matchPattern(op, *pattern)) { - return WalkResult::advance(); - } - - // TODO: notify properly. - LLVM_DEBUG({ - DBGS() << "leading:\n"; - if (leading->getCaptured()) { - DBGS() << leading->getCaptured() << "\n"; - } - DBGS() << "fill: " << fill->getCaptured() << "\n"; - DBGS() << "pattern: " << pattern->getCaptured() << "\n"; - DBGS() << "trailing:\n"; - if (trailing->getCaptured()) { - DBGS() << trailing->getCaptured() << "\n"; - } - }); - - res.addPotentiallyEmptyPayloadGroup(leading->getCaptured()); - res.addPayloadGroup({fill->getCaptured()}); - res.addPayloadGroup({pattern->getCaptured()}); - res.addPotentiallyEmptyPayloadGroup(trailing->getCaptured()); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -/// Match callback for a matmul with fill and optional trailing -/// elementwise operations. Matches *the first* occurrence of such a convolution -/// within an op associated with the given handle. -/// -/// Input handles: -/// -/// - container op, must be associated with one operation. -/// -/// Output handles: -/// -/// - the "fill" op preceding the convolution, if present; -/// - convolution op; -/// - trailing elementwise op, if any. -static DiagnosedSilenceableFailure -matmulCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - - transform_ext::StructuredOpMatcher *pattern, *fill, *trailing; - transform_ext::MatchedMatmulCaptures ignore; - transform_ext::MatcherContext matcherContext; - makeMatmulMatcher(matcherContext, pattern, fill, trailing, ignore, - /*mustMatchEntireFunc=*/true); - - // TODO: need a mechanism for this to go around the entire IR, - // potentially with list matches for each group. - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - WalkResult walkResult = root->walk([&](Operation *op) { - pattern->resetCapture(); - if (!matchPattern(op, *pattern)) { - return WalkResult::advance(); - } - - // TODO: notify properly. - LLVM_DEBUG({ - DBGS() << "fill:\n"; - if (fill->getCaptured()) { - DBGS() << fill->getCaptured() << "\n"; - } - DBGS() << "pattern: " << pattern->getCaptured() << "\n"; - DBGS() << "trailing:\n"; - if (trailing->getCaptured()) { - DBGS() << trailing->getCaptured() << "\n"; - } - }); - - res.addPayloadGroup({fill->getCaptured()}); - res.addPayloadGroup({pattern->getCaptured()}); - res.addPotentiallyEmptyPayloadGroup(trailing->getCaptured()); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -/// Match callback for linalg.batch_matmul and its linalg.generic equivalent fed -/// by a linalg.fill. -/// -/// Input handles: -/// -/// - the container op, must be associated with one operation. -/// -/// Output handles: -/// -/// - the fill op initializing the output; -/// - the main compute op. -static DiagnosedSilenceableFailure -batchMatmulCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - - transform_ext::StructuredOpMatcher *pattern, *fill; - transform_ext::MatchedMatmulCaptures ignore; - transform_ext::MatcherContext matcherContext; - transform_ext::makeBatchMatmulMatcher(matcherContext, pattern, fill, ignore, - /*mustMatchEntireFunc*/ true); - - // TODO: need a mechanism for this to go around the entire IR, - // potentially with list matches for each group. - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - WalkResult walkResult = root->walk([&](Operation *op) { - pattern->resetCapture(); - if (!matchPattern(op, *pattern)) { - return WalkResult::advance(); - } - - // TODO: notify properly - LLVM_DEBUG({ - DBGS() << "fill:" << fill->getCaptured() << "\n"; - DBGS() << "pattern: " << pattern->getCaptured() << "\n"; - }); - - res.addPayloadGroup({fill->getCaptured()}); - res.addPayloadGroup({pattern->getCaptured()}); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match batch matmul"; -} - -/// Match callback for a tensor.pad. Matches *the first* occurrence of such pad -/// within an op associated with the given handle. -/// -/// Input handles: -/// -/// - the container op, must be associated with one operation. -/// -/// Output handles: -/// -/// - the pad op. -static DiagnosedSilenceableFailure -padCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, ValueRange handles, - bool mustMatchEntireFunc) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - - transform_ext::CapturingOpMatcher *pattern; - transform_ext::MatchedPadCaptures ignore; - transform_ext::MatcherContext matcherContext; - makePadMatcher(matcherContext, pattern, ignore, mustMatchEntireFunc); - - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - WalkResult walkResult = root->walk([&](Operation *op) { - pattern->resetCapture(); - if (!matchPattern(op, *pattern)) { - return WalkResult::advance(); - } - - // TODO: notify properly. - LLVM_DEBUG({ - DBGS() << "pad:\n"; - if (pattern->getCaptured()) { - DBGS() << pattern->getCaptured() << "\n"; - } - }); - - res.addPayloadGroup({pattern->getCaptured()}); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -//===---------------------------------------------------------------------===// -// RegisterMatchCallbacksOp -//===---------------------------------------------------------------------===// - -DiagnosedSilenceableFailure transform_ext::RegisterMatchCallbacksOp::apply( - mlir::transform::TransformRewriter &rewriter, - mlir::transform::TransformResults &results, - mlir::transform::TransformState &state) { - auto ®istry = state.addExtension(); - registry.registerCallback("_test_match_callback", testMatchCallbackCallback); - registry.registerCallback("_test_repeated_matcher_use_callback", - testRepeatedMatcherUseCallback); - registry.registerCallback("_test_value_matcher_callback", - testValueMatcherCallback); - registry.registerCallback("_test_shaped_value_matcher_callback", - testShapedValueMatcherCallback); - registry.registerCallback("convolution", convolutionCallback); - registry.registerCallback("matmul", matmulCallback); - registry.registerCallback("batch_matmul", batchMatmulCallback); - registry.registerCallback("pad", wrapAsEntireFuncMatch(padCallback)); - registry.registerCallback("reduction", - wrapAsEntireFuncMatch(reductionCallback)); - registry.registerCallback("reduction_partial", - wrapAsPartialMatch(reductionCallback)); - return DiagnosedSilenceableFailure::success(); -} - -void transform_ext::RegisterMatchCallbacksOp::getEffects( - SmallVectorImpl &effects) { - // TODO: it doesn't really modify the payload, we need a separate resource for - // this mapping. - mlir::transform::modifiesPayload(effects); -} - -//===---------------------------------------------------------------------===// -// TakeFirstOp -//===---------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform_ext::TakeFirstOp::apply(mlir::transform::TransformRewriter &rewriter, - mlir::transform::TransformResults &results, - mlir::transform::TransformState &state) { - SmallVector concatenated; - bool found = false; - for (Value handle : getInputs()) { - auto payloads = state.getPayloadOps(handle); - if (payloads.empty()) { - continue; - } - if (!found) { - results.set(cast(getFirst()), payloads); - found = true; - } else { - llvm::append_range(concatenated, payloads); - } - } - - if (!found) { - results.set(cast(getFirst()), {}); - } - results.set(cast(getRest()), concatenated); - return DiagnosedSilenceableFailure::success(); -} - -void transform_ext::TakeFirstOp::getEffects( - SmallVectorImpl &effects) { - mlir::transform::onlyReadsHandle(getInputsMutable(), effects); - mlir::transform::producesHandle(getOperation()->getOpResults(), effects); -} - -//===---------------------------------------------------------------------===// -// EmitRemarkOp -//===---------------------------------------------------------------------===// - -DiagnosedSilenceableFailure transform_ext::EmitRemarkOp::applyToOne( - transform::TransformRewriter &rewriter, Operation *target, - mlir::transform::ApplyToEachResultList &results, - mlir::transform::TransformState &state) { - for (Operation *payload : state.getPayloadOps(getHandle())) { - payload->emitRemark(getMessage()); - } - return DiagnosedSilenceableFailure::success(); -} - -void transform_ext::EmitRemarkOp::getEffects( - SmallVectorImpl &effects) { - mlir::transform::onlyReadsHandle(getHandleMutable(), effects); - mlir::transform::onlyReadsPayload(effects); -} diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt deleted file mode 100644 index da2cbd97ae82..000000000000 --- a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ - -add_mlir_library(IREEDialectsTransforms - TransformMatchers.cpp - - LINK_LIBS PRIVATE - # TODO: break dialect dependency by implementing the transformation separately - # and registering it. - MLIRArithDialect - MLIRAsyncDialect - MLIRFuncDialect - MLIRLinalgDialect - MLIRLinalgTransforms - MLIRMathDialect - - DEPENDS - mlir-headers -) - -iree_dialects_target_includes(IREEDialectsTransforms) diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp deleted file mode 100644 index a90ff9b2c32f..000000000000 --- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp +++ /dev/null @@ -1,1845 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Transforms/TransformMatchers.h" - -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/Debug.h" - -using namespace mlir; - -#define DEBUG_TYPE "transform-matchers" -#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " -#define DBGSNL() llvm::dbgs() << "\n[" DEBUG_TYPE "] " - -//===---------------------------------------------------------------------===// -// CapturingMatcherBase -//===---------------------------------------------------------------------===// - -void transform_ext::CapturingMatcherBase::getAllNested( - SmallVectorImpl &nested) { - - SetVector found; - found.insert(nested.begin(), nested.end()); - int64_t start = found.size(); - - auto appendOne = [&found](CapturingMatcherBase &one) { - found.insert(one.nestedCapturingMatchers.begin(), - one.nestedCapturingMatchers.end()); - for (CapturingValueMatcher *valueMatcher : - one.nestedCapturingValueMatchers) { - found.insert(valueMatcher->nestedCapturingMatchers.begin(), - valueMatcher->nestedCapturingMatchers.end()); - } - }; - - appendOne(*this); - for (int64_t position = start; position < found.size(); ++position) { - appendOne(*found[position]); - } - - llvm::append_range(nested, found.getArrayRef()); -} - -void transform_ext::CapturingMatcherBase::getAllNestedValueMatchers( - SmallVectorImpl &nested) { - - SetVector found; - found.insert(nested.begin(), nested.end()); - int64_t start = found.size(); - - auto appendOne = [&found](CapturingMatcherBase &one) { - found.insert(one.nestedCapturingValueMatchers.begin(), - one.nestedCapturingValueMatchers.end()); - for (CapturingOpMatcher *opMatcher : one.nestedCapturingMatchers) { - found.insert(opMatcher->nestedCapturingValueMatchers.begin(), - opMatcher->nestedCapturingValueMatchers.end()); - } - }; - - appendOne(*this); - for (int64_t position = start; position < found.size(); ++position) { - appendOne(*found[position]); - } - - llvm::append_range(nested, found.getArrayRef()); -} - -void transform_ext::CapturingMatcherBase::resetCapture() { - SmallVector nested; - getAllNested(nested); - for (CapturingOpMatcher *matcher : nested) { - matcher->captured = nullptr; - } - SmallVector nestedValue; - getAllNestedValueMatchers(nestedValue); - for (CapturingValueMatcher *matcher : nestedValue) { - matcher->captured = nullptr; - } -} - -//===---------------------------------------------------------------------===// -// CapturingOpMatcher -//===---------------------------------------------------------------------===// - -bool transform_ext::CapturingOpMatcher::checkAllTilableMatched( - Operation *parent, Operation *op, - ArrayRef matchers) { - LLVM_DEBUG(DBGS() << "all tilable ops captured"); - int64_t numTilableOps = 0; - if (!parent) { - return false; - } - parent->walk([&](TilingInterface Op) { ++numTilableOps; }); - - llvm::SmallPtrSet matched; - for (CapturingOpMatcher *nested : matchers) { - if (Operation *captured = nested->getCaptured()) { - matched.insert(captured); - } - } - - // Don't forget to include the root matcher. - matched.insert(op); - return numTilableOps == matched.size(); -} - -bool transform_ext::CapturingOpMatcher::match(Operation *op) { - auto debugRAII = llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "-------\n"); }); - LLVM_DEBUG(DBGS() << "matching: " << *op << "\n"); - - if (getCaptured()) { - LLVM_DEBUG(DBGS() << "found an already captured op: "); - if (getCaptured() == op) { - LLVM_DEBUG(llvm::dbgs() << "same\n"); - return true; - } else { - LLVM_DEBUG(llvm::dbgs() << "different\n"); - return false; - } - } - - if (!llvm::all_of(predicates, [op](const PredicateFn &fn) { - bool result = fn(op); - LLVM_DEBUG(llvm::dbgs() << ": " << result << "\n"); - return result; - })) { - return false; - } - - captured = op; - return true; -} - -void transform_ext::CapturingOpMatcher::debugOutputForCreate( - ArrayRef opNames) { - LLVM_DEBUG(DBGS() << "operation type is one of {"; - llvm::interleaveComma(opNames, llvm::dbgs()); llvm::dbgs() << "}"); -} - -/// Apply the given matcher to the given object, produce debug messages. -template ::template args<0>> -static bool recursiveMatch(Matcher &matcher, Object &object, - StringRef extraMessage = "") { - LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "] " << "start recursive match (" - << extraMessage << ") {\n"); - bool result = matcher.match(object); - LLVM_DEBUG(DBGS() << "} end recursive match"); - return result; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::alternatives( - transform_ext::CapturingOpMatcher &first, - transform_ext::CapturingOpMatcher &second) { - addPredicate([&first, &second](Operation *op) { - LLVM_DEBUG(DBGS() << "matching alternatives\n"); - return recursiveMatch(first, op, "alternative 1") || - recursiveMatch(second, op, "alternative 2"); - }); - return *this; -} - -//---------------------------------------------------------------------------// -// Predicates for operands and results. -//---------------------------------------------------------------------------// - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::operand(transform_ext::NumEqualsTo num) { - addPredicate([=](Operation *op) { - LLVM_DEBUG(DBGS() << "operation has exactly " << num.value << " operands"); - return num.value == op->getNumOperands(); - }); - return *this; -} - -/// If `pos` is negative, returns the number of the operand in op starting from -/// the last. For example, -1 means the last operand, -2 means the -/// second-to-last, etc. Returns nullopt if pos is out-of-bounds, both positive -/// and negative. -static std::optional remapNegativeOperandNumber(int64_t pos, - Operation *op) { - int64_t updated = pos < 0 ? op->getNumOperands() + pos : pos; - if (updated < 0 || updated >= op->getNumOperands()) { - LLVM_DEBUG(DBGS() << "match operand #" << pos - << "that does not exist in the operation"); - return std::nullopt; - } - return updated; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::operand(int64_t pos, - CapturingOpMatcher &nested) { - addPredicate([pos, &nested](Operation *op) { - std::optional operandNo = remapNegativeOperandNumber(pos, op); - if (!operandNo) { - return false; - } - LLVM_DEBUG(DBGS() << "operand #" << pos << " is defined by an operation"); - Operation *definingOp = op->getOperand(*operandNo).getDefiningOp(); - if (!definingOp) { - return false; - } - return recursiveMatch(nested, definingOp); - }); - recordNestedMatcher(nested); - return *this; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::operand(int64_t pos, - CapturingValueMatcher &nested) { - addPredicate([pos, &nested](Operation *op) { - std::optional operandNo = remapNegativeOperandNumber(pos, op); - if (!operandNo) { - return false; - } - LLVM_DEBUG(DBGS() << "operand #" << pos << " is"); - Value operand = op->getOperand(*operandNo); - return recursiveMatch(nested, operand); - }); - recordNestedMatcher(nested); - return *this; -} - -transform_ext::CapturingOpMatcher &transform_ext::CapturingOpMatcher::operand( - int64_t position, std::function floatValueFn) { - addPredicate([position, - floatValueFn = std::move(floatValueFn)](Operation *op) -> bool { - std::optional operandNo = remapNegativeOperandNumber(position, op); - if (!operandNo) { - return false; - } - - LLVM_DEBUG(DBGS() << "operand #" << *operandNo - << " is a special floating point constant"); - auto cstOp = - op->getOperand(*operandNo).getDefiningOp(); - if (!cstOp) { - return false; - } - return floatValueFn(cstOp.value()); - }); - - return *this; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::operand(int64_t position, ConstantFloatOne) { - return operand(position, - [](llvm::APFloat value) { return value.isExactlyValue(1.0); }); -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::result(transform_ext::NumEqualsTo num) { - addPredicate([=](Operation *op) { - LLVM_DEBUG(DBGS() << "operation has exactly " << num.value << " results"); - return num.value == op->getNumResults(); - }); - return *this; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::result(int64_t pos, - CapturingValueMatcher &nested) { - addPredicate([pos, &nested](Operation *op) { - int64_t updated = pos < 0 ? op->getNumResults() + pos : pos; - if (updated < 0 || updated >= op->getNumResults()) { - LLVM_DEBUG(DBGS() << "matching result #" << pos - << " that does not exist in the operation"); - return false; - } - LLVM_DEBUG(DBGS() << "result #" << pos << " is"); - Value result = op->getResult(updated); - return recursiveMatch(nested, result); - }); - recordNestedMatcher(nested); - return *this; -} - -//===---------------------------------------------------------------------===// -// CapturingValueMatcher -//===---------------------------------------------------------------------===// - -namespace { -struct DebugPrintValueWrapper { - Value value; -}; - -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const DebugPrintValueWrapper &wrapper) { - if (auto opResult = dyn_cast(wrapper.value)) { - return os << "op result #" << opResult.getResultNumber() << " in " - << wrapper.value; - } - - auto blockArg = cast(wrapper.value); - os << "block argument #" << blockArg.getArgNumber(); - Block *parentBlock = blockArg.getParentBlock(); - Region *parentRegion = parentBlock->getParent(); - if (!parentRegion) { - os << " of a detached block:\n"; - parentBlock->print(os); - return os; - } - - os << " of block #" - << std::distance(parentRegion->begin(), parentBlock->getIterator()); - Operation *parentOp = parentRegion->getParentOp(); - if (!parentOp) { - os << " of a detached region:\n"; - for (Block &b : *parentRegion) { - b.print(os); - } - return os; - } - - os << " in region #" << parentRegion->getRegionNumber() << " of " - << *parentOp; - return os; -} -} // namespace - -bool transform_ext::CapturingValueMatcher::match(Value value) { - auto debugRAII = llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "-------\n"); }); - LLVM_DEBUG(DBGS() << "matching " << DebugPrintValueWrapper{value} << "\n"); - - if (getCaptured()) { - LLVM_DEBUG(DBGS() << "found an already captured value: "); - if (getCaptured() == value) { - LLVM_DEBUG(llvm::dbgs() << "same\n"); - return true; - } else { - LLVM_DEBUG(llvm::dbgs() << "different\n"); - return false; - } - } - - for (const PredicateFn &fn : predicates) { - bool result = fn(value); - LLVM_DEBUG(llvm::dbgs() << ": " << result << "\n"); - if (!result) { - return false; - } - } - - captured = value; - return true; -} - -transform_ext::ShapedValueMatcher::ShapedValueMatcher() - : CapturingValueMatcher() { - addPredicate([](Value value) { - LLVM_DEBUG(DBGS() << "value is of shaped type"); - return value && isa(value.getType()); - }); -} - -transform_ext::ShapedValueMatcher & -transform_ext::ShapedValueMatcher::rank(transform_ext::CaptureRank capture) { - addPredicate([=](Value value) { - LLVM_DEBUG(DBGS() << "capturing shaped value rank"); - capture.value = cast(value.getType()).getRank(); - return true; - }); - return *this; -} - -transform_ext::ShapedValueMatcher & -transform_ext::ShapedValueMatcher::dim(int64_t dimension, CaptureDim capture) { - addPredicate([=](Value value) { - LLVM_DEBUG(DBGS() << "capturing shaped value dimension " << dimension); - capture.value = cast(value.getType()).getDimSize(dimension); - return true; - }); - return *this; -} - -transform_ext::ShapedValueMatcher & -transform_ext::ShapedValueMatcher::dim(AllDims tag, CaptureDims captures) { - (void)tag; - addPredicate([=](Value value) { - LLVM_DEBUG(DBGS() << "capturing all shaped value dimensions"); - ArrayRef shape = cast(value.getType()).getShape(); - captures.value.assign(shape.begin(), shape.end()); - return true; - }); - return *this; -} - -transform_ext::ShapedValueMatcher & -transform_ext::ShapedValueMatcher::elementType(CaptureElementType captures) { - addPredicate([=](Value value) { - LLVM_DEBUG(DBGS() << "capturing elementType"); - captures.value = cast(value.getType()).getElementType(); - return true; - }); - return *this; -} - -//===---------------------------------------------------------------------===// -// Constraints on op rank and dims. -//===---------------------------------------------------------------------===// - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::rank(NumGreaterEqualTo minRank) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "rank >= " << minRank.value); - return linalgOp.getNumLoops() >= minRank.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::rank(NumLowerEqualTo maxRank) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "rank <= " << maxRank.value); - return linalgOp.getNumLoops() <= maxRank.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::rank(NumEqualsTo exactRank) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "rank == " << exactRank.value); - return linalgOp.getNumLoops() == exactRank.value; - }); -} - -StringRef stringifyShapeKind(transform_ext::ShapeKind kind) { - switch (kind) { - case transform_ext::ShapeKind::Static: - return "static"; - case transform_ext::ShapeKind::Dynamic: - return "dynamic"; - } - llvm_unreachable("unhandled shape kind"); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(SmallVector &&dimensions, - ShapeKind kind) { - return addPredicate([dimensions = std::move(dimensions), - kind](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "dimensions ["; - llvm::interleaveComma(dimensions, llvm::dbgs()); - llvm::dbgs() << "] are " << stringifyShapeKind(kind)); - SmallVector shape = linalgOp.getStaticLoopRanges(); - for (auto dimension : dimensions) { - int64_t transformedDimension = - dimension >= 0 ? dimension : shape.size() + dimension; - if (transformedDimension < 0 || transformedDimension >= shape.size()) { - return false; - } - if (ShapedType::isDynamic(shape[transformedDimension]) ^ - (kind == ShapeKind::Static)) { - continue; - } - return false; - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all dimensions are " << stringifyShapeKind(kind)); - SmallVector shape = linalgOp.getStaticLoopRanges(); - return llvm::all_of(shape, [=](int64_t dimension) { - return ShapedType::isDynamic(dimension) ^ (kind == ShapeKind::Static); - }); - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(SmallVector &&dimensions, - utils::IteratorType kind) { - return addPredicate([dimensions = std::move(dimensions), - kind](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "dimensions ["; - llvm::interleaveComma(dimensions, llvm::dbgs()); - llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind)); - int64_t rank = linalgOp.getNumLoops(); - for (auto dimension : dimensions) { - int64_t transformedDimension = - dimension >= 0 ? dimension : rank + dimension; - if (transformedDimension < 0 || transformedDimension >= rank) { - return false; - } - utils::IteratorType iteratorKind = - linalgOp.getIteratorTypesArray()[transformedDimension]; - if (iteratorKind == kind) { - continue; - } - return false; - } - return true; - }); -} -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDims tag, utils::IteratorType kind) { - return dim(AllDimsExcept({}), kind); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDimsExcept &&dims, - utils::IteratorType kind) { - return addPredicate([dimensions = std::move(dims), - kind](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all dimensions except ["; - llvm::interleaveComma(dimensions.getExcluded(), llvm::dbgs()); - llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind)); - int64_t rank = linalgOp.getNumLoops(); - llvm::SmallDenseSet excludedDims; - for (int64_t dim : dimensions.getExcluded()) { - excludedDims.insert(dim >= 0 ? dim : rank + dim); - } - - for (auto [index, type] : - llvm::enumerate(linalgOp.getIteratorTypesArray())) { - if (excludedDims.contains(index)) { - continue; - } - if (type == kind) { - continue; - } - return false; - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(int64_t dimension, - DivisibleBy divisibleBy) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "dimension " << dimension << " is divisible by " - << divisibleBy.value); - int64_t rank = linalgOp.getNumLoops(); - int64_t transformedDimension = - dimension >= 0 ? dimension : rank + dimension; - if (transformedDimension >= rank) { - return false; - } - - int64_t size = linalgOp.getStaticLoopRanges()[transformedDimension]; - return !ShapedType::isDynamic(size) && (size % divisibleBy.value == 0); - }); -} - -//===---------------------------------------------------------------------===// -// Capture directives. -//===---------------------------------------------------------------------===// -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::rank(CaptureRank capture) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture rank"); - capture.value = linalgOp.getNumLoops(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(int64_t dimension, CaptureDim capture) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture dimension"); - int64_t rank = linalgOp.getNumLoops(); - int64_t transformedDimension = - dimension >= 0 ? dimension : rank + dimension; - if (transformedDimension >= rank) { - return false; - } - - capture.value = linalgOp.getStaticLoopRanges()[transformedDimension]; - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDims tag, CaptureDims captures) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture all dimensions"); - captures.value = linalgOp.getStaticLoopRanges(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::indexingMaps( - CaptureIndexingMaps indexingMaps) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture indexing maps"); - indexingMaps.value = linalgOp.getIndexingMapsArray(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::contractionDims( - CaptureContractionDims contractionDims) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture contraction dimensions"); - StringRef convMessage = linalg::detail::getMatchContractionMessage( - mlir::linalg::detail::isContractionInterfaceImpl( - linalgOp, &contractionDims.value)); - if (convMessage.empty()) { - return true; - } - LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")"); - return false; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::convolutionDims(CaptureConvDims convDims) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture convolution dimensions"); - StringRef convMessage = linalg::detail::getMatchConvolutionMessage( - mlir::linalg::detail::isConvolutionInterfaceImpl(linalgOp, - &convDims.value)); - if (convMessage.empty()) { - return true; - } - LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")"); - return false; - }); -} - -transform_ext::StructuredOpMatcher::StructuredOpMatcher( - StructuredOpMatcher &A, StructuredOpMatcher &B) { - - addPredicate([&A, &B](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "start recursive lhs OR match {\n"); - { - auto debugRAII = llvm::scope_exit( - [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (A.match(linalgOp)) { - return true; - } - } - LLVM_DEBUG(DBGS() << "start recursive rhs OR match {\n"); - { - auto debugRAII = llvm::scope_exit( - [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (B.match(linalgOp)) { - return true; - } - } - return false; - }); - recordNestedMatcher(A); - recordNestedMatcher(B); -} - -//===---------------------------------------------------------------------===// -// Constraints on input operands. -//===---------------------------------------------------------------------===// - -void transform_ext::StructuredOpMatcher::addInputMatcher( - int64_t position, std::function matcher, - OptionalMatch optional) { - addInputMatcher( - position, - // No need to handle optional inside the lambda, the wrapper will do that. - [matcher = std::move(matcher)](Value value) { - Operation *definingOp = value.getDefiningOp(); - return definingOp && matcher(definingOp); - }, - optional); -} - -void transform_ext::StructuredOpMatcher::addInputMatcher( - int64_t position, std::function matcher, - OptionalMatch optional) { - addPredicate([position, optional, matcher = std::move(matcher)]( - linalg::LinalgOp linalgOp) -> bool { - int64_t transformedPosition = - position >= 0 ? position : linalgOp.getNumDpsInputs() + position; - if (transformedPosition >= linalgOp.getNumDpsInputs()) { - LLVM_DEBUG(DBGS() << "input operand #" << position - << " does not exist but match required"); - return false; - } - - LLVM_DEBUG(DBGS() << "input operand #" << position - << (optional.value ? " (optional match) " : " ") - << "is\n"); - - // We MUST run the matcher at this point, even if the match is optional, - // to allow for capture. - LLVM_DEBUG(DBGS() << "start recursive match {\n"); - auto debugRAII = - llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (matcher(linalgOp.getDpsInputOperand(transformedPosition)->get())) { - return true; - } - return optional.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(AllOperands tag, IsPermutation) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all input operands have permutation maps"); - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand *operand : linalgOp.getDpsInputOperands()) { - if (!linalgOp.getMatchingIndexingMap(operand).isPermutation()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(AllOperands tag, - IsProjectedPermutation) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all input operands have projected permutation maps"); - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand *operand : linalgOp.getDpsInputOperands()) { - if (!linalgOp.getMatchingIndexingMap(operand).isProjectedPermutation()) { - return false; - } - } - return true; - }); -} - -/// Helper to check if the map is an identity map with a projected dim. -static bool isProjectedMap(AffineMap map, int64_t projectedDim) { - if (!map.isProjectedPermutation()) { - return false; - } - int64_t dimCounter = 0; - for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { - // Skip the project dim. - if (dimCounter == projectedDim) { - dimCounter++; - } - if (map.getDimPosition(i) != dimCounter++) { - return false; - } - } - return true; -} - -/// Helper to turn a potentially negative index to positive within the range -/// [0, ub) and indicate whether the transformed index is in bounds. -static bool makeValidPositiveIndex(int64_t &index, int64_t ub) { - int64_t positiveIndex = index >= 0 ? index : ub + index; - if (positiveIndex < 0 || ub < positiveIndex) { - LLVM_DEBUG(DBGSNL() << " index out of range"); - return false; - } - index = positiveIndex; - return true; -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(SmallVector &&positions, - IsProjected dim) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "operands "; - llvm::interleaveComma(positions, llvm::dbgs()); - llvm::dbgs() << " have a permutation maps with " << dim.value - << " projected"); - int64_t updatedDim = dim.value; - if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops())) { - return false; - } - for (int64_t position : positions) { - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, - linalgOp.getNumDpsInputs())) { - return false; - } - OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); - if (!isProjectedMap(linalgOp.getMatchingIndexingMap(operand), - updatedDim)) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(AllOperands tag, IsIdentity) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all input operands have identity maps"); - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand *operand : linalgOp.getDpsInputOperands()) { - if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(SmallVector &&positions, - IsIdentity) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operands "; - llvm::interleaveComma(positions, llvm::dbgs()); - llvm::dbgs() << " have identity maps"); - // all_of with a lambda requires const-casting dance, so using a loop. - for (int64_t position : positions) { - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, - linalgOp.getNumDpsInputs())) { - return false; - } - OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); - if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, - ElementTypeBitWidth width) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operand #" << position - << " has elemental type with bit width " << width.value); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); - return shapedType && shapedType.getElementType().isIntOrFloat() && - shapedType.getElementType().getIntOrFloatBitWidth() == width.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, - CaptureElementTypeBitWidth width) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operand #" << position << " capture bitwidth"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); - if (!shapedType || !shapedType.getElementType().isIntOrFloat()) { - return false; - } - width.value = shapedType.getElementType().getIntOrFloatBitWidth(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, - CaptureElementType elem) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operand #" << position - << " capture element type"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); - if (!shapedType) { - LLVM_DEBUG(DBGSNL() << " not a shaped type"); - return false; - } - elem.value = shapedType.getElementType(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(NumEqualsTo num) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "number of input operands == " << num.value); - return linalgOp.getNumDpsInputs() == num.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, - ConstantFloatMinOrMinusInf) { - return input(position, [](llvm::APFloat f) { - return (f.isLargest() || f.isInfinity()) && f.isNegative(); - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, ConstantFloatZero) { - return input(position, [](llvm::APFloat f) { return f.isZero(); }); -} - -transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::input( - int64_t position, std::function floatValueFn) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operand #" << position - << " is a special floating point constant"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { - return false; - } - auto cstOp = linalgOp.getDpsInputOperand(updatedPosition) - ->get() - .getDefiningOp(); - if (!cstOp) { - return false; - } - return floatValueFn(cstOp.value()); - }); -} - -//===---------------------------------------------------------------------===// -// Constraints on output operands. -//===---------------------------------------------------------------------===// - -void transform_ext::StructuredOpMatcher::addOutputMatcher( - int64_t position, std::function matcher, - OptionalMatch optional) { - addPredicate([position, optional, matcher = std::move(matcher)]( - linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position - << (optional.value ? " (optional match) " - : " (mandatory match) ") - << "is produced by\n"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - Operation *definingOp = - linalgOp.getDpsInitOperand(updatedPosition)->get().getDefiningOp(); - if (!definingOp) { - return optional.value; - } - // We MUST run the matcher at this point, even if the match is optional, - // to allow for capture. - LLVM_DEBUG(DBGS() << "start recursive match {\n"); - auto debugRAII = - llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (matcher(definingOp)) { - return true; - } - return optional.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, IsPermutation) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all output operands have permutation maps"); - for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!linalgOp.getMatchingIndexingMap(&operand).isPermutation()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, - IsProjectedPermutation) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all output operands have projected permutation maps"); - for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!linalgOp.getMatchingIndexingMap(&operand).isProjectedPermutation()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, IsProjected dim) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all output operands have a maps with projected"); - int64_t updatedDim = dim.value; - if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops())) { - return false; - } - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!isProjectedMap(linalgOp.getMatchingIndexingMap(&operand), - updatedDim)) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, IsIdentity) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all output operands have identity permutation maps"); - for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!linalgOp.getMatchingIndexingMap(&operand).isIdentity()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - ElementTypeBitWidth width) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position - << " has elemental type with bit width " << width.value); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); - return shapedType && shapedType.getElementType().isIntOrFloat() && - shapedType.getElementType().getIntOrFloatBitWidth() == width.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - CaptureElementTypeBitWidth width) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position << " capture bitwidth"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); - if (!shapedType || !shapedType.getElementType().isIntOrFloat()) { - LLVM_DEBUG(DBGSNL() << " could not infer element type"); - return false; - } - width.value = shapedType.getElementType().getIntOrFloatBitWidth(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - CaptureElementType elem) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position - << " capture element type"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); - if (!shapedType) { - LLVM_DEBUG(DBGSNL() << " not a shaped type"); - return false; - } - elem.value = shapedType.getElementType(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - SingleCombinerReduction tag) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position - << " is populated by a single-combiner reduction"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - SmallVector combinerOps; - return matchReduction(linalgOp.getRegionOutputArgs(), updatedPosition, - combinerOps) && - llvm::hasSingleElement(combinerOps); - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(NumEqualsTo num) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "number of output operands == " << num.value); - return linalgOp.getNumDpsInits() == num.value; - }); -} - -//===---------------------------------------------------------------------===// -// Constraints on results. -//===---------------------------------------------------------------------===// - -void transform_ext::StructuredOpMatcher::addResultMatcher( - int64_t position, HasAnyUse tag, std::function matcher, - OptionalMatch optional) { - addPredicate([matcher = std::move(matcher), optional, - position](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "result #" << position - << (optional.value ? " (optional match) " - : " (mandatory match) ") - << "has a use\n"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp->getNumResults())) { - return false; - } - - // We MUST run the matcher at this point, even if the match is optional, - // to allow for capture. - LLVM_DEBUG(DBGS() << "start recursive match {\n"); - auto debugRAII = - llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (llvm::any_of(linalgOp->getResult(updatedPosition).getUsers(), - [&matcher](Operation *op) { return matcher(op); })) { - return true; - } - return optional.value; - }); -} - -//===-------------------------------------------------------------------===// -// Constraints on op region. -//===-------------------------------------------------------------------===// - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::singleOpWithCanonicaleArgs( - StringRef opcode, bool commutative) { - return addPredicate([=](linalg::LinalgOp linalgOp) { - if (linalgOp.getBlock()->getOperations().size() != 2) { - return false; - } - Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); - if (innerOp->getName().getStringRef() != opcode || - innerOp->getNumResults() != 1) { - return false; - } - Operation *yieldOp = linalgOp.getBlock()->getTerminator(); - if (yieldOp->getNumOperands() != 1) { - return false; - } - if (yieldOp->getOperand(0).getDefiningOp() != innerOp) { - return false; - } - if (commutative && innerOp->getNumOperands() == 2) { - auto arg0 = dyn_cast(innerOp->getOperand(0)); - auto arg1 = dyn_cast(innerOp->getOperand(1)); - if (!arg0 || !arg1) { - return false; - } - if (arg0.getParentBlock() != linalgOp.getBlock() || - arg1.getParentBlock() != linalgOp.getBlock()) { - return false; - } - if (!((arg0.getArgNumber() == 0 && arg1.getArgNumber() == 1) || - (arg1.getArgNumber() == 0 && arg0.getArgNumber() == 1))) { - return false; - } - } else { - for (auto [index, operand] : llvm::enumerate(innerOp->getOperands())) { - auto arg = dyn_cast(operand); - if (!arg || arg.getParentBlock() != linalgOp.getBlock() || - arg.getArgNumber() != index) { - return false; - } - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::isFloatReciprocal() { - return addPredicate([=](linalg::LinalgOp linalgOp) { - LLVM_DEBUG(DBGS() << "op region represents a reciprocal operation"); - if (linalgOp.getBlock()->getOperations().size() != 2) { - return false; - } - Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); - if (!isa(innerOp) || innerOp->getNumResults() != 1) { - return false; - } - Operation *yieldOp = linalgOp.getBlock()->getTerminator(); - if (yieldOp->getNumOperands() != 1) { - return false; - } - if (yieldOp->getOperand(0).getDefiningOp() != innerOp) { - return false; - } - auto cst = innerOp->getOperand(0).getDefiningOp(); - if (!cst || cst.value().convertToDouble() != 1.0) { - return false; - } - auto arg = dyn_cast(innerOp->getOperand(1)); - if (!arg || arg.getParentBlock() != linalgOp.getBlock() || - arg.getArgNumber() != 0) { - return false; - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::passThroughOp() { - return addPredicate([=](linalg::LinalgOp linalgOp) { - if (linalgOp.getBlock()->getOperations().size() != 1) { - return false; - } - Operation *yieldOp = linalgOp.getBlock()->getTerminator(); - for (auto [index, operand] : llvm::enumerate(yieldOp->getOperands())) { - auto arg = dyn_cast(operand); - if (!arg || arg.getParentBlock() != linalgOp.getBlock() || - arg.getArgNumber() != index) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::hasContractionBody( - function_ref isaElemOpTy, - function_ref isaReductionOpTy, StringRef elemOpName, - StringRef reductionOpName) { - return addPredicate([=](linalg::LinalgOp linalgOp) { - LLVM_DEBUG(DBGS() << "op region is a " << elemOpName << "/" - << reductionOpName << " contraction ("); - auto scopeExitPrinter = - llvm::scope_exit([] { LLVM_DEBUG(llvm::dbgs() << " check failed)"); }); - - Block *body = linalgOp.getBlock(); - if (!llvm::hasNItems(*body, 3)) { - LLVM_DEBUG(llvm::dbgs() << "three-operation body"); - return false; - } - if (body->getNumArguments() != 3) { - LLVM_DEBUG(llvm::dbgs() << "three-argument block"); - return false; - } - - Operation *elemOp = &(*linalgOp.getBlock()->getOperations().begin()); - Operation *reductionOp = elemOp->getNextNode(); - Operation *yieldOp = reductionOp->getNextNode(); - if (!isaElemOpTy(elemOp)) { - LLVM_DEBUG(llvm::dbgs() << "first operation is a " << elemOpName); - return false; - } - if (!isaReductionOpTy(reductionOp)) { - LLVM_DEBUG(llvm::dbgs() << "second operation is a " << reductionOpName); - return false; - } - if (yieldOp->getNumOperands() != 1) { - LLVM_DEBUG(llvm::dbgs() << "one value yielded"); - return false; - } - if (yieldOp->getOperand(0).getDefiningOp() != reductionOp) { - LLVM_DEBUG(llvm::dbgs() << "yielded value produced by the second op"); - return false; - } - if (elemOp->getNumOperands() != 2 || elemOp->getNumResults() != 1) { - LLVM_DEBUG(llvm::dbgs() << "first op has two operands and one result"); - return false; - } - if (reductionOp->getNumOperands() != 2 || - reductionOp->getNumResults() != 1) { - LLVM_DEBUG(llvm::dbgs() << "second op has two operands and one result"); - return false; - } - - SmallVector expectedReductionOperands = {body->getArgument(2), - elemOp->getResult(0)}; - if (!llvm::equal(expectedReductionOperands, reductionOp->getOperands()) && - !llvm::equal(llvm::reverse(expectedReductionOperands), - reductionOp->getOperands())) { - LLVM_DEBUG(llvm::dbgs() << "operands of the second op"); - return false; - } - - ValueRange expectedElemOperands = body->getArguments().take_front(2); - if (!llvm::equal(expectedElemOperands, elemOp->getOperands()) && - !llvm::equal(llvm::reverse(expectedElemOperands), - elemOp->getOperands())) { - LLVM_DEBUG(llvm::dbgs() << "operands of the first op"); - return false; - } - - scopeExitPrinter.release(); - LLVM_DEBUG(llvm::dbgs() << "success)"); - return true; - }); -} - -void transform_ext::detail::debugOutputForConcreteOpMatcherConstructor( - StringRef name) { - LLVM_DEBUG(DBGS() << "op is a " << name << "'"); -} - -//===---------------------------------------------------------------------===// -// TensorPadOpMatcher -//===---------------------------------------------------------------------===// - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::low(ArrayRef sizes) { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG({ - DBGS() << "low pad sizes are "; - llvm::interleaveComma(sizes, llvm::dbgs()); - }); - for (auto [ofr, sz] : llvm::zip(tensorPad.getMixedLowPad(), sizes)) { - if (isConstantIntValue(ofr, sz)) { - return false; - } - } - return true; - }); -} - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::low(AllDims tag, int64_t size) { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG(DBGS() << "all low pad sizes are " << size); - return llvm::all_of(tensorPad.getMixedLowPad(), [&](OpFoldResult ofr) { - return isConstantIntValue(ofr, size); - }); - }); -} - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::high(ArrayRef sizes) { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG({ - DBGS() << "high pad sizes are "; - llvm::interleaveComma(sizes, llvm::dbgs()); - }); - for (auto [ofr, sz] : llvm::zip(tensorPad.getMixedHighPad(), sizes)) { - if (isConstantIntValue(ofr, sz)) { - return false; - } - } - return true; - }); -} - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::high(AllDims tag, int64_t size) { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG(DBGS() << "all high pad sizes are " << size); - return llvm::all_of(tensorPad.getMixedHighPad(), [&](OpFoldResult ofr) { - return isConstantIntValue(ofr, size); - }); - }); -} - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::yieldsExternalValue() { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG(DBGS() << "pad body yields an externally-defined value"); - Block *body = tensorPad.getBody(); - if (!llvm::hasSingleElement(*body)) { - return false; - } - return llvm::all_of(body->getTerminator()->getOperands(), - [body](Value operand) { - auto arg = dyn_cast(operand); - return !arg || arg.getOwner() != body; - }); - }); -} - -//===---------------------------------------------------------------------===// -// MatchCallbackResult. -//===---------------------------------------------------------------------===// - -ArrayRef -transform_ext::MatchCallbackResult::getPayloadGroup(int64_t position) const { - assert(position < payloadGroupLengths.size()); - int64_t start = 0; - for (int64_t i = 0; i < position; ++i) { - start += payloadGroupLengths[i]; - } - return llvm::ArrayRef(payloadOperations) - .slice(start, payloadGroupLengths[position]); -} - -//===---------------------------------------------------------------------===// -// Case-specific matcher builders. -//===---------------------------------------------------------------------===// - -static constexpr int64_t kCudaWarpSize = 32; - -void transform_ext::makeReductionMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&reductionCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::StructuredOpMatcher *&leadingCapture, - transform_ext::StructuredOpMatcher *&trailingCapture, - MatchedReductionCaptures &captures, bool mustMatchEntireFunc) { - // The core part of the matcher is anchored on a particular reduction op. - auto &reduction = - m_StructuredOp(matcherContext) - // Op has at least a parallel a reduction dimension and at - // most 3 parallel dimensions. - // TODO: relax once we have global collapse/expand_shape. - // - .rank(NumGreaterEqualTo(2)) - .rank(NumLowerEqualTo(4)) - .rank(CaptureRank(captures.reductionRank)) - // Op has a single most-minor reduction. - .dim(-1, utils::IteratorType::reduction) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.reductionOpSizes)) - // All other dimensions are parallel. - .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) - // Single input for now, can be arbitrary projected permutations. - // TODO: Multiple inputs, can be arbitrary projected permutations. - // TODO: Watch out for multiple inputs though as a reduction turns - // into a contraction when mixed with projected - // permutations. A reduction is often bandwidth bound but - // contraction is a different beast that is compute bound - // and has a very different schedule. - // - .input(NumEqualsTo(1)) - .input(AllOperands(), IsProjectedPermutation()) - // Single output supported atm. - // TODO: Multiple outputs. - // - .output(NumEqualsTo(1)) - // A reduction output must be a projected permutation, match it but we - // could also drop this technically. - .output(AllOperands(), IsProjectedPermutation()) - // Only single combiner for now due to reduction warp - // distribution. - // TODO: relax this once reduction distribution is more powerful. - // - .output(0, CaptureElementTypeBitWidth( - captures.reductionOutputElementalTypeBitWidth)) - .output(0, SingleCombinerReduction()); - reductionCapture = &reduction; - - // Mandatory FillOp must create the unique output of the reduction. - // TODO: Relax this, as any map, broadcast, transpose should also work. - // - auto &fill = m_StructuredOp(matcherContext); - reduction = reduction.output(NumEqualsTo(1)).output(0, fill); - fillCapture = &fill; - - // Optional leading or trailing op can be any map, transpose, broadcast but - // not reduce or windowing operation for now. - // It must create the unique input for the reduction. - // TODO: match more optional leading ops, one per input of the reduction. - // TODO: careful about multi-output and turning into a contraction. - // - transform_ext::StructuredOpMatcher commonLeadingOrTrailing = - m_StructuredOp(matcherContext) - // All parallel dimensions. - .dim(AllDims(), utils::IteratorType::parallel) - // All inputs are any projected permutation. - .input(AllOperands(), IsProjectedPermutation()) - .output(AllOperands(), IsPermutation()) - // leading and trailing may have 0, 1 or more input as long as they do - // not come from unmatched ops. This extra constraint is taken care of - // separately. This is also a noop but we document it. - // TODO: Base and derived classes, atm this does not compile. - // .input(NumGreaterEqualTo(0)) - // Single output supported atm. - // TODO: extend this. - // - .output(NumEqualsTo(1)); - // TODO: match more optional leading ops, one per input of the reduction. - // TODO: careful about multi-output and turning into a contraction. - // - auto &leading = - m_StructuredOp(matcherContext, commonLeadingOrTrailing) - .rank(CaptureRank(captures.maybeLeadingRank)) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.leadingOpSizes)) - // Capture output elemental type. - .output(0, CaptureElementTypeBitWidth( - captures.maybeLeadingOutputElementalTypeBitWidth)); - reduction = reduction.input(0, leading, OptionalMatch()); - leadingCapture = &leading; - - // Optional trailing can be any map, transpose, broadcast but not reduce or - // windowing operation for now. - // It must be fed by the unique input for the reduction. - // TODO: match more optional leading ops, one per input of the reduction. - // TODO: careful about multi-output and turning into a contraction. - // - auto &trailing = - m_StructuredOp(matcherContext, commonLeadingOrTrailing) - .rank(CaptureRank(captures.maybeTrailingRank)) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.trailingOpSizes)) - // Capture output elemental type. - .output(0, CaptureElementTypeBitWidth( - captures.maybeTrailingOutputElementalTypeBitWidth)); - reduction = reduction.result(0, HasAnyUse(), trailing, OptionalMatch()); - if (mustMatchEntireFunc) { - reduction = reduction.allTilableOpsCaptured(); - } - trailingCapture = &trailing; -} - -void transform_ext::makeReductionMatcher(transform_ext::MatcherContext &context, - StructuredOpMatcher *&reductionCapture, - MatchedReductionCaptures &captures, - bool mustMatchEntireFunc) { - StructuredOpMatcher *fill; - StructuredOpMatcher *leading; - StructuredOpMatcher *trailing; - makeReductionMatcher(context, reductionCapture, fill, leading, trailing, - captures, mustMatchEntireFunc); -} - -void transform_ext::makeMatmulMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&matmulCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::StructuredOpMatcher *&trailingCapture, - transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { - auto &matmul = transform_ext::m_StructuredOp(matcherContext) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) - // Capture input/output element types. - .input(0, CaptureElementType(captures.lhsElementType)) - .input(1, CaptureElementType(captures.rhsElementType)) - .output(0, CaptureElementType(captures.outputElementType)); - matmulCapture = &matmul; - // Mandatory FillOp must create the unique output of the reduction. - auto &fill = transform_ext::m_StructuredOp(matcherContext); - matmul = matmul.output(transform_ext::NumEqualsTo(1)).output(0, fill); - fillCapture = &fill; - - auto &trailing = m_StructuredOp(matcherContext); - matmul = matmul.result(0, HasAnyUse(), trailing, OptionalMatch()); - if (mustMatchEntireFunc) { - matmul = matmul.allTilableOpsCaptured(); - } - trailingCapture = &trailing; -} - -void transform_ext::makeBatchMatmulMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&bmmCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { - auto &bmm = - transform_ext::m_StructuredOp( - matcherContext) - .hasContractionBody() - .rank(NumEqualsTo(4)) - .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) - .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) - .dim(-1, utils::IteratorType::reduction) - .contractionDims(CaptureContractionDims(captures.contractionDims)) - .input(NumEqualsTo(2)) - .input(0, CaptureElementType(captures.lhsElementType)) - .input(1, CaptureElementType(captures.rhsElementType)) - .output(0, CaptureElementType(captures.outputElementType)); - bmmCapture = &bmm; - - auto &fill = transform_ext::m_StructuredOp(matcherContext); - bmm = bmm.output(0, fill); - fillCapture = &fill; - - if (mustMatchEntireFunc) { - bmm = bmm.allTilableOpsCaptured(); - } -} - -/// Match sum(%src, broadcast(%reduction)) -static void -matchSubBroadcast(transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher &maxReduction, - transform_ext::CapturingValueMatcher &softmaxSourceOperand, - transform_ext::StructuredOpMatcher *&sub) { - using namespace transform_ext; - - auto &broadcast = - transform_ext::m_StructuredOp(matcherContext) - .passThroughOp() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(0, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - broadcast = broadcast.input(0, maxReduction); - - auto &subParallel = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsIdentity()) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - subParallel = subParallel.input(0, softmaxSourceOperand); - subParallel = subParallel.input(1, broadcast); - - auto &subBroadcast = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - subBroadcast = subBroadcast.input(0, softmaxSourceOperand); - subBroadcast = subBroadcast.input(1, maxReduction); - auto &subOr = transform_ext::m_StructuredOp_Or(matcherContext, subBroadcast, - subParallel); - sub = &subOr; -} - -/// Match sum(%exp, broadcast(%sum)) -static void matchdivBroadcast(transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher &expOperand, - transform_ext::StructuredOpMatcher &sum, - transform_ext::StructuredOpMatcher *&div) { - using namespace transform_ext; - - auto &broadcast = - transform_ext::m_StructuredOp(matcherContext) - .passThroughOp() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(0, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - broadcast = broadcast.input(0, sum); - - auto &divNoBroadcast = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsIdentity()) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - - divNoBroadcast = divNoBroadcast.input(0, expOperand); - divNoBroadcast = divNoBroadcast.input(1, broadcast); - - auto &divBroadcast = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - - divBroadcast = divBroadcast.input(0, expOperand); - divBroadcast = divBroadcast.input(1, sum); - - auto &divMerge = transform_ext::m_StructuredOp_Or( - matcherContext, divNoBroadcast, divBroadcast); - div = &divMerge; -} - -void transform_ext::makeSoftmaxMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&maxReductionCapture, - transform_ext::StructuredOpMatcher *&softmaxRootCapture) { - auto &softmaxSourceOperand = m_Value(matcherContext); - - auto &fillMinusInf = m_StructuredOp(matcherContext) - .input(0, ConstantFloatMinOrMinusInf()); - auto &maxReduction = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs(/*commutative=*/true) - // Only handle most inner reduction for now. - .dim(-1, utils::IteratorType::reduction) - .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(AllOperands(), IsIdentity()) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsProjected(-1)); - maxReduction = maxReduction.input(0, softmaxSourceOperand); - maxReduction = maxReduction.output(0, fillMinusInf); - maxReductionCapture = &maxReduction; - - transform_ext::StructuredOpMatcher *subOperand; - matchSubBroadcast(matcherContext, maxReduction, softmaxSourceOperand, - subOperand); - - auto &expOperand = m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(AllOperands(), IsIdentity()) - .output(AllOperands(), IsIdentity()) - .output(NumEqualsTo(1)); - expOperand = expOperand.input(0, *subOperand); - - auto &fillZero = m_StructuredOp(matcherContext) - .input(0, ConstantFloatZero()); - auto &sum = - m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs(/*commutative=*/true) - // Only handle most inner reduction for now. - .dim(-1, utils::IteratorType::reduction) - .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(AllOperands(), IsIdentity()) - .output(AllOperands(), IsProjected(-1)) - .output(NumEqualsTo(1)); - sum = sum.input(0, expOperand); - sum = sum.output(0, fillZero); - - auto &rcpOperand = m_StructuredOp(matcherContext) - .isFloatReciprocal() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(AllOperands(), IsIdentity()) - .output(AllOperands(), IsIdentity()) - .output(NumEqualsTo(1)); - rcpOperand = rcpOperand.input(0, sum); - - auto &mulOperand = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs(/*commutative=*/true) - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - - mulOperand = mulOperand.input(0, expOperand); - mulOperand = mulOperand.input(1, rcpOperand); - - transform_ext::StructuredOpMatcher *divOperand; - matchdivBroadcast(matcherContext, expOperand, sum, divOperand); - - auto &softmaxRoot = - transform_ext::m_StructuredOp_Or(matcherContext, mulOperand, *divOperand); - softmaxRootCapture = &softmaxRoot; -} - -/// Matcher for convolutions. -void transform_ext::makeConvolutionMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&convolutionCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::StructuredOpMatcher *&trailingCapture, - MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { - // The core part of the matcher is anchored on a particular convolution op. - auto &convolution = - m_StructuredOp( - matcherContext) - // Capture convolution dim classifications. - .convolutionDims(CaptureConvDims(captures.convolutionDims)) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.convolutionOpSizes)) - // Capture convolution element types. - .input(0, CaptureElementType(captures.inputElementType)) - .input(1, CaptureElementType(captures.filterElementType)) - .output(0, CaptureElementType(captures.outputElementType)); - convolutionCapture = &convolution; - - // Optional FillOp to create the unique output of the convolution. - auto &fill = m_StructuredOp(matcherContext) - .output(0, CaptureElementTypeBitWidth( - captures.maybeFillElementalTypeBitWidth)); - convolution = - convolution.output(NumEqualsTo(1)).output(0, fill, OptionalMatch()); - fillCapture = &fill; - - // Optional trailing op can be any map, transpose, broadcast but - // not reduce or windowing operation for now. - // It must create the unique input for the reduction. - auto &trailing = - m_StructuredOp(matcherContext) - // All parallel dimensions. - .dim(AllDims(), utils::IteratorType::parallel) - // All inputs are any projected permutation. - .input(AllOperands(), IsProjectedPermutation()) - .output(AllOperands(), IsPermutation()) - .output(NumEqualsTo(1)) - .dim(AllDims(), CaptureDims(captures.trailingOpSizes)) - // Capture output elemental type. - .output(0, CaptureElementTypeBitWidth( - captures.maybeTrailingOutputElementalTypeBitWidth)); - - // Optional trailing can be any map, transpose, broadcast but not reduce or - // windowing operation for now. - convolution = convolution.result(0, HasAnyUse(), trailing, OptionalMatch()); - if (mustMatchEntireFunc) { - convolution = - convolution.allTilableOpsCaptured(); - } - trailingCapture = &trailing; -} - -void transform_ext::makeConvolutionMatcher( - transform_ext::MatcherContext &context, - StructuredOpMatcher *&convolutionCapture, - MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { - StructuredOpMatcher *fill; - StructuredOpMatcher *trailing; - makeConvolutionMatcher(context, convolutionCapture, fill, trailing, captures, - mustMatchEntireFunc); -} - -void transform_ext::makePadMatcher(MatcherContext &context, - CapturingOpMatcher *&padCapture, - MatchedPadCaptures &captures, - bool mustMatchEntireFunc) { - auto &value = transform_ext::m_ShapedValue(context); - value.rank(transform_ext::CaptureRank(captures.rank)) - .dim(transform_ext::AllDims(), transform_ext::CaptureDims(captures.dims)) - .elementType(CaptureElementType(captures.elementType)); - auto &opMatcher = transform_ext::m_tensorPad(context) - .result(0, value) - .low(AllDims(), 0) - .yieldsExternalValue(); - if (mustMatchEntireFunc) { - opMatcher = opMatcher.allTilableOpsCaptured(); - } - padCapture = &opMatcher; -} diff --git a/llvm-external-projects/iree-dialects/python/CMakeLists.txt b/llvm-external-projects/iree-dialects/python/CMakeLists.txt index 0fc3506f5420..488e5129834a 100644 --- a/llvm-external-projects/iree-dialects/python/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/python/CMakeLists.txt @@ -15,16 +15,6 @@ declare_mlir_python_sources(IREEDialectsPythonSources.Dialects ADD_TO_PARENT IREEDialectsPythonSources ) -declare_mlir_dialect_extension_python_bindings( - ADD_TO_PARENT IREEDialectsPythonSources.Dialects - ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler" - TD_FILE dialects/IreeStructuredTransformOps.td - SOURCES - dialects/transform/iree_structured.py - dialects/_iree_structured_transform_ops_ext.py - DIALECT_NAME transform - EXTENSION_NAME iree_structured_transform) - ################################################################################ # Extensions ################################################################################ @@ -36,7 +26,7 @@ declare_mlir_python_extension(IREEDialectsPythonExtensions.Main SOURCES IREEDialectsModule.cpp EMBED_CAPI_LINK_LIBS - IREEDialectsCAPI + MLIRCAPITransformDialect PRIVATE_LINK_LIBS LLVMSupport ) diff --git a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp index d6bab61dfe9f..10a50f8e43dc 100644 --- a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp +++ b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp @@ -4,10 +4,10 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree-dialects-c/Dialects.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" +#include "mlir-c/Dialect/Transform.h" #include "mlir-c/RegisterEverything.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" @@ -30,7 +30,6 @@ NB_MODULE(_ireeDialects, m) { [](MlirContext context, bool load) { MlirDialectHandle handle = mlirGetDialectHandle__transform__(); mlirDialectHandleRegisterDialect(handle, context); - ireeRegisterTransformExtensions(context); if (load) { mlirDialectHandleLoadDialect(handle, context); } diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeStructuredTransformOps.td b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeStructuredTransformOps.td deleted file mode 100644 index 0a10cd36d778..000000000000 --- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeStructuredTransformOps.td +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef PYTHON_BINDINGS_IREE_TRANSFORMEXT_BINDING -#define PYTHON_BINDINGS_IREE_TRANSFORMEXT_BINDING - -include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td" - -#endif // PYTHON_BINDINGS_IREE_TRANSFORMEXT_BINDING diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py deleted file mode 100644 index f1f89550bd42..000000000000 --- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Disable PyType, it does not seem to like the specialization pattern used in -# MLIR. -# pytype: skip-file -try: - from ..ir import * - from ..dialects import pdl - from ._ods_common import ( - extend_opview_class as _ods_extend_opview_class, - segmented_accessor as _ods_segmented_accessor, - equally_sized_accessor as _ods_equally_sized_accessor, - get_default_loc_context as _ods_get_default_loc_context, - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, - ) - from typing import Optional, overload, Sequence, Union -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e -BoolArg = Optional[Union[bool, BoolAttr]] -IntListArg = Optional[Union[Sequence[int], ArrayAttr]] -StringArg = Optional[Union[str, StringAttr]] - - -def _defaulted_ensure(f): - def inner(value, default=None): - assert value is not None or default is not None - return f(default if value is None else value) - - return inner - - -@_defaulted_ensure -def _ensure_int_array_attr(value: IntListArg): - i64 = IntegerType.get_signless(64) - if isinstance(value, Sequence): - return ArrayAttr.get([IntegerAttr.get(i64, i) for i in value]) - return value - - -@_defaulted_ensure -def _ensure_bool_attr(value: BoolArg): - if isinstance(value, bool): - return BoolAttr.get(value) - return value - - -@_defaulted_ensure -def _ensure_string_attr(value: StringArg): - if isinstance(value, str): - return StringAttr.get(value) - return value - - -class LowerToLLVMOp: - """Specialization for the LowerToLLVMOp class.""" - - def __init__( - self, - *, - reassociate_fp_reductions: BoolArg = None, - enable_index_optimizations: BoolArg = None, - enable_arm_neon: BoolArg = None, - enable_arm_sve: BoolArg = None, - enable_amx: BoolArg = None, - enable_x86vector: BoolArg = None, - enable_async: BoolArg = None, - loc=None, - ip=None - ): - super().__init__( - _ensure_bool_attr(reassociate_fp_reductions, False), - _ensure_bool_attr(enable_index_optimizations, False), - _ensure_bool_attr(enable_arm_neon, False), - _ensure_bool_attr(enable_arm_sve, False), - _ensure_bool_attr(enable_amx, False), - _ensure_bool_attr(enable_x86vector, False), - _ensure_bool_attr(enable_async, False), - loc=loc, - ip=ip, - ) diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/transform/iree_structured.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/transform/iree_structured.py deleted file mode 100644 index 563e20bc0b04..000000000000 --- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/transform/iree_structured.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from .._iree_structured_transform_ops_gen import * diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt index 669a616d9e89..7d699b5d14eb 100644 --- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt @@ -1,6 +1,4 @@ set(LIBS - # Local dialects. - IREELinalgTransformDialect # Core dialects. MLIRAffineDialect MLIRArithDialect diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp index b4a100caa1b2..2824d307fa64 100644 --- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp +++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" @@ -71,7 +70,6 @@ int main(int argc, char **argv) { mlir::LLVM::registerInlinerInterface(registry); mlir::linalg::registerTilingInterfaceExternalModels(registry); - registry.addExtensions(); mlir::bufferization::registerTransformDialectExtension(registry); mlir::linalg::registerTransformDialectExtension(registry); mlir::scf::registerTransformDialectExtension(registry);