From 748ebeaa511375a8df6a7574a12c51e7688cec0f Mon Sep 17 00:00:00 2001 From: Han || Alex <36247722+Alex-Wengg@users.noreply.github.com> Date: Wed, 27 May 2026 12:43:43 -0400 Subject: [PATCH 1/8] [GlobalOpt] Move softmax matcher out of iree-dialects (#24466) Relocate the StructuredOpMatcher infrastructure and makeSoftmaxMatcher from llvm-external-projects/iree-dialects into the GlobalOptimization directory, which holds the only production user (RaiseSpecialOps). The code is moved verbatim (NFC); RaiseSpecialOps is repointed to the local header and the build deps are switched to the MLIR libraries the matcher needs. The now-unused matchers (reduction/matmul/convolution/pad) come along verbatim and can be trimmed in a follow-up once a build is available to validate. Progress toward retiring the iree-dialects dependency. Signed-off-by: Han || Alex <36247722+Alex-Wengg@users.noreply.github.com> --- .../compiler/GlobalOptimization/BUILD.bazel | 10 +- .../GlobalOptimization/CMakeLists.txt | 10 +- .../GlobalOptimization/RaiseSpecialOps.cpp | 2 +- .../GlobalOptimization/TransformMatchers.cpp | 1845 +++++++++++++++++ .../GlobalOptimization/TransformMatchers.h | 1201 +++++++++++ 5 files changed, 3063 insertions(+), 5 deletions(-) create mode 100644 compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.cpp create mode 100644 compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.h diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel index 37f58d4280a1..88fdc328a7c3 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel @@ -71,11 +71,13 @@ iree_compiler_cc_library( "RaiseSpecialOps.cpp", "RemoveZeroExtentTensors.cpp", "SimplifyPackUnpack.cpp", + "TransformMatchers.cpp", "Utils.cpp", "WarnOnUninitializedValues.cpp", ], hdrs = [ "Passes.h", + "TransformMatchers.h", "Utils.h", ], deps = [ @@ -103,29 +105,33 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Modules/IO/Parameters/Transforms", "//compiler/src/iree/compiler/Pipelines:Options", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgInterfaces", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LinalgUtils", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:TensorUtils", + "@llvm-project//mlir:TransformDialect", + "@llvm-project//mlir:TransformDialectInterfaces", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt index 0b5cfb188450..7d98a7ade9bf 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt +++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt @@ -38,6 +38,7 @@ iree_cc_library( GlobalOptimization HDRS "Passes.h" + "TransformMatchers.h" "Utils.h" SRCS "CleanupNumericNarrowing.cpp" @@ -62,33 +63,38 @@ iree_cc_library( "RaiseSpecialOps.cpp" "RemoveZeroExtentTensors.cpp" "SimplifyPackUnpack.cpp" + "TransformMatchers.cpp" "Utils.cpp" "WarnOnUninitializedValues.cpp" DEPS ::PassHeaders ::PassesIncGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRAffineDialect + MLIRAnalysis MLIRArithDialect MLIRArithUtils MLIRControlFlowDialect + MLIRFuncDialect MLIRFunctionInterfaces MLIRIR MLIRLinalgDialect + MLIRLinalgInterfacesIncGenLib MLIRLinalgTransforms MLIRLinalgUtils MLIRMathDialect MLIRMemRefDialect MLIRMemRefTransforms MLIRPass + MLIRRewrite MLIRSCFDialect MLIRSCFTransforms MLIRSupport MLIRTensorDialect MLIRTensorTransforms MLIRTensorUtils + MLIRTransformDialect + MLIRTransformDialectInterfaces MLIRTransformUtils MLIRTransforms iree::compiler::Codegen::Common diff --git a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp index b4e177eddfe4..658b1e9ade29 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp @@ -4,13 +4,13 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree-dialects/Transforms/TransformMatchers.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" #include "iree/compiler/GlobalOptimization/Passes.h" +#include "iree/compiler/GlobalOptimization/TransformMatchers.h" #include "iree/compiler/GlobalOptimization/Utils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" diff --git a/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.cpp b/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.cpp new file mode 100644 index 000000000000..9a4160fc3588 --- /dev/null +++ b/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.cpp @@ -0,0 +1,1845 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/GlobalOptimization/TransformMatchers.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; + +#define DEBUG_TYPE "transform-matchers" +#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " +#define DBGSNL() llvm::dbgs() << "\n[" DEBUG_TYPE "] " + +//===---------------------------------------------------------------------===// +// CapturingMatcherBase +//===---------------------------------------------------------------------===// + +void transform_ext::CapturingMatcherBase::getAllNested( + SmallVectorImpl &nested) { + + SetVector found; + found.insert(nested.begin(), nested.end()); + int64_t start = found.size(); + + auto appendOne = [&found](CapturingMatcherBase &one) { + found.insert(one.nestedCapturingMatchers.begin(), + one.nestedCapturingMatchers.end()); + for (CapturingValueMatcher *valueMatcher : + one.nestedCapturingValueMatchers) { + found.insert(valueMatcher->nestedCapturingMatchers.begin(), + valueMatcher->nestedCapturingMatchers.end()); + } + }; + + appendOne(*this); + for (int64_t position = start; position < found.size(); ++position) { + appendOne(*found[position]); + } + + llvm::append_range(nested, found.getArrayRef()); +} + +void transform_ext::CapturingMatcherBase::getAllNestedValueMatchers( + SmallVectorImpl &nested) { + + SetVector found; + found.insert(nested.begin(), nested.end()); + int64_t start = found.size(); + + auto appendOne = [&found](CapturingMatcherBase &one) { + found.insert(one.nestedCapturingValueMatchers.begin(), + one.nestedCapturingValueMatchers.end()); + for (CapturingOpMatcher *opMatcher : one.nestedCapturingMatchers) { + found.insert(opMatcher->nestedCapturingValueMatchers.begin(), + opMatcher->nestedCapturingValueMatchers.end()); + } + }; + + appendOne(*this); + for (int64_t position = start; position < found.size(); ++position) { + appendOne(*found[position]); + } + + llvm::append_range(nested, found.getArrayRef()); +} + +void transform_ext::CapturingMatcherBase::resetCapture() { + SmallVector nested; + getAllNested(nested); + for (CapturingOpMatcher *matcher : nested) { + matcher->captured = nullptr; + } + SmallVector nestedValue; + getAllNestedValueMatchers(nestedValue); + for (CapturingValueMatcher *matcher : nestedValue) { + matcher->captured = nullptr; + } +} + +//===---------------------------------------------------------------------===// +// CapturingOpMatcher +//===---------------------------------------------------------------------===// + +bool transform_ext::CapturingOpMatcher::checkAllTilableMatched( + Operation *parent, Operation *op, + ArrayRef matchers) { + LLVM_DEBUG(DBGS() << "all tilable ops captured"); + int64_t numTilableOps = 0; + if (!parent) { + return false; + } + parent->walk([&](TilingInterface Op) { ++numTilableOps; }); + + llvm::SmallPtrSet matched; + for (CapturingOpMatcher *nested : matchers) { + if (Operation *captured = nested->getCaptured()) { + matched.insert(captured); + } + } + + // Don't forget to include the root matcher. + matched.insert(op); + return numTilableOps == matched.size(); +} + +bool transform_ext::CapturingOpMatcher::match(Operation *op) { + auto debugRAII = llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "-------\n"); }); + LLVM_DEBUG(DBGS() << "matching: " << *op << "\n"); + + if (getCaptured()) { + LLVM_DEBUG(DBGS() << "found an already captured op: "); + if (getCaptured() == op) { + LLVM_DEBUG(llvm::dbgs() << "same\n"); + return true; + } else { + LLVM_DEBUG(llvm::dbgs() << "different\n"); + return false; + } + } + + if (!llvm::all_of(predicates, [op](const PredicateFn &fn) { + bool result = fn(op); + LLVM_DEBUG(llvm::dbgs() << ": " << result << "\n"); + return result; + })) { + return false; + } + + captured = op; + return true; +} + +void transform_ext::CapturingOpMatcher::debugOutputForCreate( + ArrayRef opNames) { + LLVM_DEBUG(DBGS() << "operation type is one of {"; + llvm::interleaveComma(opNames, llvm::dbgs()); llvm::dbgs() << "}"); +} + +/// Apply the given matcher to the given object, produce debug messages. +template ::template args<0>> +static bool recursiveMatch(Matcher &matcher, Object &object, + StringRef extraMessage = "") { + LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "] " << "start recursive match (" + << extraMessage << ") {\n"); + bool result = matcher.match(object); + LLVM_DEBUG(DBGS() << "} end recursive match"); + return result; +} + +transform_ext::CapturingOpMatcher & +transform_ext::CapturingOpMatcher::alternatives( + transform_ext::CapturingOpMatcher &first, + transform_ext::CapturingOpMatcher &second) { + addPredicate([&first, &second](Operation *op) { + LLVM_DEBUG(DBGS() << "matching alternatives\n"); + return recursiveMatch(first, op, "alternative 1") || + recursiveMatch(second, op, "alternative 2"); + }); + return *this; +} + +//---------------------------------------------------------------------------// +// Predicates for operands and results. +//---------------------------------------------------------------------------// + +transform_ext::CapturingOpMatcher & +transform_ext::CapturingOpMatcher::operand(transform_ext::NumEqualsTo num) { + addPredicate([=](Operation *op) { + LLVM_DEBUG(DBGS() << "operation has exactly " << num.value << " operands"); + return num.value == op->getNumOperands(); + }); + return *this; +} + +/// If `pos` is negative, returns the number of the operand in op starting from +/// the last. For example, -1 means the last operand, -2 means the +/// second-to-last, etc. Returns nullopt if pos is out-of-bounds, both positive +/// and negative. +static std::optional remapNegativeOperandNumber(int64_t pos, + Operation *op) { + int64_t updated = pos < 0 ? op->getNumOperands() + pos : pos; + if (updated < 0 || updated >= op->getNumOperands()) { + LLVM_DEBUG(DBGS() << "match operand #" << pos + << "that does not exist in the operation"); + return std::nullopt; + } + return updated; +} + +transform_ext::CapturingOpMatcher & +transform_ext::CapturingOpMatcher::operand(int64_t pos, + CapturingOpMatcher &nested) { + addPredicate([pos, &nested](Operation *op) { + std::optional operandNo = remapNegativeOperandNumber(pos, op); + if (!operandNo) { + return false; + } + LLVM_DEBUG(DBGS() << "operand #" << pos << " is defined by an operation"); + Operation *definingOp = op->getOperand(*operandNo).getDefiningOp(); + if (!definingOp) { + return false; + } + return recursiveMatch(nested, definingOp); + }); + recordNestedMatcher(nested); + return *this; +} + +transform_ext::CapturingOpMatcher & +transform_ext::CapturingOpMatcher::operand(int64_t pos, + CapturingValueMatcher &nested) { + addPredicate([pos, &nested](Operation *op) { + std::optional operandNo = remapNegativeOperandNumber(pos, op); + if (!operandNo) { + return false; + } + LLVM_DEBUG(DBGS() << "operand #" << pos << " is"); + Value operand = op->getOperand(*operandNo); + return recursiveMatch(nested, operand); + }); + recordNestedMatcher(nested); + return *this; +} + +transform_ext::CapturingOpMatcher &transform_ext::CapturingOpMatcher::operand( + int64_t position, std::function floatValueFn) { + addPredicate([position, + floatValueFn = std::move(floatValueFn)](Operation *op) -> bool { + std::optional operandNo = remapNegativeOperandNumber(position, op); + if (!operandNo) { + return false; + } + + LLVM_DEBUG(DBGS() << "operand #" << *operandNo + << " is a special floating point constant"); + auto cstOp = + op->getOperand(*operandNo).getDefiningOp(); + if (!cstOp) { + return false; + } + return floatValueFn(cstOp.value()); + }); + + return *this; +} + +transform_ext::CapturingOpMatcher & +transform_ext::CapturingOpMatcher::operand(int64_t position, ConstantFloatOne) { + return operand(position, + [](llvm::APFloat value) { return value.isExactlyValue(1.0); }); +} + +transform_ext::CapturingOpMatcher & +transform_ext::CapturingOpMatcher::result(transform_ext::NumEqualsTo num) { + addPredicate([=](Operation *op) { + LLVM_DEBUG(DBGS() << "operation has exactly " << num.value << " results"); + return num.value == op->getNumResults(); + }); + return *this; +} + +transform_ext::CapturingOpMatcher & +transform_ext::CapturingOpMatcher::result(int64_t pos, + CapturingValueMatcher &nested) { + addPredicate([pos, &nested](Operation *op) { + int64_t updated = pos < 0 ? op->getNumResults() + pos : pos; + if (updated < 0 || updated >= op->getNumResults()) { + LLVM_DEBUG(DBGS() << "matching result #" << pos + << " that does not exist in the operation"); + return false; + } + LLVM_DEBUG(DBGS() << "result #" << pos << " is"); + Value result = op->getResult(updated); + return recursiveMatch(nested, result); + }); + recordNestedMatcher(nested); + return *this; +} + +//===---------------------------------------------------------------------===// +// CapturingValueMatcher +//===---------------------------------------------------------------------===// + +namespace { +struct DebugPrintValueWrapper { + Value value; +}; + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const DebugPrintValueWrapper &wrapper) { + if (auto opResult = dyn_cast(wrapper.value)) { + return os << "op result #" << opResult.getResultNumber() << " in " + << wrapper.value; + } + + auto blockArg = cast(wrapper.value); + os << "block argument #" << blockArg.getArgNumber(); + Block *parentBlock = blockArg.getParentBlock(); + Region *parentRegion = parentBlock->getParent(); + if (!parentRegion) { + os << " of a detached block:\n"; + parentBlock->print(os); + return os; + } + + os << " of block #" + << std::distance(parentRegion->begin(), parentBlock->getIterator()); + Operation *parentOp = parentRegion->getParentOp(); + if (!parentOp) { + os << " of a detached region:\n"; + for (Block &b : *parentRegion) { + b.print(os); + } + return os; + } + + os << " in region #" << parentRegion->getRegionNumber() << " of " + << *parentOp; + return os; +} +} // namespace + +bool transform_ext::CapturingValueMatcher::match(Value value) { + auto debugRAII = llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "-------\n"); }); + LLVM_DEBUG(DBGS() << "matching " << DebugPrintValueWrapper{value} << "\n"); + + if (getCaptured()) { + LLVM_DEBUG(DBGS() << "found an already captured value: "); + if (getCaptured() == value) { + LLVM_DEBUG(llvm::dbgs() << "same\n"); + return true; + } else { + LLVM_DEBUG(llvm::dbgs() << "different\n"); + return false; + } + } + + for (const PredicateFn &fn : predicates) { + bool result = fn(value); + LLVM_DEBUG(llvm::dbgs() << ": " << result << "\n"); + if (!result) { + return false; + } + } + + captured = value; + return true; +} + +transform_ext::ShapedValueMatcher::ShapedValueMatcher() + : CapturingValueMatcher() { + addPredicate([](Value value) { + LLVM_DEBUG(DBGS() << "value is of shaped type"); + return value && isa(value.getType()); + }); +} + +transform_ext::ShapedValueMatcher & +transform_ext::ShapedValueMatcher::rank(transform_ext::CaptureRank capture) { + addPredicate([=](Value value) { + LLVM_DEBUG(DBGS() << "capturing shaped value rank"); + capture.value = cast(value.getType()).getRank(); + return true; + }); + return *this; +} + +transform_ext::ShapedValueMatcher & +transform_ext::ShapedValueMatcher::dim(int64_t dimension, CaptureDim capture) { + addPredicate([=](Value value) { + LLVM_DEBUG(DBGS() << "capturing shaped value dimension " << dimension); + capture.value = cast(value.getType()).getDimSize(dimension); + return true; + }); + return *this; +} + +transform_ext::ShapedValueMatcher & +transform_ext::ShapedValueMatcher::dim(AllDims tag, CaptureDims captures) { + (void)tag; + addPredicate([=](Value value) { + LLVM_DEBUG(DBGS() << "capturing all shaped value dimensions"); + ArrayRef shape = cast(value.getType()).getShape(); + captures.value.assign(shape.begin(), shape.end()); + return true; + }); + return *this; +} + +transform_ext::ShapedValueMatcher & +transform_ext::ShapedValueMatcher::elementType(CaptureElementType captures) { + addPredicate([=](Value value) { + LLVM_DEBUG(DBGS() << "capturing elementType"); + captures.value = cast(value.getType()).getElementType(); + return true; + }); + return *this; +} + +//===---------------------------------------------------------------------===// +// Constraints on op rank and dims. +//===---------------------------------------------------------------------===// + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::rank(NumGreaterEqualTo minRank) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "rank >= " << minRank.value); + return linalgOp.getNumLoops() >= minRank.value; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::rank(NumLowerEqualTo maxRank) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "rank <= " << maxRank.value); + return linalgOp.getNumLoops() <= maxRank.value; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::rank(NumEqualsTo exactRank) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "rank == " << exactRank.value); + return linalgOp.getNumLoops() == exactRank.value; + }); +} + +StringRef stringifyShapeKind(transform_ext::ShapeKind kind) { + switch (kind) { + case transform_ext::ShapeKind::Static: + return "static"; + case transform_ext::ShapeKind::Dynamic: + return "dynamic"; + } + llvm_unreachable("unhandled shape kind"); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::dim(SmallVector &&dimensions, + ShapeKind kind) { + return addPredicate([dimensions = std::move(dimensions), + kind](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "dimensions ["; + llvm::interleaveComma(dimensions, llvm::dbgs()); + llvm::dbgs() << "] are " << stringifyShapeKind(kind)); + SmallVector shape = linalgOp.getStaticLoopRanges(); + for (auto dimension : dimensions) { + int64_t transformedDimension = + dimension >= 0 ? dimension : shape.size() + dimension; + if (transformedDimension < 0 || transformedDimension >= shape.size()) { + return false; + } + if (ShapedType::isDynamic(shape[transformedDimension]) ^ + (kind == ShapeKind::Static)) { + continue; + } + return false; + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "all dimensions are " << stringifyShapeKind(kind)); + SmallVector shape = linalgOp.getStaticLoopRanges(); + return llvm::all_of(shape, [=](int64_t dimension) { + return ShapedType::isDynamic(dimension) ^ (kind == ShapeKind::Static); + }); + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::dim(SmallVector &&dimensions, + utils::IteratorType kind) { + return addPredicate([dimensions = std::move(dimensions), + kind](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "dimensions ["; + llvm::interleaveComma(dimensions, llvm::dbgs()); + llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind)); + int64_t rank = linalgOp.getNumLoops(); + for (auto dimension : dimensions) { + int64_t transformedDimension = + dimension >= 0 ? dimension : rank + dimension; + if (transformedDimension < 0 || transformedDimension >= rank) { + return false; + } + utils::IteratorType iteratorKind = + linalgOp.getIteratorTypesArray()[transformedDimension]; + if (iteratorKind == kind) { + continue; + } + return false; + } + return true; + }); +} +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::dim(AllDims tag, utils::IteratorType kind) { + return dim(AllDimsExcept({}), kind); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::dim(AllDimsExcept &&dims, + utils::IteratorType kind) { + return addPredicate([dimensions = std::move(dims), + kind](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "all dimensions except ["; + llvm::interleaveComma(dimensions.getExcluded(), llvm::dbgs()); + llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind)); + int64_t rank = linalgOp.getNumLoops(); + llvm::SmallDenseSet excludedDims; + for (int64_t dim : dimensions.getExcluded()) { + excludedDims.insert(dim >= 0 ? dim : rank + dim); + } + + for (auto [index, type] : + llvm::enumerate(linalgOp.getIteratorTypesArray())) { + if (excludedDims.contains(index)) { + continue; + } + if (type == kind) { + continue; + } + return false; + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::dim(int64_t dimension, + DivisibleBy divisibleBy) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "dimension " << dimension << " is divisible by " + << divisibleBy.value); + int64_t rank = linalgOp.getNumLoops(); + int64_t transformedDimension = + dimension >= 0 ? dimension : rank + dimension; + if (transformedDimension >= rank) { + return false; + } + + int64_t size = linalgOp.getStaticLoopRanges()[transformedDimension]; + return !ShapedType::isDynamic(size) && (size % divisibleBy.value == 0); + }); +} + +//===---------------------------------------------------------------------===// +// Capture directives. +//===---------------------------------------------------------------------===// +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::rank(CaptureRank capture) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "capture rank"); + capture.value = linalgOp.getNumLoops(); + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::dim(int64_t dimension, CaptureDim capture) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "capture dimension"); + int64_t rank = linalgOp.getNumLoops(); + int64_t transformedDimension = + dimension >= 0 ? dimension : rank + dimension; + if (transformedDimension >= rank) { + return false; + } + + capture.value = linalgOp.getStaticLoopRanges()[transformedDimension]; + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::dim(AllDims tag, CaptureDims captures) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "capture all dimensions"); + captures.value = linalgOp.getStaticLoopRanges(); + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::indexingMaps( + CaptureIndexingMaps indexingMaps) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "capture indexing maps"); + indexingMaps.value = linalgOp.getIndexingMapsArray(); + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::contractionDims( + CaptureContractionDims contractionDims) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "capture contraction dimensions"); + StringRef convMessage = linalg::detail::getMatchContractionMessage( + mlir::linalg::detail::isContractionInterfaceImpl( + linalgOp, &contractionDims.value)); + if (convMessage.empty()) { + return true; + } + LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")"); + return false; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::convolutionDims(CaptureConvDims convDims) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "capture convolution dimensions"); + StringRef convMessage = linalg::detail::getMatchConvolutionMessage( + mlir::linalg::detail::isConvolutionInterfaceImpl(linalgOp, + &convDims.value)); + if (convMessage.empty()) { + return true; + } + LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")"); + return false; + }); +} + +transform_ext::StructuredOpMatcher::StructuredOpMatcher( + StructuredOpMatcher &A, StructuredOpMatcher &B) { + + addPredicate([&A, &B](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "start recursive lhs OR match {\n"); + { + auto debugRAII = llvm::scope_exit( + [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); + if (A.match(linalgOp)) { + return true; + } + } + LLVM_DEBUG(DBGS() << "start recursive rhs OR match {\n"); + { + auto debugRAII = llvm::scope_exit( + [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); + if (B.match(linalgOp)) { + return true; + } + } + return false; + }); + recordNestedMatcher(A); + recordNestedMatcher(B); +} + +//===---------------------------------------------------------------------===// +// Constraints on input operands. +//===---------------------------------------------------------------------===// + +void transform_ext::StructuredOpMatcher::addInputMatcher( + int64_t position, std::function matcher, + OptionalMatch optional) { + addInputMatcher( + position, + // No need to handle optional inside the lambda, the wrapper will do that. + [matcher = std::move(matcher)](Value value) { + Operation *definingOp = value.getDefiningOp(); + return definingOp && matcher(definingOp); + }, + optional); +} + +void transform_ext::StructuredOpMatcher::addInputMatcher( + int64_t position, std::function matcher, + OptionalMatch optional) { + addPredicate([position, optional, matcher = std::move(matcher)]( + linalg::LinalgOp linalgOp) -> bool { + int64_t transformedPosition = + position >= 0 ? position : linalgOp.getNumDpsInputs() + position; + if (transformedPosition >= linalgOp.getNumDpsInputs()) { + LLVM_DEBUG(DBGS() << "input operand #" << position + << " does not exist but match required"); + return false; + } + + LLVM_DEBUG(DBGS() << "input operand #" << position + << (optional.value ? " (optional match) " : " ") + << "is\n"); + + // We MUST run the matcher at this point, even if the match is optional, + // to allow for capture. + LLVM_DEBUG(DBGS() << "start recursive match {\n"); + auto debugRAII = + llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); + if (matcher(linalgOp.getDpsInputOperand(transformedPosition)->get())) { + return true; + } + return optional.value; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(AllOperands tag, IsPermutation) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "all input operands have permutation maps"); + // all_of with a lambda requires const-casting dance, so using a loop. + for (OpOperand *operand : linalgOp.getDpsInputOperands()) { + if (!linalgOp.getMatchingIndexingMap(operand).isPermutation()) { + return false; + } + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(AllOperands tag, + IsProjectedPermutation) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "all input operands have projected permutation maps"); + // all_of with a lambda requires const-casting dance, so using a loop. + for (OpOperand *operand : linalgOp.getDpsInputOperands()) { + if (!linalgOp.getMatchingIndexingMap(operand).isProjectedPermutation()) { + return false; + } + } + return true; + }); +} + +/// Helper to check if the map is an identity map with a projected dim. +static bool isProjectedMap(AffineMap map, int64_t projectedDim) { + if (!map.isProjectedPermutation()) { + return false; + } + int64_t dimCounter = 0; + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + // Skip the project dim. + if (dimCounter == projectedDim) { + dimCounter++; + } + if (map.getDimPosition(i) != dimCounter++) { + return false; + } + } + return true; +} + +/// Helper to turn a potentially negative index to positive within the range +/// [0, ub) and indicate whether the transformed index is in bounds. +static bool makeValidPositiveIndex(int64_t &index, int64_t ub) { + int64_t positiveIndex = index >= 0 ? index : ub + index; + if (positiveIndex < 0 || ub < positiveIndex) { + LLVM_DEBUG(DBGSNL() << " index out of range"); + return false; + } + index = positiveIndex; + return true; +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(SmallVector &&positions, + IsProjected dim) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "operands "; + llvm::interleaveComma(positions, llvm::dbgs()); + llvm::dbgs() << " have a permutation maps with " << dim.value + << " projected"); + int64_t updatedDim = dim.value; + if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops())) { + return false; + } + for (int64_t position : positions) { + int64_t updatedPosition = position; + if (!makeValidPositiveIndex(updatedPosition, + linalgOp.getNumDpsInputs())) { + return false; + } + OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); + if (!isProjectedMap(linalgOp.getMatchingIndexingMap(operand), + updatedDim)) { + return false; + } + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(AllOperands tag, IsIdentity) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "all input operands have identity maps"); + // all_of with a lambda requires const-casting dance, so using a loop. + for (OpOperand *operand : linalgOp.getDpsInputOperands()) { + if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) { + return false; + } + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(SmallVector &&positions, + IsIdentity) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "input operands "; + llvm::interleaveComma(positions, llvm::dbgs()); + llvm::dbgs() << " have identity maps"); + // all_of with a lambda requires const-casting dance, so using a loop. + for (int64_t position : positions) { + int64_t updatedPosition = position; + if (!makeValidPositiveIndex(updatedPosition, + linalgOp.getNumDpsInputs())) { + return false; + } + OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); + if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) { + return false; + } + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(int64_t position, + ElementTypeBitWidth width) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "input operand #" << position + << " has elemental type with bit width " << width.value); + int64_t updatedPosition = position; + if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { + return false; + } + auto shapedType = dyn_cast( + linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); + return shapedType && shapedType.getElementType().isIntOrFloat() && + shapedType.getElementType().getIntOrFloatBitWidth() == width.value; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(int64_t position, + CaptureElementTypeBitWidth width) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "input operand #" << position << " capture bitwidth"); + int64_t updatedPosition = position; + if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { + return false; + } + auto shapedType = dyn_cast( + linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); + if (!shapedType || !shapedType.getElementType().isIntOrFloat()) { + return false; + } + width.value = shapedType.getElementType().getIntOrFloatBitWidth(); + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(int64_t position, + CaptureElementType elem) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "input operand #" << position + << " capture element type"); + int64_t updatedPosition = position; + if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { + return false; + } + auto shapedType = dyn_cast( + linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); + if (!shapedType) { + LLVM_DEBUG(DBGSNL() << " not a shaped type"); + return false; + } + elem.value = shapedType.getElementType(); + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(NumEqualsTo num) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "number of input operands == " << num.value); + return linalgOp.getNumDpsInputs() == num.value; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(int64_t position, + ConstantFloatMinOrMinusInf) { + return input(position, [](llvm::APFloat f) { + return (f.isLargest() || f.isInfinity()) && f.isNegative(); + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::input(int64_t position, ConstantFloatZero) { + return input(position, [](llvm::APFloat f) { return f.isZero(); }); +} + +transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::input( + int64_t position, std::function floatValueFn) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "input operand #" << position + << " is a special floating point constant"); + int64_t updatedPosition = position; + if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { + return false; + } + auto cstOp = linalgOp.getDpsInputOperand(updatedPosition) + ->get() + .getDefiningOp(); + if (!cstOp) { + return false; + } + return floatValueFn(cstOp.value()); + }); +} + +//===---------------------------------------------------------------------===// +// Constraints on output operands. +//===---------------------------------------------------------------------===// + +void transform_ext::StructuredOpMatcher::addOutputMatcher( + int64_t position, std::function matcher, + OptionalMatch optional) { + addPredicate([position, optional, matcher = std::move(matcher)]( + linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "output operand #" << position + << (optional.value ? " (optional match) " + : " (mandatory match) ") + << "is produced by\n"); + int64_t updatedPosition = position; + if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { + return false; + } + Operation *definingOp = + linalgOp.getDpsInitOperand(updatedPosition)->get().getDefiningOp(); + if (!definingOp) { + return optional.value; + } + // We MUST run the matcher at this point, even if the match is optional, + // to allow for capture. + LLVM_DEBUG(DBGS() << "start recursive match {\n"); + auto debugRAII = + llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); + if (matcher(definingOp)) { + return true; + } + return optional.value; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::output(AllOperands tag, IsPermutation) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "all output operands have permutation maps"); + for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { + if (!linalgOp.getMatchingIndexingMap(&operand).isPermutation()) { + return false; + } + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::output(AllOperands tag, + IsProjectedPermutation) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "all output operands have projected permutation maps"); + for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { + if (!linalgOp.getMatchingIndexingMap(&operand).isProjectedPermutation()) { + return false; + } + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::output(AllOperands tag, IsProjected dim) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "all output operands have a maps with projected"); + int64_t updatedDim = dim.value; + if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops())) { + return false; + } + // all_of with a lambda requires const-casting dance, so using a loop. + for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { + if (!isProjectedMap(linalgOp.getMatchingIndexingMap(&operand), + updatedDim)) { + return false; + } + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::output(AllOperands tag, IsIdentity) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "all output operands have identity permutation maps"); + for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { + if (!linalgOp.getMatchingIndexingMap(&operand).isIdentity()) { + return false; + } + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::output(int64_t position, + ElementTypeBitWidth width) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "output operand #" << position + << " has elemental type with bit width " << width.value); + int64_t updatedPosition = position; + if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { + return false; + } + auto shapedType = dyn_cast( + linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); + return shapedType && shapedType.getElementType().isIntOrFloat() && + shapedType.getElementType().getIntOrFloatBitWidth() == width.value; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::output(int64_t position, + CaptureElementTypeBitWidth width) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "output operand #" << position << " capture bitwidth"); + int64_t updatedPosition = position; + if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { + return false; + } + auto shapedType = dyn_cast( + linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); + if (!shapedType || !shapedType.getElementType().isIntOrFloat()) { + LLVM_DEBUG(DBGSNL() << " could not infer element type"); + return false; + } + width.value = shapedType.getElementType().getIntOrFloatBitWidth(); + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::output(int64_t position, + CaptureElementType elem) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "output operand #" << position + << " capture element type"); + int64_t updatedPosition = position; + if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { + return false; + } + auto shapedType = dyn_cast( + linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); + if (!shapedType) { + LLVM_DEBUG(DBGSNL() << " not a shaped type"); + return false; + } + elem.value = shapedType.getElementType(); + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::output(int64_t position, + SingleCombinerReduction tag) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "output operand #" << position + << " is populated by a single-combiner reduction"); + int64_t updatedPosition = position; + if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { + return false; + } + SmallVector combinerOps; + return matchReduction(linalgOp.getRegionOutputArgs(), updatedPosition, + combinerOps) && + llvm::hasSingleElement(combinerOps); + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::output(NumEqualsTo num) { + return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "number of output operands == " << num.value); + return linalgOp.getNumDpsInits() == num.value; + }); +} + +//===---------------------------------------------------------------------===// +// Constraints on results. +//===---------------------------------------------------------------------===// + +void transform_ext::StructuredOpMatcher::addResultMatcher( + int64_t position, HasAnyUse tag, std::function matcher, + OptionalMatch optional) { + addPredicate([matcher = std::move(matcher), optional, + position](linalg::LinalgOp linalgOp) -> bool { + LLVM_DEBUG(DBGS() << "result #" << position + << (optional.value ? " (optional match) " + : " (mandatory match) ") + << "has a use\n"); + int64_t updatedPosition = position; + if (!makeValidPositiveIndex(updatedPosition, linalgOp->getNumResults())) { + return false; + } + + // We MUST run the matcher at this point, even if the match is optional, + // to allow for capture. + LLVM_DEBUG(DBGS() << "start recursive match {\n"); + auto debugRAII = + llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); + if (llvm::any_of(linalgOp->getResult(updatedPosition).getUsers(), + [&matcher](Operation *op) { return matcher(op); })) { + return true; + } + return optional.value; + }); +} + +//===-------------------------------------------------------------------===// +// Constraints on op region. +//===-------------------------------------------------------------------===// + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::singleOpWithCanonicaleArgs( + StringRef opcode, bool commutative) { + return addPredicate([=](linalg::LinalgOp linalgOp) { + if (linalgOp.getBlock()->getOperations().size() != 2) { + return false; + } + Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); + if (innerOp->getName().getStringRef() != opcode || + innerOp->getNumResults() != 1) { + return false; + } + Operation *yieldOp = linalgOp.getBlock()->getTerminator(); + if (yieldOp->getNumOperands() != 1) { + return false; + } + if (yieldOp->getOperand(0).getDefiningOp() != innerOp) { + return false; + } + if (commutative && innerOp->getNumOperands() == 2) { + auto arg0 = dyn_cast(innerOp->getOperand(0)); + auto arg1 = dyn_cast(innerOp->getOperand(1)); + if (!arg0 || !arg1) { + return false; + } + if (arg0.getParentBlock() != linalgOp.getBlock() || + arg1.getParentBlock() != linalgOp.getBlock()) { + return false; + } + if (!((arg0.getArgNumber() == 0 && arg1.getArgNumber() == 1) || + (arg1.getArgNumber() == 0 && arg0.getArgNumber() == 1))) { + return false; + } + } else { + for (auto [index, operand] : llvm::enumerate(innerOp->getOperands())) { + auto arg = dyn_cast(operand); + if (!arg || arg.getParentBlock() != linalgOp.getBlock() || + arg.getArgNumber() != index) { + return false; + } + } + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::isFloatReciprocal() { + return addPredicate([=](linalg::LinalgOp linalgOp) { + LLVM_DEBUG(DBGS() << "op region represents a reciprocal operation"); + if (linalgOp.getBlock()->getOperations().size() != 2) { + return false; + } + Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); + if (!isa(innerOp) || innerOp->getNumResults() != 1) { + return false; + } + Operation *yieldOp = linalgOp.getBlock()->getTerminator(); + if (yieldOp->getNumOperands() != 1) { + return false; + } + if (yieldOp->getOperand(0).getDefiningOp() != innerOp) { + return false; + } + auto cst = innerOp->getOperand(0).getDefiningOp(); + if (!cst || cst.value().convertToDouble() != 1.0) { + return false; + } + auto arg = dyn_cast(innerOp->getOperand(1)); + if (!arg || arg.getParentBlock() != linalgOp.getBlock() || + arg.getArgNumber() != 0) { + return false; + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::passThroughOp() { + return addPredicate([=](linalg::LinalgOp linalgOp) { + if (linalgOp.getBlock()->getOperations().size() != 1) { + return false; + } + Operation *yieldOp = linalgOp.getBlock()->getTerminator(); + for (auto [index, operand] : llvm::enumerate(yieldOp->getOperands())) { + auto arg = dyn_cast(operand); + if (!arg || arg.getParentBlock() != linalgOp.getBlock() || + arg.getArgNumber() != index) { + return false; + } + } + return true; + }); +} + +transform_ext::StructuredOpMatcher & +transform_ext::StructuredOpMatcher::hasContractionBody( + function_ref isaElemOpTy, + function_ref isaReductionOpTy, StringRef elemOpName, + StringRef reductionOpName) { + return addPredicate([=](linalg::LinalgOp linalgOp) { + LLVM_DEBUG(DBGS() << "op region is a " << elemOpName << "/" + << reductionOpName << " contraction ("); + auto scopeExitPrinter = + llvm::scope_exit([] { LLVM_DEBUG(llvm::dbgs() << " check failed)"); }); + + Block *body = linalgOp.getBlock(); + if (!llvm::hasNItems(*body, 3)) { + LLVM_DEBUG(llvm::dbgs() << "three-operation body"); + return false; + } + if (body->getNumArguments() != 3) { + LLVM_DEBUG(llvm::dbgs() << "three-argument block"); + return false; + } + + Operation *elemOp = &(*linalgOp.getBlock()->getOperations().begin()); + Operation *reductionOp = elemOp->getNextNode(); + Operation *yieldOp = reductionOp->getNextNode(); + if (!isaElemOpTy(elemOp)) { + LLVM_DEBUG(llvm::dbgs() << "first operation is a " << elemOpName); + return false; + } + if (!isaReductionOpTy(reductionOp)) { + LLVM_DEBUG(llvm::dbgs() << "second operation is a " << reductionOpName); + return false; + } + if (yieldOp->getNumOperands() != 1) { + LLVM_DEBUG(llvm::dbgs() << "one value yielded"); + return false; + } + if (yieldOp->getOperand(0).getDefiningOp() != reductionOp) { + LLVM_DEBUG(llvm::dbgs() << "yielded value produced by the second op"); + return false; + } + if (elemOp->getNumOperands() != 2 || elemOp->getNumResults() != 1) { + LLVM_DEBUG(llvm::dbgs() << "first op has two operands and one result"); + return false; + } + if (reductionOp->getNumOperands() != 2 || + reductionOp->getNumResults() != 1) { + LLVM_DEBUG(llvm::dbgs() << "second op has two operands and one result"); + return false; + } + + SmallVector expectedReductionOperands = {body->getArgument(2), + elemOp->getResult(0)}; + if (!llvm::equal(expectedReductionOperands, reductionOp->getOperands()) && + !llvm::equal(llvm::reverse(expectedReductionOperands), + reductionOp->getOperands())) { + LLVM_DEBUG(llvm::dbgs() << "operands of the second op"); + return false; + } + + ValueRange expectedElemOperands = body->getArguments().take_front(2); + if (!llvm::equal(expectedElemOperands, elemOp->getOperands()) && + !llvm::equal(llvm::reverse(expectedElemOperands), + elemOp->getOperands())) { + LLVM_DEBUG(llvm::dbgs() << "operands of the first op"); + return false; + } + + scopeExitPrinter.release(); + LLVM_DEBUG(llvm::dbgs() << "success)"); + return true; + }); +} + +void transform_ext::detail::debugOutputForConcreteOpMatcherConstructor( + StringRef name) { + LLVM_DEBUG(DBGS() << "op is a " << name << "'"); +} + +//===---------------------------------------------------------------------===// +// TensorPadOpMatcher +//===---------------------------------------------------------------------===// + +transform_ext::TensorPadOpMatcher & +transform_ext::TensorPadOpMatcher::low(ArrayRef sizes) { + return addPredicate([=](tensor::PadOp tensorPad) { + LLVM_DEBUG({ + DBGS() << "low pad sizes are "; + llvm::interleaveComma(sizes, llvm::dbgs()); + }); + for (auto [ofr, sz] : llvm::zip(tensorPad.getMixedLowPad(), sizes)) { + if (isConstantIntValue(ofr, sz)) { + return false; + } + } + return true; + }); +} + +transform_ext::TensorPadOpMatcher & +transform_ext::TensorPadOpMatcher::low(AllDims tag, int64_t size) { + return addPredicate([=](tensor::PadOp tensorPad) { + LLVM_DEBUG(DBGS() << "all low pad sizes are " << size); + return llvm::all_of(tensorPad.getMixedLowPad(), [&](OpFoldResult ofr) { + return isConstantIntValue(ofr, size); + }); + }); +} + +transform_ext::TensorPadOpMatcher & +transform_ext::TensorPadOpMatcher::high(ArrayRef sizes) { + return addPredicate([=](tensor::PadOp tensorPad) { + LLVM_DEBUG({ + DBGS() << "high pad sizes are "; + llvm::interleaveComma(sizes, llvm::dbgs()); + }); + for (auto [ofr, sz] : llvm::zip(tensorPad.getMixedHighPad(), sizes)) { + if (isConstantIntValue(ofr, sz)) { + return false; + } + } + return true; + }); +} + +transform_ext::TensorPadOpMatcher & +transform_ext::TensorPadOpMatcher::high(AllDims tag, int64_t size) { + return addPredicate([=](tensor::PadOp tensorPad) { + LLVM_DEBUG(DBGS() << "all high pad sizes are " << size); + return llvm::all_of(tensorPad.getMixedHighPad(), [&](OpFoldResult ofr) { + return isConstantIntValue(ofr, size); + }); + }); +} + +transform_ext::TensorPadOpMatcher & +transform_ext::TensorPadOpMatcher::yieldsExternalValue() { + return addPredicate([=](tensor::PadOp tensorPad) { + LLVM_DEBUG(DBGS() << "pad body yields an externally-defined value"); + Block *body = tensorPad.getBody(); + if (!llvm::hasSingleElement(*body)) { + return false; + } + return llvm::all_of(body->getTerminator()->getOperands(), + [body](Value operand) { + auto arg = dyn_cast(operand); + return !arg || arg.getOwner() != body; + }); + }); +} + +//===---------------------------------------------------------------------===// +// MatchCallbackResult. +//===---------------------------------------------------------------------===// + +ArrayRef +transform_ext::MatchCallbackResult::getPayloadGroup(int64_t position) const { + assert(position < payloadGroupLengths.size()); + int64_t start = 0; + for (int64_t i = 0; i < position; ++i) { + start += payloadGroupLengths[i]; + } + return llvm::ArrayRef(payloadOperations) + .slice(start, payloadGroupLengths[position]); +} + +//===---------------------------------------------------------------------===// +// Case-specific matcher builders. +//===---------------------------------------------------------------------===// + +static constexpr int64_t kCudaWarpSize = 32; + +void transform_ext::makeReductionMatcher( + transform_ext::MatcherContext &matcherContext, + transform_ext::StructuredOpMatcher *&reductionCapture, + transform_ext::StructuredOpMatcher *&fillCapture, + transform_ext::StructuredOpMatcher *&leadingCapture, + transform_ext::StructuredOpMatcher *&trailingCapture, + MatchedReductionCaptures &captures, bool mustMatchEntireFunc) { + // The core part of the matcher is anchored on a particular reduction op. + auto &reduction = + m_StructuredOp(matcherContext) + // Op has at least a parallel a reduction dimension and at + // most 3 parallel dimensions. + // TODO: relax once we have global collapse/expand_shape. + // + .rank(NumGreaterEqualTo(2)) + .rank(NumLowerEqualTo(4)) + .rank(CaptureRank(captures.reductionRank)) + // Op has a single most-minor reduction. + .dim(-1, utils::IteratorType::reduction) + // Capture op sizes. + .dim(AllDims(), CaptureDims(captures.reductionOpSizes)) + // All other dimensions are parallel. + .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) + // Single input for now, can be arbitrary projected permutations. + // TODO: Multiple inputs, can be arbitrary projected permutations. + // TODO: Watch out for multiple inputs though as a reduction turns + // into a contraction when mixed with projected + // permutations. A reduction is often bandwidth bound but + // contraction is a different beast that is compute bound + // and has a very different schedule. + // + .input(NumEqualsTo(1)) + .input(AllOperands(), IsProjectedPermutation()) + // Single output supported atm. + // TODO: Multiple outputs. + // + .output(NumEqualsTo(1)) + // A reduction output must be a projected permutation, match it but we + // could also drop this technically. + .output(AllOperands(), IsProjectedPermutation()) + // Only single combiner for now due to reduction warp + // distribution. + // TODO: relax this once reduction distribution is more powerful. + // + .output(0, CaptureElementTypeBitWidth( + captures.reductionOutputElementalTypeBitWidth)) + .output(0, SingleCombinerReduction()); + reductionCapture = &reduction; + + // Mandatory FillOp must create the unique output of the reduction. + // TODO: Relax this, as any map, broadcast, transpose should also work. + // + auto &fill = m_StructuredOp(matcherContext); + reduction = reduction.output(NumEqualsTo(1)).output(0, fill); + fillCapture = &fill; + + // Optional leading or trailing op can be any map, transpose, broadcast but + // not reduce or windowing operation for now. + // It must create the unique input for the reduction. + // TODO: match more optional leading ops, one per input of the reduction. + // TODO: careful about multi-output and turning into a contraction. + // + transform_ext::StructuredOpMatcher commonLeadingOrTrailing = + m_StructuredOp(matcherContext) + // All parallel dimensions. + .dim(AllDims(), utils::IteratorType::parallel) + // All inputs are any projected permutation. + .input(AllOperands(), IsProjectedPermutation()) + .output(AllOperands(), IsPermutation()) + // leading and trailing may have 0, 1 or more input as long as they do + // not come from unmatched ops. This extra constraint is taken care of + // separately. This is also a noop but we document it. + // TODO: Base and derived classes, atm this does not compile. + // .input(NumGreaterEqualTo(0)) + // Single output supported atm. + // TODO: extend this. + // + .output(NumEqualsTo(1)); + // TODO: match more optional leading ops, one per input of the reduction. + // TODO: careful about multi-output and turning into a contraction. + // + auto &leading = + m_StructuredOp(matcherContext, commonLeadingOrTrailing) + .rank(CaptureRank(captures.maybeLeadingRank)) + // Capture op sizes. + .dim(AllDims(), CaptureDims(captures.leadingOpSizes)) + // Capture output elemental type. + .output(0, CaptureElementTypeBitWidth( + captures.maybeLeadingOutputElementalTypeBitWidth)); + reduction = reduction.input(0, leading, OptionalMatch()); + leadingCapture = &leading; + + // Optional trailing can be any map, transpose, broadcast but not reduce or + // windowing operation for now. + // It must be fed by the unique input for the reduction. + // TODO: match more optional leading ops, one per input of the reduction. + // TODO: careful about multi-output and turning into a contraction. + // + auto &trailing = + m_StructuredOp(matcherContext, commonLeadingOrTrailing) + .rank(CaptureRank(captures.maybeTrailingRank)) + // Capture op sizes. + .dim(AllDims(), CaptureDims(captures.trailingOpSizes)) + // Capture output elemental type. + .output(0, CaptureElementTypeBitWidth( + captures.maybeTrailingOutputElementalTypeBitWidth)); + reduction = reduction.result(0, HasAnyUse(), trailing, OptionalMatch()); + if (mustMatchEntireFunc) { + reduction = reduction.allTilableOpsCaptured(); + } + trailingCapture = &trailing; +} + +void transform_ext::makeReductionMatcher(transform_ext::MatcherContext &context, + StructuredOpMatcher *&reductionCapture, + MatchedReductionCaptures &captures, + bool mustMatchEntireFunc) { + StructuredOpMatcher *fill; + StructuredOpMatcher *leading; + StructuredOpMatcher *trailing; + makeReductionMatcher(context, reductionCapture, fill, leading, trailing, + captures, mustMatchEntireFunc); +} + +void transform_ext::makeMatmulMatcher( + transform_ext::MatcherContext &matcherContext, + transform_ext::StructuredOpMatcher *&matmulCapture, + transform_ext::StructuredOpMatcher *&fillCapture, + transform_ext::StructuredOpMatcher *&trailingCapture, + transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { + auto &matmul = transform_ext::m_StructuredOp(matcherContext) + // Capture op sizes. + .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) + // Capture input/output element types. + .input(0, CaptureElementType(captures.lhsElementType)) + .input(1, CaptureElementType(captures.rhsElementType)) + .output(0, CaptureElementType(captures.outputElementType)); + matmulCapture = &matmul; + // Mandatory FillOp must create the unique output of the reduction. + auto &fill = transform_ext::m_StructuredOp(matcherContext); + matmul = matmul.output(transform_ext::NumEqualsTo(1)).output(0, fill); + fillCapture = &fill; + + auto &trailing = m_StructuredOp(matcherContext); + matmul = matmul.result(0, HasAnyUse(), trailing, OptionalMatch()); + if (mustMatchEntireFunc) { + matmul = matmul.allTilableOpsCaptured(); + } + trailingCapture = &trailing; +} + +void transform_ext::makeBatchMatmulMatcher( + transform_ext::MatcherContext &matcherContext, + transform_ext::StructuredOpMatcher *&bmmCapture, + transform_ext::StructuredOpMatcher *&fillCapture, + transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { + auto &bmm = + transform_ext::m_StructuredOp( + matcherContext) + .hasContractionBody() + .rank(NumEqualsTo(4)) + .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) + .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) + .dim(-1, utils::IteratorType::reduction) + .contractionDims(CaptureContractionDims(captures.contractionDims)) + .input(NumEqualsTo(2)) + .input(0, CaptureElementType(captures.lhsElementType)) + .input(1, CaptureElementType(captures.rhsElementType)) + .output(0, CaptureElementType(captures.outputElementType)); + bmmCapture = &bmm; + + auto &fill = transform_ext::m_StructuredOp(matcherContext); + bmm = bmm.output(0, fill); + fillCapture = &fill; + + if (mustMatchEntireFunc) { + bmm = bmm.allTilableOpsCaptured(); + } +} + +/// Match sum(%src, broadcast(%reduction)) +static void +matchSubBroadcast(transform_ext::MatcherContext &matcherContext, + transform_ext::StructuredOpMatcher &maxReduction, + transform_ext::CapturingValueMatcher &softmaxSourceOperand, + transform_ext::StructuredOpMatcher *&sub) { + using namespace transform_ext; + + auto &broadcast = + transform_ext::m_StructuredOp(matcherContext) + .passThroughOp() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(1)) + .input(0, IsProjected(-1)) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsIdentity()); + broadcast = broadcast.input(0, maxReduction); + + auto &subParallel = + transform_ext::m_StructuredOp(matcherContext) + .singleOpWithCanonicaleArgs() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(2)) + .input(0, IsIdentity()) + .input(1, IsIdentity()) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsIdentity()); + subParallel = subParallel.input(0, softmaxSourceOperand); + subParallel = subParallel.input(1, broadcast); + + auto &subBroadcast = + transform_ext::m_StructuredOp(matcherContext) + .singleOpWithCanonicaleArgs() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(2)) + .input(0, IsIdentity()) + .input(1, IsProjected(-1)) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsIdentity()); + subBroadcast = subBroadcast.input(0, softmaxSourceOperand); + subBroadcast = subBroadcast.input(1, maxReduction); + auto &subOr = transform_ext::m_StructuredOp_Or(matcherContext, subBroadcast, + subParallel); + sub = &subOr; +} + +/// Match sum(%exp, broadcast(%sum)) +static void matchdivBroadcast(transform_ext::MatcherContext &matcherContext, + transform_ext::StructuredOpMatcher &expOperand, + transform_ext::StructuredOpMatcher &sum, + transform_ext::StructuredOpMatcher *&div) { + using namespace transform_ext; + + auto &broadcast = + transform_ext::m_StructuredOp(matcherContext) + .passThroughOp() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(1)) + .input(0, IsProjected(-1)) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsIdentity()); + broadcast = broadcast.input(0, sum); + + auto &divNoBroadcast = + transform_ext::m_StructuredOp(matcherContext) + .singleOpWithCanonicaleArgs() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(2)) + .input(0, IsIdentity()) + .input(1, IsIdentity()) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsIdentity()); + + divNoBroadcast = divNoBroadcast.input(0, expOperand); + divNoBroadcast = divNoBroadcast.input(1, broadcast); + + auto &divBroadcast = + transform_ext::m_StructuredOp(matcherContext) + .singleOpWithCanonicaleArgs() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(2)) + .input(0, IsIdentity()) + .input(1, IsProjected(-1)) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsIdentity()); + + divBroadcast = divBroadcast.input(0, expOperand); + divBroadcast = divBroadcast.input(1, sum); + + auto &divMerge = transform_ext::m_StructuredOp_Or( + matcherContext, divNoBroadcast, divBroadcast); + div = &divMerge; +} + +void transform_ext::makeSoftmaxMatcher( + transform_ext::MatcherContext &matcherContext, + transform_ext::StructuredOpMatcher *&maxReductionCapture, + transform_ext::StructuredOpMatcher *&softmaxRootCapture) { + auto &softmaxSourceOperand = m_Value(matcherContext); + + auto &fillMinusInf = m_StructuredOp(matcherContext) + .input(0, ConstantFloatMinOrMinusInf()); + auto &maxReduction = + transform_ext::m_StructuredOp(matcherContext) + .singleOpWithCanonicaleArgs(/*commutative=*/true) + // Only handle most inner reduction for now. + .dim(-1, utils::IteratorType::reduction) + .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) + .input(NumEqualsTo(1)) + .input(AllOperands(), IsIdentity()) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsProjected(-1)); + maxReduction = maxReduction.input(0, softmaxSourceOperand); + maxReduction = maxReduction.output(0, fillMinusInf); + maxReductionCapture = &maxReduction; + + transform_ext::StructuredOpMatcher *subOperand; + matchSubBroadcast(matcherContext, maxReduction, softmaxSourceOperand, + subOperand); + + auto &expOperand = m_StructuredOp(matcherContext) + .singleOpWithCanonicaleArgs() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(1)) + .input(AllOperands(), IsIdentity()) + .output(AllOperands(), IsIdentity()) + .output(NumEqualsTo(1)); + expOperand = expOperand.input(0, *subOperand); + + auto &fillZero = m_StructuredOp(matcherContext) + .input(0, ConstantFloatZero()); + auto &sum = + m_StructuredOp(matcherContext) + .singleOpWithCanonicaleArgs(/*commutative=*/true) + // Only handle most inner reduction for now. + .dim(-1, utils::IteratorType::reduction) + .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) + .input(NumEqualsTo(1)) + .input(AllOperands(), IsIdentity()) + .output(AllOperands(), IsProjected(-1)) + .output(NumEqualsTo(1)); + sum = sum.input(0, expOperand); + sum = sum.output(0, fillZero); + + auto &rcpOperand = m_StructuredOp(matcherContext) + .isFloatReciprocal() + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(1)) + .input(AllOperands(), IsIdentity()) + .output(AllOperands(), IsIdentity()) + .output(NumEqualsTo(1)); + rcpOperand = rcpOperand.input(0, sum); + + auto &mulOperand = + transform_ext::m_StructuredOp(matcherContext) + .singleOpWithCanonicaleArgs(/*commutative=*/true) + .dim(AllDims(), utils::IteratorType::parallel) + .input(NumEqualsTo(2)) + .input(0, IsIdentity()) + .input(1, IsProjected(-1)) + .output(NumEqualsTo(1)) + .output(AllOperands(), IsIdentity()); + + mulOperand = mulOperand.input(0, expOperand); + mulOperand = mulOperand.input(1, rcpOperand); + + transform_ext::StructuredOpMatcher *divOperand; + matchdivBroadcast(matcherContext, expOperand, sum, divOperand); + + auto &softmaxRoot = + transform_ext::m_StructuredOp_Or(matcherContext, mulOperand, *divOperand); + softmaxRootCapture = &softmaxRoot; +} + +/// Matcher for convolutions. +void transform_ext::makeConvolutionMatcher( + transform_ext::MatcherContext &matcherContext, + transform_ext::StructuredOpMatcher *&convolutionCapture, + transform_ext::StructuredOpMatcher *&fillCapture, + transform_ext::StructuredOpMatcher *&trailingCapture, + MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { + // The core part of the matcher is anchored on a particular convolution op. + auto &convolution = + m_StructuredOp( + matcherContext) + // Capture convolution dim classifications. + .convolutionDims(CaptureConvDims(captures.convolutionDims)) + // Capture op sizes. + .dim(AllDims(), CaptureDims(captures.convolutionOpSizes)) + // Capture convolution element types. + .input(0, CaptureElementType(captures.inputElementType)) + .input(1, CaptureElementType(captures.filterElementType)) + .output(0, CaptureElementType(captures.outputElementType)); + convolutionCapture = &convolution; + + // Optional FillOp to create the unique output of the convolution. + auto &fill = m_StructuredOp(matcherContext) + .output(0, CaptureElementTypeBitWidth( + captures.maybeFillElementalTypeBitWidth)); + convolution = + convolution.output(NumEqualsTo(1)).output(0, fill, OptionalMatch()); + fillCapture = &fill; + + // Optional trailing op can be any map, transpose, broadcast but + // not reduce or windowing operation for now. + // It must create the unique input for the reduction. + auto &trailing = + m_StructuredOp(matcherContext) + // All parallel dimensions. + .dim(AllDims(), utils::IteratorType::parallel) + // All inputs are any projected permutation. + .input(AllOperands(), IsProjectedPermutation()) + .output(AllOperands(), IsPermutation()) + .output(NumEqualsTo(1)) + .dim(AllDims(), CaptureDims(captures.trailingOpSizes)) + // Capture output elemental type. + .output(0, CaptureElementTypeBitWidth( + captures.maybeTrailingOutputElementalTypeBitWidth)); + + // Optional trailing can be any map, transpose, broadcast but not reduce or + // windowing operation for now. + convolution = convolution.result(0, HasAnyUse(), trailing, OptionalMatch()); + if (mustMatchEntireFunc) { + convolution = + convolution.allTilableOpsCaptured(); + } + trailingCapture = &trailing; +} + +void transform_ext::makeConvolutionMatcher( + transform_ext::MatcherContext &context, + StructuredOpMatcher *&convolutionCapture, + MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { + StructuredOpMatcher *fill; + StructuredOpMatcher *trailing; + makeConvolutionMatcher(context, convolutionCapture, fill, trailing, captures, + mustMatchEntireFunc); +} + +void transform_ext::makePadMatcher(MatcherContext &context, + CapturingOpMatcher *&padCapture, + MatchedPadCaptures &captures, + bool mustMatchEntireFunc) { + auto &value = transform_ext::m_ShapedValue(context); + value.rank(transform_ext::CaptureRank(captures.rank)) + .dim(transform_ext::AllDims(), transform_ext::CaptureDims(captures.dims)) + .elementType(CaptureElementType(captures.elementType)); + auto &opMatcher = transform_ext::m_tensorPad(context) + .result(0, value) + .low(AllDims(), 0) + .yieldsExternalValue(); + if (mustMatchEntireFunc) { + opMatcher = opMatcher.allTilableOpsCaptured(); + } + padCapture = &opMatcher; +} diff --git a/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.h b/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.h new file mode 100644 index 000000000000..6de72611a792 --- /dev/null +++ b/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.h @@ -0,0 +1,1201 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_GLOBALOPTIMIZATION_TRANSFORMMATCHERS_H_ +#define IREE_COMPILER_GLOBALOPTIMIZATION_TRANSFORMMATCHERS_H_ + +#include +#include +#include + +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/Matchers.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +namespace transform_ext { + +//===---------------------------------------------------------------------===// +// StructuredOpMatcher and predicates. +//===---------------------------------------------------------------------===// + +class StructuredOpMatcher; +class MatcherContext; +StructuredOpMatcher &m_StructuredOp(MatcherContext &); + +/// A tag indicating the shape being static or dynamic, for use with the +/// structured op matcher. +enum class ShapeKind { Static, Dynamic }; + +/// A placeholder indicating the structured op matcher to check the predicate +/// for all dimensions. +struct AllDims {}; + +/// A predicate indicating the structured op matcher to check the predicate for +/// all dimensions except the specified ones. +struct AllDimsExcept { + explicit AllDimsExcept(std::initializer_list range) { + llvm::append_range(exceptions, range); + } + ArrayRef getExcluded() const { return llvm::ArrayRef(exceptions); } + +private: + SmallVector exceptions; +}; + +/// A placeholder indicating the structured op matcher to check the predicate +/// for all operands of the relevant kind. +struct AllOperands {}; + +/// Base class for single-value captures. Concrete captures should inherit this +/// and forward the constructor via `using Base::Base`. +template +struct CaptureStaticValue { + using Base = CaptureStaticValue; + explicit CaptureStaticValue(T &value) : value(value) {} + T &value; +}; + +/// Captures the (static) size of the dimension. +struct CaptureDim : public CaptureStaticValue { + using Base::Base; +}; + +/// Captures the (static) sizes of multiple dimensions. +struct CaptureDims : public CaptureStaticValue> { + using Base::Base; +}; + +/// Captures the contraction dimensions of the target operation. +struct CaptureIndexingMaps : public CaptureStaticValue> { + using Base::Base; +}; + +/// Captures the contraction dimensions of the target operation. +struct CaptureContractionDims + : public CaptureStaticValue { + using Base::Base; +}; + +/// Captures the convolution dimensions of the target operation. +struct CaptureConvDims + : public CaptureStaticValue { + using Base::Base; +}; + +/// Captures the rank of the operation. +struct CaptureRank : public CaptureStaticValue { + using Base::Base; +}; + +/// Captures the bitwidth of an element type. +struct CaptureElementTypeBitWidth : public CaptureStaticValue { + using Base::Base; +}; + +/// Captures element element type. +struct CaptureElementType : public CaptureStaticValue { + using Base::Base; +}; + +template +struct CaptureAttribute : public CaptureStaticValue { + static_assert(std::is_base_of_v, + "can only capture a subclass of Attribute"); + using CaptureStaticValue::CaptureStaticValue; +}; + +/// A tag indicating to look for any user of the operation's result that would +/// satisfy the predicate. +struct HasAnyUse {}; + +/// Base class for predicate parameters that can be described with the single +/// value. Concrete predicate parameters should inherit this and forward the +/// constructor via `using Base::Base`. +template +struct SingleValuePredicateParam { + using Base = SingleValuePredicateParam; + explicit SingleValuePredicateParam(T value) : value(value) {} + const T value; +}; + +/// Indicates that the dimension must be divisible by the given value. +struct DivisibleBy : public SingleValuePredicateParam { + using Base::Base; +}; + +/// Indicates that the number of entities must be equal to the given value. +struct NumEqualsTo : public SingleValuePredicateParam { + using Base::Base; +}; + +/// Indicates that the number of entities must be greater than the given value. +struct NumGreaterEqualTo : public SingleValuePredicateParam { + using Base::Base; +}; + +/// Indicates that the number of entities must be greater than the given value. +struct NumLowerEqualTo : public SingleValuePredicateParam { + using Base::Base; +}; + +/// Indicates that the bit width of the elemental type must be equal to the give +/// value. +struct ElementTypeBitWidth : public SingleValuePredicateParam { + using Base::Base; +}; + +/// Predicate tag indicating that the affine map is a permutation. +struct IsPermutation {}; + +/// Predicate tag indicating that the affine map is a projected permutation. +struct IsProjectedPermutation {}; + +/// Predicate tag indicating that the affine map is a projection of given +/// dimension. +struct IsProjected : public SingleValuePredicateParam { + using Base::Base; +}; +/// Predicate tag indicating that the affine map is an identity. +struct IsIdentity {}; + +/// Predicate tag indicating that the operand is a special float constant. +struct ConstantFloatMinOrMinusInf {}; +struct ConstantFloatZero {}; +struct ConstantFloatOne {}; + +/// Indicates that the match optional. The matcher is still expected to run and +/// capture if successful. The parameter can be set to false +struct OptionalMatch : public SingleValuePredicateParam { + OptionalMatch() : Base(true) {} + explicit OptionalMatch(bool set) : Base(set) {} +}; + +/// Predicate tag indicating that the reduction is produced by a single combiner +/// operation. +struct SingleCombinerReduction {}; + +class CapturingOpMatcher; +class CapturingValueMatcher; + +/// Base class for capturing matchers that can be owned by the context. +class CapturingMatcherBase { +public: + // Virtual destructor so unique pointers are deallocated correctly. + // TODO: if efficiency is a problem, consider disallowing non-trivial + // destructors for subclasses. + virtual ~CapturingMatcherBase() = default; + +protected: + /// Informs the matcher that it has another, nested matcher. Derived classes + /// must call this to keep track of nested matchers for capture resetting + /// purposes. + template + void recordNestedMatcher(T &nested) { + if constexpr (std::is_base_of_v) { + nestedCapturingMatchers.push_back(&nested); + } + if constexpr (std::is_base_of_v) { + nestedCapturingValueMatchers.push_back(&nested); + } + } + + /// Appends all nested capturing matchers of a certain kind, excluding this + /// one, to `nested`. + void getAllNested(SmallVectorImpl &nested); + void + getAllNestedValueMatchers(SmallVectorImpl &nested); + + /// Resets nested capturing matchers but does NOT reset the current one. + void resetCapture(); + +private: + /// A list of (recursively) nested capturing matchers that should be reset + /// when the current matcher is. + SmallVector nestedCapturingMatchers; + SmallVector nestedCapturingValueMatchers; +}; + +/// A context object holding capturing matchers, must outlive any individual +/// matcher. When matching complex subgraphs, the caller often doesn't care +/// about all intermediate nodes (operations) in the graph and shouldn't need to +/// hold matcher objects for those. These matchers can be created in this +/// context. +class MatcherContext { +public: + /// Create a new matcher of the specified type owned by this context. + template + std::enable_if_t, T> & + allocate(Args &&...args) { + // Need to call "new" explicitly as make_unique wouldn't have access to the + // private constructor when this class would. + ownedMatchers.emplace_back( + std::unique_ptr(new T(std::forward(args)...))); + return *static_cast(ownedMatchers.back().get()); + } + +private: + /// Owning list of matchers. + // TODO: If this becomes inefficient, consider something like BumpPtrAllocator + // that derived classes can use to store their members as well. + SmallVector> ownedMatchers; +}; + +/// Base class for value matchers that capture the matched value. Stores a list +/// of predicates and requires all of them to match for the value to match. Once +/// a value matched, any repeated use just verifies that equality of the value. +class CapturingValueMatcher : public CapturingMatcherBase { + friend class CapturingMatcherBase; + friend class MatcherContext; + + using PredicateFn = std::function; + +public: + /// Resets the captured value to null. This should be called if the same + /// pattern needs to be applied more than once as it may keep captured values + /// for optional nested predicates from the previous application. + void resetCapture() { + captured = nullptr; + CapturingMatcherBase::resetCapture(); + } + + /// Returns the matched value if the match was successful. + Value getCaptured() const { return captured; } + + /// Matches the given value, hook for `matchPattern`. + bool match(Value value); + +protected: + CapturingValueMatcher() = default; + + /// Adds a predicate to the end of the predicate list for this value matcher. + template + void addPredicate(Fn &&predicate) { + predicates.emplace_back(std::forward(predicate)); + } + + /// The captured value. + Value captured = nullptr; + +private: + /// Additional predicates to be checked on the value. + SmallVector predicates; +}; + +/// Creates a matcher of an arbitrary value. +inline CapturingValueMatcher &m_Value(MatcherContext &context) { + return context.allocate(); +} + +/// Matcher for typed values whose type implements the `ShapedType` interface. +/// Allows for matching the components of the shaped type such as rank and +/// dimensions. +class ShapedValueMatcher : public CapturingValueMatcher { + friend class MatcherContext; + + ShapedValueMatcher(); + +public: + /// Add an always-succeeding matcher predicate capturing the rank. + ShapedValueMatcher &rank(CaptureRank capture); + + /// Add an always-succeeding matcher predicate capturing the size of the + /// dimension identified by the first argument. + ShapedValueMatcher &dim(int64_t dimension, CaptureDim capture); + + /// Add an always-succeeding matcher predicate capturing the sizes of all + /// dimensions in order of appearance. + ShapedValueMatcher &dim(AllDims tag, CaptureDims captures); + + /// Add an always-succeeding matcher predicate capturing the element type of + /// the value. + ShapedValueMatcher &elementType(CaptureElementType captures); +}; + +/// Construct a new matcher of a value whose type is a `ShapedType`, owned by +/// the given context. +inline ShapedValueMatcher &m_ShapedValue(MatcherContext &context) { + return context.allocate(); +} + +/// Matcher for operations with additional predicates attachable through the +/// fluent, a.k.a. chainable, API. Note that public API must *not* accept +/// additional callbacks even; new predicates should be added instead when +/// necessary. Not only this decreases the depth of the callback stack and +/// increases readability, it also allows us to port the matcher to a +/// declarative format using PDL and/or Transform dialect in the future. The +/// latter will become impossible with arbitrary C++ callbacks. +class CapturingOpMatcher : public CapturingMatcherBase { + friend class CapturingMatcherBase; + friend class MatcherContext; + + template + friend CapturingOpMatcher &m_Operation(MatcherContext &matcherContext); + +public: + using PredicateFn = std::function; + + /// Matches the given operation, hook for `matchPattern`. + bool match(Operation *op); + + /// Resets the captured value to null. This should be called if the same + /// pattern needs to be applied more than once as it may keep captured values + /// for optional nested predicates from the previous application. + void resetCapture() { + captured = nullptr; + CapturingMatcherBase::resetCapture(); + } + + /// Returns the matched operation if the match was successful. + Operation *getCaptured() const { return captured; } + + /// Adds alternative paths for predicates. In practice, this is just a + /// predicate that is satisfied when either the first or the second matcher is + /// satisfied. The alternative satisfaction is eager and short-cutting, i.e., + /// the second alternative will not be processed, and therefore will not + /// capture values, if the first alternative succeeded. + CapturingOpMatcher &alternatives(CapturingOpMatcher &first, + CapturingOpMatcher &second); + + //===-------------------------------------------------------------------===// + // Constraints on adjacent ops. + //===-------------------------------------------------------------------===// + + /// Adds a predicate checking that all ops implementing TilingInterface in the + /// parent of the given type (e.g., a function or a module) were matched by + /// this or nested matchers. This is useful to ensure that the matcher covered + /// the entire parent region, not just a parent of it. This predicate **must** + /// be added *after* all the other predicates that capture. + template + CapturingOpMatcher &allTilableOpsCaptured() { + SmallVector copy; + copy.push_back(this); + getAllNested(copy); + addPredicate([copy = std::move(copy)](Operation *op) { + Operation *parent = op->getParentOfType(); + return checkAllTilableMatched(parent, op, copy); + }); + return *this; + } + + //-------------------------------------------------------------------------// + // Predicates for operands and results. + //-------------------------------------------------------------------------// + + /// Adds a predicate checking that the operation has exactly the given number + /// of operands. + CapturingOpMatcher &operand(NumEqualsTo num); + + /// Adds a predicate checking that the `pos`-th operand of the operation is + /// defined by an operation that satisfies the given matcher. + CapturingOpMatcher &operand(int64_t pos, CapturingOpMatcher &nested); + + /// Adds a predicate checking that the `pos`-th operand of the operation + /// satisfies the given value matcher. + CapturingOpMatcher &operand(int64_t pos, CapturingValueMatcher &nested); + + /// Adds a predicate checking that the `pos`-th operand of the operation is + /// defined by `arith.constant` with the value 1.0. + // TODO: better matching for attributes. + CapturingOpMatcher &operand(int64_t pos, ConstantFloatOne); + + /// Adds a predicate checking that the operation has exactly the given number + /// of results. + CapturingOpMatcher &result(NumEqualsTo num); + + /// Adds a predicate checking that the `pos`-th result of the operation + /// satisfies the given value matcher. + CapturingOpMatcher &result(int64_t pos, CapturingValueMatcher &nested); + +protected: + /// Constructs a default operation matcher accepting any operation. + CapturingOpMatcher() = default; + + /// Adds a predicate for the matched operation to satisfy. + template + void addPredicate(Fn &&predicate) { + predicates.emplace_back(std::forward(predicate)); + } + + /// Produce the debug output for `create` method in a non-templated way. + static void debugOutputForCreate(ArrayRef opNames); + +private: + /// A list of additional conditions for the operation to match. + SmallVector predicates; + + /// Checks that `matchers` captured all tilable ops nested in `parent` except + /// for `linalgOp`. This is an implementation detail of allTilableOpsCaptured. + static bool checkAllTilableMatched(Operation *parent, Operation *op, + ArrayRef matchers); + + /// Creates a matcher for an operation with one of the given types. + template + static CapturingOpMatcher create() { + CapturingOpMatcher matcher; + matcher.addPredicate([](Operation *op) { + debugOutputForCreate(ArrayRef{OpType::getOperationName()...}); + return isa(op); + }); + return matcher; + } + + /// Common util for constant matcher. + CapturingOpMatcher &operand(int64_t position, + std::function floatValueFn); + +protected: + /// Matched value. + Operation *captured = nullptr; +}; + +namespace detail { +/// Prints the debug output from the ConcreteOpMatcher constructor. The +/// implementation must reside in the C++ file so we don't pollute the header +/// with debug includes, and ConcreteOpMatcher is a class template that can only +/// reside in the header. +void debugOutputForConcreteOpMatcherConstructor(StringRef name); +} // namespace detail + +/// Base class for matchers that match a specific op. Adds an initial predicate +/// checking if the op is indeed of the specified kind. +/// Derived classes specializing this for op interfaces MUST also define a +/// specialization of DebugOpKindDescription. +template +class ConcreteOpMatcher : public CapturingOpMatcher { +protected: + using Base = ConcreteOpMatcher; + + static StringRef getConcreteOpDescription() { + return OpTy::getOperationName(); + } + + /// Adds a predicate checking if the op is of the OpTy kind. + ConcreteOpMatcher() { + CapturingOpMatcher::addPredicate([](Operation *op) { + detail::debugOutputForConcreteOpMatcherConstructor( + Derived::getConcreteOpDescription()); + return isa(op); + }); + } + + /// Adds a predicate for the matched operation to satisfy. + template + Derived &addPredicate(FnTy &&predicate) { + // Dispatch to the callback. + CapturingOpMatcher::addPredicate( + [inner = std::move(predicate)](Operation *op) { + return inner(cast(op)); + }); + return static_cast(*this); + } + +public: + /// Adds alternative paths for predicates. In practice, this is just a + /// predicate that is satisfied when either the first or the second matcher is + /// satisfied. The alternative satisfaction is eager and short-cutting, i.e., + /// the second alternative will not be processed, and therefore will not + /// capture values, if the first alternative succeeded. + Derived &alternatives(CapturingOpMatcher &first, CapturingOpMatcher &second) { + return static_cast( + CapturingOpMatcher::alternatives(first, second)); + } + + /// Adds a predicate checking that all ops implementing TilingInterface in the + /// parent of the given type (e.g., a function or a module) were matched by + /// this or nested matchers. This is useful to ensure that the matcher covered + /// the entire parent region, not just a parent of it. This predicate **must** + /// be added *after* all the other predicates that capture. + template + Derived &allTilableOpsCaptured() { + return static_cast( + CapturingOpMatcher::allTilableOpsCaptured()); + } + + //-------------------------------------------------------------------------// + // Predicates for operands and results. + //-------------------------------------------------------------------------// + + /// Adds a predicate checking that the operation has exactly the given number + /// of operands. + Derived &operand(NumEqualsTo num) { + return static_cast(CapturingOpMatcher::operand(num)); + } + + /// Adds a predicate checking that the `pos`-th operand of the operation is + /// defined by an operation that satisfies the given matcher. + Derived &operand(int64_t pos, CapturingOpMatcher &nested) { + return static_cast(CapturingOpMatcher::operand(pos, nested)); + } + + /// Adds a predicate checking that the `pos`-th operand of the operation + /// satisfies the given value matcher. + Derived &operand(int64_t pos, CapturingValueMatcher &nested) { + return static_cast(CapturingOpMatcher::operand(pos, nested)); + } + + /// Adds a predicate checking that the `pos`-th operand of the operation is + /// defined by `arith.constant` with the value 1.0. + // TODO: better matching for attributes. + Derived &operand(int64_t pos, ConstantFloatOne c) { + return static_cast(CapturingOpMatcher::operand(pos, c)); + } + + /// Adds a predicate checking that the operation has exactly the given number + /// of results. + Derived &result(NumEqualsTo num) { + return static_cast(CapturingOpMatcher::result(num)); + } + + /// Adds a predicate checking that the `pos`-th result of the operation + /// satisfies the given value matcher. + Derived &result(int64_t pos, CapturingValueMatcher &nested) { + return static_cast(CapturingOpMatcher::result(pos, nested)); + } +}; + +/// Matcher for the `tensor.pad` operation. +class TensorPadOpMatcher + : public ConcreteOpMatcher { + friend class MatcherContext; + + TensorPadOpMatcher() = default; + +public: + /// Adds a predicate checking that the low padding sizes are exactly the given + /// values. + TensorPadOpMatcher &low(ArrayRef sizes); + + /// Adds a predicate checking that the low padding sizes for all dimensions + /// are exactly the same given value. + TensorPadOpMatcher &low(AllDims tag, int64_t size); + + /// Adds a predicate checking that the high padding sizes for all dimensions + /// are exactly the same given value. + TensorPadOpMatcher &high(ArrayRef sizes); + + /// Adds a predicate checking that the high padding sizes for all dimensions + /// are exactly the same given value. + TensorPadOpMatcher &high(AllDims tag, int64_t size); + + /// Adds a predicate checking that the body of the pad only yields values + /// defined outside the pad region. + TensorPadOpMatcher &yieldsExternalValue(); +}; + +inline TensorPadOpMatcher &m_tensorPad(MatcherContext &matcherContext) { + return matcherContext.allocate(); +} + +/// Creates a default operation matcher in the given context that accepts any +/// operation. +inline CapturingOpMatcher &m_Operation(MatcherContext &matcherContext) { + return matcherContext.allocate(); +} + +/// Creates an operation matcher in the given context that accepts only +/// operations of the kinds provided as template arguments. +template +inline CapturingOpMatcher &m_Operation(MatcherContext &matcherContext) { + return matcherContext.allocate( + CapturingOpMatcher::create()); +} + +/// Matcher for structured aka Linalg operations. +class StructuredOpMatcher + : public ConcreteOpMatcher { + friend class MatcherContext; + + StructuredOpMatcher() = default; + +public: + static StringRef getConcreteOpDescription() { + return "linalg interface implementation"; + } + + /// Creates a matcher for a structured operation with one of the given types. + template + static StructuredOpMatcher create() { + StructuredOpMatcher matcher; + matcher.addPredicate([](Operation *op) { + debugOutputForCreate(ArrayRef{OpType::getOperationName()...}); + return isa(op) && isa(op); + }); + return matcher; + } + + /// Matches a structured operation if either patterns A or B match. + StructuredOpMatcher(StructuredOpMatcher &A, StructuredOpMatcher &B); + + //===-------------------------------------------------------------------===// + // Constraints on op rank and dims. + //===-------------------------------------------------------------------===// + /// Adds a predicate checking that the given rank must be greater than some + /// constant value. + StructuredOpMatcher &rank(NumGreaterEqualTo minRank); + StructuredOpMatcher &rank(NumLowerEqualTo maxRank); + StructuredOpMatcher &rank(NumEqualsTo exactRank); + + /// Adds a predicate checking that the given iteration space dimension is + /// static/dynamic. The dimension index may be negative, in which case + /// dimensions are counted from the last one (i.e. Python-style), or be an + /// AllDims tag, in which case all dimensions are checked. This may be + /// eventually extended to slices and/or lists of dimensions. + StructuredOpMatcher &dim(int64_t dimension, ShapeKind kind) { + return dim(SmallVector{dimension}, kind); + } + StructuredOpMatcher &dim(SmallVector &&dimensions, ShapeKind kind); + StructuredOpMatcher &dim(AllDims tag, ShapeKind kind); + + /// Adds a predicate checking that the given iteration space dimension has the + /// given iterator type, e.g., parallel or reduction. The dimension index may + /// be negative, in which case dimensions are counted from the last one + /// (i.e. Python-style), or be an AllDims tag, in which case all dimensions + /// are checked. This may be eventually extended to slices and/or lists of + /// dimensions. + StructuredOpMatcher &dim(int64_t dimension, utils::IteratorType kind) { + return dim(SmallVector{dimension}, kind); + } + // Ownership may get tricky here so we wrap in an explicit vector. + StructuredOpMatcher &dim(SmallVector &&dimensions, + utils::IteratorType kind); + StructuredOpMatcher &dim(AllDims tag, utils::IteratorType kind); + StructuredOpMatcher &dim(AllDimsExcept &&dimensions, + utils::IteratorType kind); + + /// Adds a predicate checking that the given iteration space dimension is + /// statically known to be divisible by the given value. The dimension index + /// may be negative, in which case dimensions are counted from the last one + /// (i.e. Python-style). + StructuredOpMatcher &dim(int64_t dimension, DivisibleBy divisibleBy); + + //===-------------------------------------------------------------------===// + // Capture directives. + //===-------------------------------------------------------------------===// + StructuredOpMatcher &rank(CaptureRank capture); + StructuredOpMatcher &dim(int64_t dimension, CaptureDim capture); + StructuredOpMatcher &dim(AllDims tag, CaptureDims captures); + StructuredOpMatcher &indexingMaps(CaptureIndexingMaps indexingMaps); + StructuredOpMatcher &contractionDims(CaptureContractionDims contractionDims); + StructuredOpMatcher &convolutionDims(CaptureConvDims convDims); + + //===-------------------------------------------------------------------===// + // Constraints on input operands. + //===-------------------------------------------------------------------===// + /// Adds a predicate checking that the structured op has the given number of + /// inputs. + StructuredOpMatcher &input(NumEqualsTo num); + + /// Adds a predicate that recursively applies other predicates to the + /// operation defining the `position`-th operand. The position may be + /// negative, in which case positions are counted from the last one + /// (i.e. Python-style). When the match is optional, the predicate check + /// succeeds as long as the `position` is in bounds. The matcher is executed + /// if there is a defining operation for the input operand. + template + std::enable_if_t::value, + StructuredOpMatcher &> + input(int64_t position, T &operandMatcher, + OptionalMatch optional = OptionalMatch(false)) { + addInputMatcher( + position, + [&operandMatcher](Operation *op) { return operandMatcher.match(op); }, + optional); + recordNestedMatcher(operandMatcher); + return *this; + } + template + std::enable_if_t::value, + StructuredOpMatcher &> + input(int64_t position, T &operandMatcher, + OptionalMatch optional = OptionalMatch(false)) { + addInputMatcher( + position, + [&operandMatcher](Value v) { return operandMatcher.match(v); }, + optional); + recordNestedMatcher(operandMatcher); + return *this; + } + + /// Adds a predicate checking that all input operands of the structured op + /// have a permutation indexing map. + StructuredOpMatcher &input(AllOperands tag, IsPermutation); + + /// Adds a predicate checking that all input operands of the structured op + /// have a projected permutation indexing map. + StructuredOpMatcher &input(AllOperands tag, IsProjectedPermutation); + + /// Adds a predicate checking that all input operands of the structured op + /// are projected along the given dimension. + StructuredOpMatcher &input(SmallVector &&positions, IsProjected dim); + StructuredOpMatcher &input(int64_t position, IsProjected dim) { + return input(SmallVector{position}, dim); + } + + /// Adds a predicate checking that all input operands of the structured op + /// have identity indexing map. + StructuredOpMatcher &input(AllOperands tag, IsIdentity); + StructuredOpMatcher &input(SmallVector &&positions, IsIdentity); + StructuredOpMatcher &input(int64_t position, IsIdentity) { + return input(SmallVector{position}, IsIdentity()); + } + + /// Adds a predicate checking that the bit width of the elemental type of the + /// structured op input at the given position is equal to the given value. + StructuredOpMatcher &input(int64_t position, ElementTypeBitWidth width); + + /// Capture the elemental type bitwidth of input operand `position`. + StructuredOpMatcher &input(int64_t position, + CaptureElementTypeBitWidth width); + + /// Capture the elemental type of input operand `position`. + StructuredOpMatcher &input(int64_t position, CaptureElementType elem); + + /// Check if input is equal to a known constant. + // TODO: Support matching for constant ops. + StructuredOpMatcher &input(int64_t position, ConstantFloatMinOrMinusInf); + StructuredOpMatcher &input(int64_t position, ConstantFloatZero); + + //===-------------------------------------------------------------------===// + // Constraints on output operands. + //===-------------------------------------------------------------------===// + + /// Adds a predicate checking that the structured op has the given number of + /// outputs. + StructuredOpMatcher &output(NumEqualsTo num); + + /// Adds a predicate checking that all output operands of the structured op + /// have a permutation indexing map. + StructuredOpMatcher &output(AllOperands tag, IsPermutation); + + /// Adds a predicate checking that all output operands of the structured op + /// have a projected permutation indexing map. + StructuredOpMatcher &output(AllOperands tag, IsProjectedPermutation); + + /// Adds a predicate checking that all output operands of the structured op + /// have a + StructuredOpMatcher &output(AllOperands tag, IsProjected dim); + + /// Adds a predicate checking that all output operands of the structured op + /// have identity indexing map. + StructuredOpMatcher &output(AllOperands tag, IsIdentity); + + /// Adds a predicate checking that the bit width of the elemental type of the + /// structured op output at the given position is equal to the given value. + StructuredOpMatcher &output(int64_t position, ElementTypeBitWidth width); + + /// Capture the elemental type bitwidth of output operand `position`. + StructuredOpMatcher &output(int64_t position, + CaptureElementTypeBitWidth width); + + /// Capture the elemental type of output operand `position`. + StructuredOpMatcher &output(int64_t position, CaptureElementType elem); + + /// Adds a predicate checking that the output of the structured op is produced + /// by a reduction with a single-operation combinator (such as addf or mulf, + /// but not a compare+select pair). + StructuredOpMatcher &output(int64_t position, SingleCombinerReduction tag); + + /// Adds a predicate that recursively applies other predicates to the + /// operation defining the init/out operand corresponding to `position`-th + /// output. The position may be negative, in which case positions are counted + /// from the last one (i.e. Python-style). When the match is optional, the + /// predicate check succeeds as long as the `position` is in bounds. The + /// matcher executed if there is a defining operation for the output operand. + template + std::enable_if_t::value, + StructuredOpMatcher &> + output(int64_t position, T &operandMatcher, + OptionalMatch optional = OptionalMatch(false)) { + addOutputMatcher( + position, + [&operandMatcher](Operation *op) { return operandMatcher.match(op); }, + optional); + recordNestedMatcher(operandMatcher); + return *this; + } + + //===-------------------------------------------------------------------===// + // Constraints on results. + //===-------------------------------------------------------------------===// + + /// Adds a predicate that recursively applies to users of the `position`-th + /// result of the structured op. Succeeds if any user matches the predicate. + /// When the match is optional, the predicate check succeeds as long as the + /// `position` is in bounds, after running the given matcher. + template + std::enable_if_t::value, + StructuredOpMatcher &> + result(int64_t position, HasAnyUse tag, T &resultUserMatcher, + OptionalMatch optional = OptionalMatch(false)) { + addResultMatcher( + position, tag, + [&resultUserMatcher](Operation *op) { + return resultUserMatcher.match(op); + }, + optional); + recordNestedMatcher(resultUserMatcher); + return *this; + } + + //===-------------------------------------------------------------------===// + // Constraints on op region. + //===-------------------------------------------------------------------===// + + /// Return true if the linalg op only contains a single ops and the arguments + /// of the operation match the order of the linalg operand. + /// Example: + /// linalg.generic + /// ins(%0, %1 : tensor, tensor) + /// outs(%2 : tensor) { + /// ^bb0(%arg0: f32, %arg1: f32): + /// %3 = arith.maxf %arg0, %arg1 : f32 + /// linalg.yield %3 : f32 + /// } -> tensor + /// If commutative is set binary operations can have their operands swapped. + template + StructuredOpMatcher &singleOpWithCanonicaleArgs(bool commutative = false) { + return singleOpWithCanonicaleArgs(OpType::getOperationName(), commutative); + } + StructuredOpMatcher &singleOpWithCanonicaleArgs(StringRef opname, + bool commutative); + /// Check if the op is a linalg of with a single float reciprocal op. + StructuredOpMatcher &isFloatReciprocal(); + /// Check if the op is a linalg of with a region containing only a yield op + /// using block arguments in order. + StructuredOpMatcher &passThroughOp(); + + /// Check if the body of the linalg op implements a contraction of the kind + /// result = input1 input2 + template + StructuredOpMatcher &hasContractionBody() { + return hasContractionBody( + [](Operation *op) { return isa(op); }, + [](Operation *op) { return isa(op); }, + ElemOpTy::getOperationName(), ReductionOpTy::getOperationName()); + } + +private: + /// Non-template implementations of nested predicate builders for inputs, + /// outputs and results. Should not be called directly. + void addInputMatcher(int64_t position, + std::function matcher, + OptionalMatch optional); + void addInputMatcher(int64_t position, std::function matcher, + OptionalMatch optional); + void addOutputMatcher(int64_t position, + std::function matcher, + OptionalMatch optional); + void addResultMatcher(int64_t position, HasAnyUse tag, + std::function matcher, + OptionalMatch optional); + + // Common util for constant matcher. + StructuredOpMatcher &input(int64_t position, + std::function floatValueFn); + + /// Non-template implementation of hasContractionBody. Takes callbacks for + /// checking operation kinds and names for error reporting. + StructuredOpMatcher & + hasContractionBody(function_ref isaElemOpTy, + function_ref isaReductionOpTy, + StringRef elemOpName, StringRef reductionOpName); +}; + +/// Creates a matcher of an arbitrary structured op. +inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext) { + return matcherContext.allocate(); +} + +/// Creates a matcher that is a copy of the given matcher. +inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext, + const StructuredOpMatcher &other) { + return matcherContext.allocate(other); +} + +/// Creates a matcher that accepts as disjunction of the two given matchers. +inline StructuredOpMatcher &m_StructuredOp_Or(MatcherContext &matcherContext, + StructuredOpMatcher &A, + StructuredOpMatcher &B) { + return matcherContext.allocate(A, B); +} + +/// Creates a matcher of a structured op with kinds provided as template +/// arguments. +template +inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext) { + return matcherContext.allocate( + StructuredOpMatcher::create()); +} + +//===---------------------------------------------------------------------===// +// MatchCallback functionality. +//===---------------------------------------------------------------------===// + +/// Additional results of the C++ callback usable in the `match_callback` +/// transform operation. Conceptually, a list of lists of payload operations to +/// be associated with each result handle. +class MatchCallbackResult { +public: + /// Returns the number of lists of payload operations. + int64_t getNumPayloadGroups() const { return payloadGroupLengths.size(); } + + /// Returns the `position`-th list of payload operations. + ArrayRef getPayloadGroup(int64_t position) const; + + /// Adds a new list of payload operations to the list of lists. The new list + /// must not contain null operations. + template + int64_t addPayloadGroup(Range operations) { + int64_t originalLength = payloadOperations.size(); + assert(llvm::all_of(operations, [](Operation *op) -> bool { return op; }) && + "null operation"); + llvm::append_range(payloadOperations, operations); + payloadGroupLengths.push_back(payloadOperations.size() - originalLength); + return payloadGroupLengths.size() - 1; + } + void addPayloadGroup(ArrayRef operations) { + addPayloadGroup>(operations); + } + + /// Adds a new singleton list of payload operation to the list of lists if the + /// operation is non-null, adds an empty list otherwise. Useful for results of + /// optional matches. + void addPotentiallyEmptyPayloadGroup(Operation *op) { + if (!op) { + addPayloadGroup(ArrayRef()); + } else { + addPayloadGroup(ArrayRef(op)); + } + } + +private: + /// The flat list of all payload operations. `payloadGroupLengths` can be used + /// to compute the sublist that corresponds to one nested list. + // TODO: if somebody implements such a flattened vector generically, use it. + SmallVector payloadOperations; + SmallVector payloadGroupLengths; +}; + +/// A transform state extension that maintains the mapping between callback +/// names as strings usable in `match_callback` and their implementations. +class MatchCallbacksRegistry : public transform::TransformState::Extension { +public: + using MatchCallbackFn = std::function; + + /// Constructs the extension. + MatchCallbacksRegistry(transform::TransformState &state) + : transform::TransformState::Extension(state) {} + + /// Registers the given function as a callback with the given name. The name + /// must not be already present in the registry. The callback must be + /// convertible to MatchCallbackFn. + template + void registerCallback(StringRef name, Fn &&fn) { + bool succeeded = callbacks.try_emplace(name, std::forward(fn)).second; + (void)succeeded; + assert(succeeded && "adding a callback with a repeated name"); + } + + /// Returns a pointer to the implementation of the callback with the given + /// name, or null if it is not present in the registry. + const MatchCallbackFn *get(StringRef name) const { + auto iter = callbacks.find(name); + if (iter == callbacks.end()) { + return nullptr; + } + return &iter->getValue(); + } + +private: + llvm::StringMap callbacks; +}; + +//===---------------------------------------------------------------------===// +// Case-specific matcher builders. +//===---------------------------------------------------------------------===// + +struct MatchedReductionCaptures { + int64_t reductionRank = 0; + int64_t maybeLeadingRank = 0; + int64_t maybeTrailingRank = 0; + SmallVector leadingOpSizes = {}; + SmallVector reductionOpSizes = {}; + SmallVector trailingOpSizes = {}; + int64_t reductionOutputElementalTypeBitWidth = 0; + int64_t maybeLeadingOutputElementalTypeBitWidth = 0; + int64_t maybeTrailingOutputElementalTypeBitWidth = 0; +}; + +struct MatchedMatmulCaptures { + linalg::ContractionDimensions contractionDims = {}; + Type lhsElementType, rhsElementType, outputElementType; + SmallVector matmulOpSizes = {}; + SmallVector indexingMaps; + + /// Helper functions. + int64_t rank() const { return matmulOpSizes.size(); } + /// Return all batches. + ArrayRef batches() const { return contractionDims.batch; } + /// Return the most minor candidate dimension for `m`. + int64_t m() const { return contractionDims.m.back(); } + /// Return the most minor candidate dimension for `n`. + int64_t n() const { return contractionDims.n.back(); } + /// Return the most minor candidate dimension for `k`. + int64_t k() const { return contractionDims.k.back(); } + /// AffineMap for indexing into the LHS. + AffineMap lhsIndexing() const { + assert(indexingMaps.size() == 3 && "expected 3 indexing maps"); + return indexingMaps[0]; + } + /// AffineMap for indexing into the RHS. + AffineMap rhsIndexing() const { + assert(indexingMaps.size() == 3 && "expected 3 indexing maps"); + return indexingMaps[1]; + } + /// AffineMap for indexing into the RES. + AffineMap resIndexing() const { + assert(indexingMaps.size() == 3 && "expected 3 indexing maps"); + return indexingMaps[2]; + } +}; + +/// Creates a group of matchers for: +/// +/// trailing(reduction(leading(), fill())) +/// +/// where trailing and leading are elementwise operations whose presence is +/// optional. Each matcher will capture the corresponding operation. If +/// `mustMatchEntireFunc` is set, the matcher additionally checks if all +/// tileable operations in the functions are captured. +void makeReductionMatcher(MatcherContext &context, + StructuredOpMatcher *&reductionCapture, + StructuredOpMatcher *&fillCapture, + StructuredOpMatcher *&leadingCapture, + StructuredOpMatcher *&trailingCapture, + MatchedReductionCaptures &captures, + bool mustMatchEntireFunc); +void makeReductionMatcher(MatcherContext &context, + StructuredOpMatcher *&reductionCapture, + MatchedReductionCaptures &captures, + bool mustMatchEntireFunc); +/// +/// trailing(matmul(*, *, fill())) +/// +/// where trailing and leading are elementwise operations whose presence is +/// optional. Each matcher will capture the corresponding operation. If +/// `mustMatchEntireFunc` is set, the matcher additionally checks if all +/// tileable operations in the functions are captured. +void makeMatmulMatcher(MatcherContext &matcherContext, + StructuredOpMatcher *&matmulCapture, + StructuredOpMatcher *&fillCapture, + StructuredOpMatcher *&trailingCapture, + MatchedMatmulCaptures &captures, + bool mustMatchEntireFunc); + +/// Create a group of matchers of batch mamtul with a fill: +/// +/// batch_matmul(*, *, fill()) +/// +/// and capture various useful quantities. If `mustMatchEntireFunc` is set, the +/// matcher additionally checks if all tileable operations in the functions are +/// captured. +void makeBatchMatmulMatcher(transform_ext::MatcherContext &matcherContext, + transform_ext::StructuredOpMatcher *&bmmCapture, + transform_ext::StructuredOpMatcher *&fillCapture, + transform_ext::MatchedMatmulCaptures &captures, + bool mustMatchEntireFunc); + +/// Create a group of matchers for a different code sequence of operations +/// matching exactly a softmax operation. +/// +/// %red = reduce_max(%0) +/// %sub = sub(%0, %red) +/// %exp = exp(%sub) +/// %sum = reduce_sum(%exp) +/// %mul = div(%exp, %%sum) +void makeSoftmaxMatcher(MatcherContext &context, + StructuredOpMatcher *&maxReductionCapture, + StructuredOpMatcher *&softmaxRootCapture); + +struct MatchedConvolutionCaptures { + Type inputElementType, filterElementType, outputElementType; + mlir::linalg::ConvolutionDimensions convolutionDims = {}; + SmallVector convolutionOpSizes = {}; + SmallVector trailingOpSizes = {}; + int64_t maybeTrailingOutputElementalTypeBitWidth = 0; + int64_t maybeFillElementalTypeBitWidth = 0; +}; + +/// Creates a group of matchers for: +/// +/// trailing(convolution(input, filter, fill())) +/// +/// where fill is a FillOp and trailing is an elementwise operation, both of +/// which is optional. Each matcher will capture the corresponding operation. If +/// `mustMatchEntireFunc` is set, the matcher additionally checks if all +/// tileable operations in the functions are captured. +void makeConvolutionMatcher(MatcherContext &context, + StructuredOpMatcher *&convolutionCapture, + StructuredOpMatcher *&fillCapture, + StructuredOpMatcher *&trailingCapture, + MatchedConvolutionCaptures &captures, + bool mustMatchEntireFunc); +void makeConvolutionMatcher(MatcherContext &context, + StructuredOpMatcher *&convolutionCapture, + MatchedConvolutionCaptures &captures, + bool mustMatchEntireFunc); + +struct MatchedPadCaptures { + int64_t rank = 0; + Type elementType; + SmallVector dims = {}; +}; + +/// Create a matcher for tensor.pad(*) without leading or trailing ops atm. +/// If `mustMatchEntireFunc` is set, the matcher additionally checks if all +/// tileable operations in the functions are captured. +void makePadMatcher(MatcherContext &context, CapturingOpMatcher *&padCapture, + MatchedPadCaptures &captures, bool mustMatchEntireFunc); + +/// Wraps the given matcher callback to indicate that it must capture all +/// tilable ops in the parent function. Expects the callback to accept the same +/// arguments as what is expected by MatchCallbacksRegistry::register, followed +/// by a bool. +template +auto wrapAsEntireFuncMatch(Fn &&fn) { + return [fn = std::move(fn)](MatchCallbackResult &res, Location loc, + const mlir::transform::TransformState &state, + ValueRange handles) { + return fn(res, loc, state, handles, true); + }; +} + +/// Wraps the given matcher callback to indicate that it can match subgraphs. +/// Expects the callback to accept the same arguments as what is expected by +/// MatchCallbacksRegistry::register, followed by a bool. +template +auto wrapAsPartialMatch(Fn &&fn) { + return [fn = std::move(fn)](MatchCallbackResult &res, Location loc, + const mlir::transform::TransformState &state, + ValueRange handles) { + return fn(res, loc, state, handles, false); + }; +} + +} // namespace transform_ext +} // namespace mlir + +#endif // IREE_COMPILER_GLOBALOPTIMIZATION_TRANSFORMMATCHERS_H_ From 971c671726f7a2c4fffdf839b3031278d41f66f2 Mon Sep 17 00:00:00 2001 From: Han || Alex <36247722+Alex-Wengg@users.noreply.github.com> Date: Wed, 27 May 2026 12:43:54 -0400 Subject: [PATCH 2/8] [Codegen] Move ErrorCheckingTrackingListener out of iree-dialects (#24466) Relocate ErrorCheckingTrackingListener into Codegen/Common (namespace mlir::iree_compiler) and repoint its users in the Common and LLVMGPU transform extensions. Drop the StructuredTransformOpsExtension registration and includes from CommonDialectRegistration and Interfaces: its transform ops (transform.iree.register_match_callbacks/match_callback/take_first/emit_remark) have no production users and their tests were removed in #24500. Build deps are updated accordingly. Progress toward retiring the iree-dialects dependency. Signed-off-by: Han || Alex <36247722+Alex-Wengg@users.noreply.github.com> --- .../iree/compiler/Codegen/Common/BUILD.bazel | 5 +- .../compiler/Codegen/Common/CMakeLists.txt | 5 +- .../Common/CommonDialectRegistration.cpp | 4 +- .../Common/ErrorCheckingTrackingListener.cpp | 40 ++++++++++++++++ .../Common/ErrorCheckingTrackingListener.h | 48 +++++++++++++++++++ .../Common/TransformExtensions/BUILD.bazel | 2 - .../Common/TransformExtensions/CMakeLists.txt | 2 - .../TransformExtensions/CommonExtensions.cpp | 4 +- .../compiler/Codegen/Interfaces/BUILD.bazel | 1 - .../Codegen/Interfaces/CMakeLists.txt | 1 - .../Codegen/Interfaces/Interfaces.cpp | 8 +--- .../LLVMGPU/TransformExtensions/BUILD.bazel | 2 - .../TransformExtensions/CMakeLists.txt | 2 - .../TransformExtensions/LLVMGPUExtensions.cpp | 3 +- 14 files changed, 98 insertions(+), 29 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 81167e58463b..fa6dd7e8158d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -103,6 +103,7 @@ iree_compiler_cc_library( "EncodingUtils.cpp", "EraseDeadAllocAndStores.cpp", "EraseHALDescriptorTypeFromMemRef.cpp", + "ErrorCheckingTrackingListener.cpp", "FastMathPatterns.cpp", "FissionTransferOpsInControlFlow.cpp", "FlattenMemRefSubspan.cpp", @@ -190,6 +191,7 @@ iree_compiler_cc_library( "CombineLayoutTransformation.h", "EmulateNarrowType.h", "EncodingUtils.h", + "ErrorCheckingTrackingListener.h", "FastMathPatterns.h", "PassUtils.h", "Passes.h", @@ -234,7 +236,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", @@ -332,7 +333,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions:LinalgExtExtensions", "//compiler/src/iree/compiler/Dialect/TensorExt/IR", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineUtils", @@ -377,7 +377,6 @@ iree_compiler_cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:DialectUtils", # TransformExtensions (needed for registration in the pass) - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", "//compiler/src/iree/compiler/Codegen/Common/TransformExtensions:CommonExtensions", "//compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions:LLVMCPUExtensions", "//compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions:LLVMGPUExtensions", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index a3f5d46293dd..0b38a523d4c4 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -56,6 +56,7 @@ iree_cc_library( "CombineLayoutTransformation.h" "EmulateNarrowType.h" "EncodingUtils.h" + "ErrorCheckingTrackingListener.h" "FastMathPatterns.h" "PassUtils.h" "Passes.h" @@ -96,6 +97,7 @@ iree_cc_library( "EncodingUtils.cpp" "EraseDeadAllocAndStores.cpp" "EraseHALDescriptorTypeFromMemRef.cpp" + "ErrorCheckingTrackingListener.cpp" "FastMathPatterns.cpp" "FissionTransferOpsInControlFlow.cpp" "FlattenMemRefSubspan.cpp" @@ -180,7 +182,6 @@ iree_cc_library( DEPS ::PassHeaders ::PassesIncGen - IREELinalgTransformDialect LLVMCore LLVMSupport MLIRAMDGPUDialect @@ -287,8 +288,6 @@ iree_cc_library( ::Common ::PassHeaders ::PassesIncGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRAMDGPUTransforms MLIRAffineDialect diff --git a/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp b/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp index d7187490c02c..69733a6d8b3b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" #include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" @@ -97,8 +96,7 @@ void registerTransformDialectTranslationDependentDialects( vector::registerBufferizableOpInterfaceExternalModels(registry); registry.addExtensions< - mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension, - transform_ext::StructuredTransformOpsExtension>(); + mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension>(); iree_compiler::registerTransformDialectCommonExtension(registry); iree_compiler::registerTransformDialectLLVMCPUExtension(registry); iree_compiler::registerTransformDialectLLVMGPUExtension(registry); diff --git a/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.cpp b/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.cpp new file mode 100644 index 000000000000..6d4353ab1a43 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.cpp @@ -0,0 +1,40 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "iree-codegen-error-checking-tracking-listener" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +namespace mlir::iree_compiler { + +void ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( + Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) { + // Certain ops can dropped safely. + if (isa(op)) { + LLVM_DEBUG(DBGS() << "Silently dropping scf.for op mapping\n"); + return; + } + + SmallVector diags; + diag.takeDiagnostics(diags); + if (!status.succeeded()) { + status.takeDiagnostics(diags); + } + status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags)); + + status = emitSilenceableFailure( + getTransformOp(), "!!! tracking listener failed to find replacement op"); + status.attachNote(op->getLoc()) << "replaced op"; + for (Value v : values) { + status.attachNote(v.getLoc()) << "replacement value"; + } +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h b/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h new file mode 100644 index 000000000000..35ea97d88136 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h @@ -0,0 +1,48 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_COMMON_ERRORCHECKINGTRACKINGLISTENER_H_ +#define IREE_COMPILER_CODEGEN_COMMON_ERRORCHECKINGTRACKINGLISTENER_H_ + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" + +namespace mlir::iree_compiler { + +/// A tracking listener for tensor IR that checks for payload replacement +/// errors. +class ErrorCheckingTrackingListener : public transform::TrackingListener { +public: + using transform::TrackingListener::TrackingListener; + + ~ErrorCheckingTrackingListener() override { + assert(status.succeeded() && "must check listener error state"); + } + + /// Return "true" if this tracking listener had a failure. + bool failed() const { return !status.succeeded(); } + + /// Check and return the current error state of this listener. In case of a + /// failure state, only the most recent error is returned. Afterwards, resets + /// the error state. + DiagnosedSilenceableFailure checkAndResetError() { + DiagnosedSilenceableFailure result(std::move(status)); + status = DiagnosedSilenceableFailure::success(); + return result; + } + +private: + void + notifyPayloadReplacementNotFound(Operation *op, ValueRange values, + DiagnosedSilenceableFailure &&diag) override; + + /// The error state of this listener. "Success" indicates that no error + /// happened so far. Otherwise, the status contains the most recent error. + DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success(); +}; + +} // namespace mlir::iree_compiler + +#endif // IREE_COMPILER_CODEGEN_COMMON_ERRORCHECKINGTRACKINGLISTENER_H_ diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel index 296f03528d75..6601bda95e5e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel @@ -75,8 +75,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Utils", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineUtils", diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt index 2f91228d2fde..a98181a05315 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CMakeLists.txt @@ -31,8 +31,6 @@ iree_cc_library( "CommonExtensionsOps.cpp.inc" DEPS ::CommonExtensionsOpGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRAffineDialect MLIRAffineUtils diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index b630323b05b1..8113593b4b6e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -6,8 +6,7 @@ #include "CommonExtensions.h" -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" -#include "iree-dialects/Transforms/TransformMatchers.h" +#include "iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h" #include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h" #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" @@ -48,6 +47,7 @@ #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel index bcba86ba95cd..7ff9311888eb 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel @@ -61,7 +61,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Dialect/PCF/ExternalInterfaces:ExternalModels", "//compiler/src/iree/compiler/Codegen/ExternalInterfaces:ExternalModels", "//compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions:LinalgExtExtensions", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:SideEffectInterfaces", diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt index c22bc95db95d..646bba0bd068 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt @@ -25,7 +25,6 @@ iree_cc_library( ::TensorMaskingOpInterface ::UKernelOpInterface ::VectorizableOpInterface - IREELinalgTransformDialect MLIRAMDGPUDialect MLIRAffineDialect MLIRAffineTransformOps diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp index e21cde562ec4..67db978cdd7e 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp @@ -18,9 +18,6 @@ #include "iree/compiler/Codegen/Interfaces/VectorizableOpInterface.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -// TODO: Remove this dependency once the transform dialect extensions -// have a better registration mechanism. -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" #include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h" #include "iree/compiler/Codegen/Interfaces/TensorMaskingOpInterface.h" #include "iree/compiler/Codegen/LLVMCPU/TransformExtensions/LLVMCPUExtensions.h" @@ -90,11 +87,8 @@ void registerCodegenInterfaces(DialectRegistry ®istry) { registerPCFExternalInterfaces(registry); registerBufferizationInterfaces(registry); registerTensorMaskingOpInterface(registry); - // TODO: Remove this dependency once the transform dialect extensions - // have a better registration mechanism. // TODO: when warranted, move to its own file. - registry.addExtensions(); + registry.addExtensions(); registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { linalg::GenericOp::attachInterface(*ctx); }); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel index df19585b67ab..3ee6e782ff41 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel @@ -66,8 +66,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms:VectorExtTransforms", "//compiler/src/iree/compiler/Codegen/LLVMGPU/Utils", "//compiler/src/iree/compiler/Codegen/Utils", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", "@llvm-project//mlir:AffineDialect", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt index e7ff813e374e..a98db3845cf4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/CMakeLists.txt @@ -31,8 +31,6 @@ iree_cc_library( "LLVMGPUExtensionsOps.cpp.inc" DEPS ::LLVMGPUExtensionsOpGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRAMDGPUDialect MLIRAffineDialect diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 36a59eca50a7..04387136a2bc 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -6,7 +6,7 @@ #include "LLVMGPUExtensions.h" -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" +#include "iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h" #include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h" #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" @@ -29,6 +29,7 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" From f5aa584cf6128191e2019f75a036491aaec2abc9 Mon Sep 17 00:00:00 2001 From: Alex-Wengg Date: Sat, 30 May 2026 22:34:23 -0400 Subject: [PATCH 3/8] [NFC] Apply clang-format include ordering fixes Pure reordering of include blocks in four files to satisfy the pre-commit clang-format hook (CI lint job was failing on this PR): - Codegen/Common/ErrorCheckingTrackingListener.cpp - Codegen/Interfaces/Interfaces.cpp - GlobalOptimization/TransformMatchers.cpp - GlobalOptimization/TransformMatchers.h clang-format wants llvm/* headers ordered before mlir/* within each include block (alphabetical). No removals, no behavior change. Signed-off-by: Alex-Wengg --- .../Codegen/Common/ErrorCheckingTrackingListener.cpp | 2 +- .../src/iree/compiler/Codegen/Interfaces/Interfaces.cpp | 8 ++++---- .../compiler/GlobalOptimization/TransformMatchers.cpp | 6 +++--- .../iree/compiler/GlobalOptimization/TransformMatchers.h | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.cpp b/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.cpp index 6d4353ab1a43..91285ec78cd3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.cpp @@ -6,8 +6,8 @@ #include "iree/compiler/Codegen/Common/ErrorCheckingTrackingListener.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "llvm/Support/Debug.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #define DEBUG_TYPE "iree-codegen-error-checking-tracking-listener" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp index 67db978cdd7e..69483e8aa7f0 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/Interfaces/Interfaces.h" +#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h" #include "iree/compiler/Codegen/Dialect/GPU/ExternalInterfaces/Interfaces.h" #include "iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.h" #include "iree/compiler/Codegen/Dialect/Map/ExternalInterfaces/Interfaces.h" @@ -15,11 +16,8 @@ #include "iree/compiler/Codegen/Interfaces/HoistableRegionOpInterface.h" #include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h" #include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h" -#include "iree/compiler/Codegen/Interfaces/VectorizableOpInterface.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h" #include "iree/compiler/Codegen/Interfaces/TensorMaskingOpInterface.h" +#include "iree/compiler/Codegen/Interfaces/VectorizableOpInterface.h" #include "iree/compiler/Codegen/LLVMCPU/TransformExtensions/LLVMCPUExtensions.h" #include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h" #include "iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.h" @@ -30,6 +28,7 @@ #include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" #include "mlir/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" @@ -47,6 +46,7 @@ #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h" #include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir::iree_compiler { diff --git a/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.cpp b/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.cpp index 9a4160fc3588..2c7ce728e0a9 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.cpp @@ -6,15 +6,15 @@ #include "iree/compiler/GlobalOptimization/TransformMatchers.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Interfaces/FunctionInterfaces.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/Debug.h" using namespace mlir; diff --git a/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.h b/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.h index 6de72611a792..49d33f83c07d 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.h +++ b/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.h @@ -11,12 +11,12 @@ #include #include +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringMap.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/Matchers.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/StringMap.h" namespace mlir { namespace transform_ext { From a1ac23bd0bfa86566c1a803c4ffa18a23db686c4 Mon Sep 17 00:00:00 2001 From: Han || Alex <36247722+Alex-Wengg@users.noreply.github.com> Date: Mon, 1 Jun 2026 21:23:38 -0400 Subject: [PATCH 4/8] [GlobalOpt] Reimplement softmax matcher natively (#24466) Replace the wholesale relocation of the iree-dialects TransformMatchers DSL with a self-contained, native matcher in RaiseSpecialOps.cpp, per the review feedback that this should port only what RaiseSpecialOps needs rather than carry the generic StructuredOpMatcher framework into GlobalOptimization. - Add a local matchSoftmax() plus small helpers that walk the softmax linalg-op graph directly (reduce_max -> sub -> exp -> reduce_add -> mul/reciprocal or div), handling both implicit (projected map) and explicit (pass-through generic) broadcasts. This is behaviorally faithful to makeSoftmaxMatcher, including the same-source invariant. - Delete GlobalOptimization/TransformMatchers.{h,cpp} (~3000 lines) and drop the transform-dialect build deps that only existed for them. The iree-dialects deps stay removed. - Add negative lit cases (wrong max init, mismatched source) alongside the existing softmax raising tests. Signed-off-by: Han || Alex <36247722+Alex-Wengg@users.noreply.github.com> --- .../compiler/GlobalOptimization/BUILD.bazel | 8 - .../GlobalOptimization/CMakeLists.txt | 8 - .../GlobalOptimization/RaiseSpecialOps.cpp | 291 ++- .../GlobalOptimization/TransformMatchers.cpp | 1845 ----------------- .../GlobalOptimization/TransformMatchers.h | 1201 ----------- .../test/raise_special_ops.mlir | 109 + 6 files changed, 391 insertions(+), 3071 deletions(-) delete mode 100644 compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.cpp delete mode 100644 compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.h diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel index 88fdc328a7c3..03dd920131e8 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel @@ -71,13 +71,11 @@ iree_compiler_cc_library( "RaiseSpecialOps.cpp", "RemoveZeroExtentTensors.cpp", "SimplifyPackUnpack.cpp", - "TransformMatchers.cpp", "Utils.cpp", "WarnOnUninitializedValues.cpp", ], hdrs = [ "Passes.h", - "TransformMatchers.h", "Utils.h", ], deps = [ @@ -107,31 +105,25 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LinalgInterfaces", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LinalgUtils", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:TensorUtils", - "@llvm-project//mlir:TransformDialect", - "@llvm-project//mlir:TransformDialectInterfaces", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt index 7d98a7ade9bf..050b3c950fa7 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt +++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt @@ -38,7 +38,6 @@ iree_cc_library( GlobalOptimization HDRS "Passes.h" - "TransformMatchers.h" "Utils.h" SRCS "CleanupNumericNarrowing.cpp" @@ -63,7 +62,6 @@ iree_cc_library( "RaiseSpecialOps.cpp" "RemoveZeroExtentTensors.cpp" "SimplifyPackUnpack.cpp" - "TransformMatchers.cpp" "Utils.cpp" "WarnOnUninitializedValues.cpp" DEPS @@ -71,30 +69,24 @@ iree_cc_library( ::PassesIncGen LLVMSupport MLIRAffineDialect - MLIRAnalysis MLIRArithDialect MLIRArithUtils MLIRControlFlowDialect - MLIRFuncDialect MLIRFunctionInterfaces MLIRIR MLIRLinalgDialect - MLIRLinalgInterfacesIncGenLib MLIRLinalgTransforms MLIRLinalgUtils MLIRMathDialect MLIRMemRefDialect MLIRMemRefTransforms MLIRPass - MLIRRewrite MLIRSCFDialect MLIRSCFTransforms MLIRSupport MLIRTensorDialect MLIRTensorTransforms MLIRTensorUtils - MLIRTransformDialect - MLIRTransformDialectInterfaces MLIRTransformUtils MLIRTransforms iree::compiler::Codegen::Common diff --git a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp index 658b1e9ade29..cb9a999d17dd 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp @@ -10,7 +10,6 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" #include "iree/compiler/GlobalOptimization/Passes.h" -#include "iree/compiler/GlobalOptimization/TransformMatchers.h" #include "iree/compiler/GlobalOptimization/Utils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" @@ -19,6 +18,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinAttributes.h" @@ -28,7 +28,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; -using transform_ext::StructuredOpMatcher; namespace mlir::iree_compiler::GlobalOptimization { @@ -374,6 +373,284 @@ class NamedImplicitCastOpConversion : public OpInterfaceRewritePattern { // Softmax Raising //===----------------------------------------------------------------------===// +// Recognizing a numerically-stabilized softmax means walking a small graph of +// `linalg.generic` ops. The helpers below match the individual nodes of that +// graph directly, replacing the generic `transform_ext` matcher framework that +// used to live in iree-dialects. The matched dataflow, working back from the +// root op, is: +// +// max = reduce_max(src) (reduction over the innermost dim) +// sub = src - broadcast(max) +// exp = exp(sub) +// sum = reduce_add(exp) (reduction over the innermost dim) +// root = exp * reciprocal(sum) OR exp / broadcast(sum) +// +// where each broadcast is either implicit (the consumer indexing map drops the +// innermost dim) or explicit (a pass-through `linalg.generic` that broadcasts +// along the innermost dim). + +// Returns true if every iterator of `op` is parallel. +static bool isAllParallel(linalg::GenericOp op) { + return llvm::all_of(op.getIteratorTypesArray(), [](utils::IteratorType t) { + return t == utils::IteratorType::parallel; + }); +} + +// Returns true if `operand`'s indexing map is the identity. +static bool isIdentityOperand(linalg::GenericOp op, OpOperand *operand) { + return op.getMatchingIndexingMap(operand).isIdentity(); +} + +// Returns true if `operand`'s indexing map is the projection that drops the +// innermost (last) loop dim, e.g. (d0, d1, d2) -> (d0, d1). +static bool dropsInnermostDim(linalg::GenericOp op, OpOperand *operand) { + AffineMap map = op.getMatchingIndexingMap(operand); + unsigned numLoops = op.getNumLoops(); + if (!map.isProjectedPermutation() || map.getNumResults() + 1 != numLoops) { + return false; + } + for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { + if (map.getDimPosition(i) != i) { + return false; + } + } + return true; +} + +// Returns the single compute op of `op`'s body when the body is exactly that op +// followed by a `linalg.yield` of its single result; otherwise nullptr. +static Operation *getSingleComputeOp(linalg::GenericOp op) { + Block *body = op.getBlock(); + if (body->getOperations().size() != 2) { + return nullptr; + } + Operation *computeOp = &body->front(); + auto yieldOp = cast(body->getTerminator()); + if (computeOp->getNumResults() != 1 || yieldOp.getNumOperands() != 1 || + yieldOp.getOperand(0).getDefiningOp() != computeOp) { + return nullptr; + } + return computeOp; +} + +// Returns true if `v` references block argument number `argNumber` of `block`. +static bool isBlockArg(Value v, Block *block, unsigned argNumber) { + auto arg = dyn_cast(v); + return arg && arg.getOwner() == block && arg.getArgNumber() == argNumber; +} + +// Returns true if `op`'s body is a single `OpTy` whose operands are the leading +// block arguments in order. When `commutative`, a two-operand body may also use +// the arguments in swapped order. +template +static bool hasSingleOpBody(linalg::GenericOp op, bool commutative = false) { + Operation *computeOp = getSingleComputeOp(op); + if (!isa_and_nonnull(computeOp)) { + return false; + } + Block *body = op.getBlock(); + if (commutative && computeOp->getNumOperands() == 2) { + return (isBlockArg(computeOp->getOperand(0), body, 0) && + isBlockArg(computeOp->getOperand(1), body, 1)) || + (isBlockArg(computeOp->getOperand(0), body, 1) && + isBlockArg(computeOp->getOperand(1), body, 0)); + } + for (auto [index, operand] : llvm::enumerate(computeOp->getOperands())) { + if (!isBlockArg(operand, body, index)) { + return false; + } + } + return true; +} + +// Returns true if `op`'s body just yields its (single) block argument, i.e. it +// is a pure broadcast/transpose of the input. +static bool isPassThroughBody(linalg::GenericOp op) { + Block *body = op.getBlock(); + if (body->getOperations().size() != 1) { + return false; + } + auto yieldOp = cast(body->getTerminator()); + for (auto [index, operand] : llvm::enumerate(yieldOp.getOperands())) { + if (!isBlockArg(operand, body, index)) { + return false; + } + } + return true; +} + +// Returns true if `v` is a float constant satisfying `pred`. +static bool isFloatConstant(Value v, llvm::function_ref pred) { + auto cstOp = v.getDefiningOp(); + return cstOp && pred(cstOp.value()); +} + +// Returns true if `v` is a `linalg.fill` whose scalar value is a float constant +// satisfying `pred`. +static bool isFilledWith(Value v, llvm::function_ref pred) { + auto fillOp = v.getDefiningOp(); + return fillOp && isFloatConstant(fillOp.getInputs()[0], pred); +} + +// Matches a single-input `linalg.generic` reduction over the innermost dim with +// a commutative `OpTy` body and an init produced by a `linalg.fill` of a +// constant satisfying `fillPred`. Returns the reduced value on success. +template +static Value +matchInnermostReduction(Value v, llvm::function_ref fillPred) { + auto op = v.getDefiningOp(); + if (!op || op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1) { + return {}; + } + SmallVector iterators = op.getIteratorTypesArray(); + if (iterators.empty() || iterators.back() != utils::IteratorType::reduction) { + return {}; + } + for (unsigned i = 0, e = iterators.size() - 1; i < e; ++i) { + if (iterators[i] != utils::IteratorType::parallel) { + return {}; + } + } + OpOperand *input = op.getDpsInputOperand(0); + OpOperand *init = op.getDpsInitOperand(0); + if (!isIdentityOperand(op, input) || !dropsInnermostDim(op, init) || + !hasSingleOpBody(op, /*commutative=*/true) || + !isFilledWith(init->get(), fillPred)) { + return {}; + } + return input->get(); +} + +// Matches a single-input all-parallel `linalg.generic` with identity-mapped +// input and output and an `OpTy` body. Returns the input value on success. +template +static Value matchUnaryElementwise(Value v) { + auto op = v.getDefiningOp(); + if (!op || op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1 || + !isAllParallel(op)) { + return {}; + } + if (!isIdentityOperand(op, op.getDpsInputOperand(0)) || + !isIdentityOperand(op, op.getDpsInitOperand(0)) || + !hasSingleOpBody(op)) { + return {}; + } + return op.getDpsInputOperand(0)->get(); +} + +// Matches a reciprocal `linalg.generic` (body `divf 1.0, %arg0`) that is +// single-input, all-parallel and identity-mapped. Returns the input value. +static Value matchReciprocal(Value v) { + auto op = v.getDefiningOp(); + if (!op || op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1 || + !isAllParallel(op)) { + return {}; + } + if (!isIdentityOperand(op, op.getDpsInputOperand(0)) || + !isIdentityOperand(op, op.getDpsInitOperand(0))) { + return {}; + } + auto divOp = dyn_cast_or_null(getSingleComputeOp(op)); + if (!divOp || + !isFloatConstant(divOp.getOperand(0), + [](APFloat f) { return f.convertToDouble() == 1.0; }) || + !isBlockArg(divOp.getOperand(1), op.getBlock(), 0)) { + return {}; + } + return op.getDpsInputOperand(0)->get(); +} + +// Resolves `operand` of `consumer`, which is expected to carry the broadcast of +// an innermost-reduced value, to that pre-broadcast value. Handles both the +// implicit broadcast (the operand's map drops the innermost dim) and the +// explicit broadcast (the operand is produced by a pass-through generic that +// broadcasts along the innermost dim). +static Value resolveBroadcastedReduction(linalg::GenericOp consumer, + OpOperand *operand) { + if (dropsInnermostDim(consumer, operand)) { + return operand->get(); + } + if (!isIdentityOperand(consumer, operand)) { + return {}; + } + auto bcast = operand->get().getDefiningOp(); + if (!bcast || bcast.getNumDpsInputs() != 1 || bcast.getNumDpsInits() != 1 || + !isAllParallel(bcast) || !isPassThroughBody(bcast) || + !dropsInnermostDim(bcast, bcast.getDpsInputOperand(0)) || + !isIdentityOperand(bcast, bcast.getDpsInitOperand(0))) { + return {}; + } + return bcast.getDpsInputOperand(0)->get(); +} + +// Recognizes the softmax graph rooted at `rootOp` and returns the softmax +// source value on success. +static FailureOr matchSoftmax(linalg::LinalgOp rootOp) { + auto root = dyn_cast(rootOp.getOperation()); + if (!root || root.getNumDpsInputs() != 2 || root.getNumDpsInits() != 1 || + !isAllParallel(root)) { + return failure(); + } + // The numerator (exp(...)) always flows in as input 0 with an identity map. + OpOperand *numerator = root.getDpsInputOperand(0); + OpOperand *denominator = root.getDpsInputOperand(1); + if (!isIdentityOperand(root, numerator)) { + return failure(); + } + + // The root normalizes either as `exp * reciprocal(sum)` or `exp / sum`. + Value sumValue; + if (hasSingleOpBody(root, /*commutative=*/true)) { + if (!dropsInnermostDim(root, denominator)) { + return failure(); + } + sumValue = matchReciprocal(denominator->get()); + } else if (hasSingleOpBody(root)) { + sumValue = resolveBroadcastedReduction(root, denominator); + } + if (!sumValue) { + return failure(); + } + + // exp = exp(sub); sum = reduce_add(exp). + Value expValue = numerator->get(); + Value subValue = matchUnaryElementwise(expValue); + Value summedValue = matchInnermostReduction( + sumValue, [](APFloat f) { return f.isZero(); }); + if (!subValue || summedValue != expValue) { + return failure(); + } + + // sub = src - broadcast(max). + auto subOp = subValue.getDefiningOp(); + if (!subOp || subOp.getNumDpsInputs() != 2 || subOp.getNumDpsInits() != 1 || + !isAllParallel(subOp) || + !isIdentityOperand(subOp, subOp.getDpsInitOperand(0)) || + !hasSingleOpBody(subOp)) { + return failure(); + } + OpOperand *subSource = subOp.getDpsInputOperand(0); + if (!isIdentityOperand(subOp, subSource)) { + return failure(); + } + Value maxValue = + resolveBroadcastedReduction(subOp, subOp.getDpsInputOperand(1)); + if (!maxValue) { + return failure(); + } + + // max = reduce_max(src), reducing the same source the subtraction reads. + Value source = subSource->get(); + Value reducedValue = + matchInnermostReduction(maxValue, [](APFloat f) { + return (f.isLargest() || f.isInfinity()) && f.isNegative(); + }); + if (reducedValue != source) { + return failure(); + } + return source; +} + class RaiseSoftmax : public OpInterfaceRewritePattern { public: using OpInterfaceRewritePattern::OpInterfaceRewritePattern; @@ -386,19 +663,15 @@ class RaiseSoftmax : public OpInterfaceRewritePattern { return failure(); } - transform_ext::MatcherContext matcherContext; - transform_ext::StructuredOpMatcher *maxReduction; - transform_ext::StructuredOpMatcher *softmaxroot; - makeSoftmaxMatcher(matcherContext, maxReduction, softmaxroot); - if (!matchPattern(linalgOp, *softmaxroot)) { + FailureOr src = matchSoftmax(linalgOp); + if (failed(src)) { return rewriter.notifyMatchFailure(linalgOp, "failed to match softmax root"); } - Value src = maxReduction->getCaptured()->getOperand(0); rewriter.setInsertionPoint(linalgOp); rewriter.replaceOpWithNewOp( - linalgOp, linalgOp->getResultTypes(), src, + linalgOp, linalgOp->getResultTypes(), *src, linalgOp.getDpsInitOperand(0)->get(), linalgOp.getNumLoops() - 1); return success(); } diff --git a/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.cpp b/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.cpp deleted file mode 100644 index 2c7ce728e0a9..000000000000 --- a/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.cpp +++ /dev/null @@ -1,1845 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/GlobalOptimization/TransformMatchers.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/Debug.h" -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/Interfaces/FunctionInterfaces.h" - -using namespace mlir; - -#define DEBUG_TYPE "transform-matchers" -#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " -#define DBGSNL() llvm::dbgs() << "\n[" DEBUG_TYPE "] " - -//===---------------------------------------------------------------------===// -// CapturingMatcherBase -//===---------------------------------------------------------------------===// - -void transform_ext::CapturingMatcherBase::getAllNested( - SmallVectorImpl &nested) { - - SetVector found; - found.insert(nested.begin(), nested.end()); - int64_t start = found.size(); - - auto appendOne = [&found](CapturingMatcherBase &one) { - found.insert(one.nestedCapturingMatchers.begin(), - one.nestedCapturingMatchers.end()); - for (CapturingValueMatcher *valueMatcher : - one.nestedCapturingValueMatchers) { - found.insert(valueMatcher->nestedCapturingMatchers.begin(), - valueMatcher->nestedCapturingMatchers.end()); - } - }; - - appendOne(*this); - for (int64_t position = start; position < found.size(); ++position) { - appendOne(*found[position]); - } - - llvm::append_range(nested, found.getArrayRef()); -} - -void transform_ext::CapturingMatcherBase::getAllNestedValueMatchers( - SmallVectorImpl &nested) { - - SetVector found; - found.insert(nested.begin(), nested.end()); - int64_t start = found.size(); - - auto appendOne = [&found](CapturingMatcherBase &one) { - found.insert(one.nestedCapturingValueMatchers.begin(), - one.nestedCapturingValueMatchers.end()); - for (CapturingOpMatcher *opMatcher : one.nestedCapturingMatchers) { - found.insert(opMatcher->nestedCapturingValueMatchers.begin(), - opMatcher->nestedCapturingValueMatchers.end()); - } - }; - - appendOne(*this); - for (int64_t position = start; position < found.size(); ++position) { - appendOne(*found[position]); - } - - llvm::append_range(nested, found.getArrayRef()); -} - -void transform_ext::CapturingMatcherBase::resetCapture() { - SmallVector nested; - getAllNested(nested); - for (CapturingOpMatcher *matcher : nested) { - matcher->captured = nullptr; - } - SmallVector nestedValue; - getAllNestedValueMatchers(nestedValue); - for (CapturingValueMatcher *matcher : nestedValue) { - matcher->captured = nullptr; - } -} - -//===---------------------------------------------------------------------===// -// CapturingOpMatcher -//===---------------------------------------------------------------------===// - -bool transform_ext::CapturingOpMatcher::checkAllTilableMatched( - Operation *parent, Operation *op, - ArrayRef matchers) { - LLVM_DEBUG(DBGS() << "all tilable ops captured"); - int64_t numTilableOps = 0; - if (!parent) { - return false; - } - parent->walk([&](TilingInterface Op) { ++numTilableOps; }); - - llvm::SmallPtrSet matched; - for (CapturingOpMatcher *nested : matchers) { - if (Operation *captured = nested->getCaptured()) { - matched.insert(captured); - } - } - - // Don't forget to include the root matcher. - matched.insert(op); - return numTilableOps == matched.size(); -} - -bool transform_ext::CapturingOpMatcher::match(Operation *op) { - auto debugRAII = llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "-------\n"); }); - LLVM_DEBUG(DBGS() << "matching: " << *op << "\n"); - - if (getCaptured()) { - LLVM_DEBUG(DBGS() << "found an already captured op: "); - if (getCaptured() == op) { - LLVM_DEBUG(llvm::dbgs() << "same\n"); - return true; - } else { - LLVM_DEBUG(llvm::dbgs() << "different\n"); - return false; - } - } - - if (!llvm::all_of(predicates, [op](const PredicateFn &fn) { - bool result = fn(op); - LLVM_DEBUG(llvm::dbgs() << ": " << result << "\n"); - return result; - })) { - return false; - } - - captured = op; - return true; -} - -void transform_ext::CapturingOpMatcher::debugOutputForCreate( - ArrayRef opNames) { - LLVM_DEBUG(DBGS() << "operation type is one of {"; - llvm::interleaveComma(opNames, llvm::dbgs()); llvm::dbgs() << "}"); -} - -/// Apply the given matcher to the given object, produce debug messages. -template ::template args<0>> -static bool recursiveMatch(Matcher &matcher, Object &object, - StringRef extraMessage = "") { - LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "] " << "start recursive match (" - << extraMessage << ") {\n"); - bool result = matcher.match(object); - LLVM_DEBUG(DBGS() << "} end recursive match"); - return result; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::alternatives( - transform_ext::CapturingOpMatcher &first, - transform_ext::CapturingOpMatcher &second) { - addPredicate([&first, &second](Operation *op) { - LLVM_DEBUG(DBGS() << "matching alternatives\n"); - return recursiveMatch(first, op, "alternative 1") || - recursiveMatch(second, op, "alternative 2"); - }); - return *this; -} - -//---------------------------------------------------------------------------// -// Predicates for operands and results. -//---------------------------------------------------------------------------// - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::operand(transform_ext::NumEqualsTo num) { - addPredicate([=](Operation *op) { - LLVM_DEBUG(DBGS() << "operation has exactly " << num.value << " operands"); - return num.value == op->getNumOperands(); - }); - return *this; -} - -/// If `pos` is negative, returns the number of the operand in op starting from -/// the last. For example, -1 means the last operand, -2 means the -/// second-to-last, etc. Returns nullopt if pos is out-of-bounds, both positive -/// and negative. -static std::optional remapNegativeOperandNumber(int64_t pos, - Operation *op) { - int64_t updated = pos < 0 ? op->getNumOperands() + pos : pos; - if (updated < 0 || updated >= op->getNumOperands()) { - LLVM_DEBUG(DBGS() << "match operand #" << pos - << "that does not exist in the operation"); - return std::nullopt; - } - return updated; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::operand(int64_t pos, - CapturingOpMatcher &nested) { - addPredicate([pos, &nested](Operation *op) { - std::optional operandNo = remapNegativeOperandNumber(pos, op); - if (!operandNo) { - return false; - } - LLVM_DEBUG(DBGS() << "operand #" << pos << " is defined by an operation"); - Operation *definingOp = op->getOperand(*operandNo).getDefiningOp(); - if (!definingOp) { - return false; - } - return recursiveMatch(nested, definingOp); - }); - recordNestedMatcher(nested); - return *this; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::operand(int64_t pos, - CapturingValueMatcher &nested) { - addPredicate([pos, &nested](Operation *op) { - std::optional operandNo = remapNegativeOperandNumber(pos, op); - if (!operandNo) { - return false; - } - LLVM_DEBUG(DBGS() << "operand #" << pos << " is"); - Value operand = op->getOperand(*operandNo); - return recursiveMatch(nested, operand); - }); - recordNestedMatcher(nested); - return *this; -} - -transform_ext::CapturingOpMatcher &transform_ext::CapturingOpMatcher::operand( - int64_t position, std::function floatValueFn) { - addPredicate([position, - floatValueFn = std::move(floatValueFn)](Operation *op) -> bool { - std::optional operandNo = remapNegativeOperandNumber(position, op); - if (!operandNo) { - return false; - } - - LLVM_DEBUG(DBGS() << "operand #" << *operandNo - << " is a special floating point constant"); - auto cstOp = - op->getOperand(*operandNo).getDefiningOp(); - if (!cstOp) { - return false; - } - return floatValueFn(cstOp.value()); - }); - - return *this; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::operand(int64_t position, ConstantFloatOne) { - return operand(position, - [](llvm::APFloat value) { return value.isExactlyValue(1.0); }); -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::result(transform_ext::NumEqualsTo num) { - addPredicate([=](Operation *op) { - LLVM_DEBUG(DBGS() << "operation has exactly " << num.value << " results"); - return num.value == op->getNumResults(); - }); - return *this; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::result(int64_t pos, - CapturingValueMatcher &nested) { - addPredicate([pos, &nested](Operation *op) { - int64_t updated = pos < 0 ? op->getNumResults() + pos : pos; - if (updated < 0 || updated >= op->getNumResults()) { - LLVM_DEBUG(DBGS() << "matching result #" << pos - << " that does not exist in the operation"); - return false; - } - LLVM_DEBUG(DBGS() << "result #" << pos << " is"); - Value result = op->getResult(updated); - return recursiveMatch(nested, result); - }); - recordNestedMatcher(nested); - return *this; -} - -//===---------------------------------------------------------------------===// -// CapturingValueMatcher -//===---------------------------------------------------------------------===// - -namespace { -struct DebugPrintValueWrapper { - Value value; -}; - -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const DebugPrintValueWrapper &wrapper) { - if (auto opResult = dyn_cast(wrapper.value)) { - return os << "op result #" << opResult.getResultNumber() << " in " - << wrapper.value; - } - - auto blockArg = cast(wrapper.value); - os << "block argument #" << blockArg.getArgNumber(); - Block *parentBlock = blockArg.getParentBlock(); - Region *parentRegion = parentBlock->getParent(); - if (!parentRegion) { - os << " of a detached block:\n"; - parentBlock->print(os); - return os; - } - - os << " of block #" - << std::distance(parentRegion->begin(), parentBlock->getIterator()); - Operation *parentOp = parentRegion->getParentOp(); - if (!parentOp) { - os << " of a detached region:\n"; - for (Block &b : *parentRegion) { - b.print(os); - } - return os; - } - - os << " in region #" << parentRegion->getRegionNumber() << " of " - << *parentOp; - return os; -} -} // namespace - -bool transform_ext::CapturingValueMatcher::match(Value value) { - auto debugRAII = llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "-------\n"); }); - LLVM_DEBUG(DBGS() << "matching " << DebugPrintValueWrapper{value} << "\n"); - - if (getCaptured()) { - LLVM_DEBUG(DBGS() << "found an already captured value: "); - if (getCaptured() == value) { - LLVM_DEBUG(llvm::dbgs() << "same\n"); - return true; - } else { - LLVM_DEBUG(llvm::dbgs() << "different\n"); - return false; - } - } - - for (const PredicateFn &fn : predicates) { - bool result = fn(value); - LLVM_DEBUG(llvm::dbgs() << ": " << result << "\n"); - if (!result) { - return false; - } - } - - captured = value; - return true; -} - -transform_ext::ShapedValueMatcher::ShapedValueMatcher() - : CapturingValueMatcher() { - addPredicate([](Value value) { - LLVM_DEBUG(DBGS() << "value is of shaped type"); - return value && isa(value.getType()); - }); -} - -transform_ext::ShapedValueMatcher & -transform_ext::ShapedValueMatcher::rank(transform_ext::CaptureRank capture) { - addPredicate([=](Value value) { - LLVM_DEBUG(DBGS() << "capturing shaped value rank"); - capture.value = cast(value.getType()).getRank(); - return true; - }); - return *this; -} - -transform_ext::ShapedValueMatcher & -transform_ext::ShapedValueMatcher::dim(int64_t dimension, CaptureDim capture) { - addPredicate([=](Value value) { - LLVM_DEBUG(DBGS() << "capturing shaped value dimension " << dimension); - capture.value = cast(value.getType()).getDimSize(dimension); - return true; - }); - return *this; -} - -transform_ext::ShapedValueMatcher & -transform_ext::ShapedValueMatcher::dim(AllDims tag, CaptureDims captures) { - (void)tag; - addPredicate([=](Value value) { - LLVM_DEBUG(DBGS() << "capturing all shaped value dimensions"); - ArrayRef shape = cast(value.getType()).getShape(); - captures.value.assign(shape.begin(), shape.end()); - return true; - }); - return *this; -} - -transform_ext::ShapedValueMatcher & -transform_ext::ShapedValueMatcher::elementType(CaptureElementType captures) { - addPredicate([=](Value value) { - LLVM_DEBUG(DBGS() << "capturing elementType"); - captures.value = cast(value.getType()).getElementType(); - return true; - }); - return *this; -} - -//===---------------------------------------------------------------------===// -// Constraints on op rank and dims. -//===---------------------------------------------------------------------===// - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::rank(NumGreaterEqualTo minRank) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "rank >= " << minRank.value); - return linalgOp.getNumLoops() >= minRank.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::rank(NumLowerEqualTo maxRank) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "rank <= " << maxRank.value); - return linalgOp.getNumLoops() <= maxRank.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::rank(NumEqualsTo exactRank) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "rank == " << exactRank.value); - return linalgOp.getNumLoops() == exactRank.value; - }); -} - -StringRef stringifyShapeKind(transform_ext::ShapeKind kind) { - switch (kind) { - case transform_ext::ShapeKind::Static: - return "static"; - case transform_ext::ShapeKind::Dynamic: - return "dynamic"; - } - llvm_unreachable("unhandled shape kind"); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(SmallVector &&dimensions, - ShapeKind kind) { - return addPredicate([dimensions = std::move(dimensions), - kind](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "dimensions ["; - llvm::interleaveComma(dimensions, llvm::dbgs()); - llvm::dbgs() << "] are " << stringifyShapeKind(kind)); - SmallVector shape = linalgOp.getStaticLoopRanges(); - for (auto dimension : dimensions) { - int64_t transformedDimension = - dimension >= 0 ? dimension : shape.size() + dimension; - if (transformedDimension < 0 || transformedDimension >= shape.size()) { - return false; - } - if (ShapedType::isDynamic(shape[transformedDimension]) ^ - (kind == ShapeKind::Static)) { - continue; - } - return false; - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all dimensions are " << stringifyShapeKind(kind)); - SmallVector shape = linalgOp.getStaticLoopRanges(); - return llvm::all_of(shape, [=](int64_t dimension) { - return ShapedType::isDynamic(dimension) ^ (kind == ShapeKind::Static); - }); - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(SmallVector &&dimensions, - utils::IteratorType kind) { - return addPredicate([dimensions = std::move(dimensions), - kind](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "dimensions ["; - llvm::interleaveComma(dimensions, llvm::dbgs()); - llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind)); - int64_t rank = linalgOp.getNumLoops(); - for (auto dimension : dimensions) { - int64_t transformedDimension = - dimension >= 0 ? dimension : rank + dimension; - if (transformedDimension < 0 || transformedDimension >= rank) { - return false; - } - utils::IteratorType iteratorKind = - linalgOp.getIteratorTypesArray()[transformedDimension]; - if (iteratorKind == kind) { - continue; - } - return false; - } - return true; - }); -} -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDims tag, utils::IteratorType kind) { - return dim(AllDimsExcept({}), kind); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDimsExcept &&dims, - utils::IteratorType kind) { - return addPredicate([dimensions = std::move(dims), - kind](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all dimensions except ["; - llvm::interleaveComma(dimensions.getExcluded(), llvm::dbgs()); - llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind)); - int64_t rank = linalgOp.getNumLoops(); - llvm::SmallDenseSet excludedDims; - for (int64_t dim : dimensions.getExcluded()) { - excludedDims.insert(dim >= 0 ? dim : rank + dim); - } - - for (auto [index, type] : - llvm::enumerate(linalgOp.getIteratorTypesArray())) { - if (excludedDims.contains(index)) { - continue; - } - if (type == kind) { - continue; - } - return false; - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(int64_t dimension, - DivisibleBy divisibleBy) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "dimension " << dimension << " is divisible by " - << divisibleBy.value); - int64_t rank = linalgOp.getNumLoops(); - int64_t transformedDimension = - dimension >= 0 ? dimension : rank + dimension; - if (transformedDimension >= rank) { - return false; - } - - int64_t size = linalgOp.getStaticLoopRanges()[transformedDimension]; - return !ShapedType::isDynamic(size) && (size % divisibleBy.value == 0); - }); -} - -//===---------------------------------------------------------------------===// -// Capture directives. -//===---------------------------------------------------------------------===// -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::rank(CaptureRank capture) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture rank"); - capture.value = linalgOp.getNumLoops(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(int64_t dimension, CaptureDim capture) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture dimension"); - int64_t rank = linalgOp.getNumLoops(); - int64_t transformedDimension = - dimension >= 0 ? dimension : rank + dimension; - if (transformedDimension >= rank) { - return false; - } - - capture.value = linalgOp.getStaticLoopRanges()[transformedDimension]; - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDims tag, CaptureDims captures) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture all dimensions"); - captures.value = linalgOp.getStaticLoopRanges(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::indexingMaps( - CaptureIndexingMaps indexingMaps) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture indexing maps"); - indexingMaps.value = linalgOp.getIndexingMapsArray(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::contractionDims( - CaptureContractionDims contractionDims) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture contraction dimensions"); - StringRef convMessage = linalg::detail::getMatchContractionMessage( - mlir::linalg::detail::isContractionInterfaceImpl( - linalgOp, &contractionDims.value)); - if (convMessage.empty()) { - return true; - } - LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")"); - return false; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::convolutionDims(CaptureConvDims convDims) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture convolution dimensions"); - StringRef convMessage = linalg::detail::getMatchConvolutionMessage( - mlir::linalg::detail::isConvolutionInterfaceImpl(linalgOp, - &convDims.value)); - if (convMessage.empty()) { - return true; - } - LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")"); - return false; - }); -} - -transform_ext::StructuredOpMatcher::StructuredOpMatcher( - StructuredOpMatcher &A, StructuredOpMatcher &B) { - - addPredicate([&A, &B](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "start recursive lhs OR match {\n"); - { - auto debugRAII = llvm::scope_exit( - [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (A.match(linalgOp)) { - return true; - } - } - LLVM_DEBUG(DBGS() << "start recursive rhs OR match {\n"); - { - auto debugRAII = llvm::scope_exit( - [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (B.match(linalgOp)) { - return true; - } - } - return false; - }); - recordNestedMatcher(A); - recordNestedMatcher(B); -} - -//===---------------------------------------------------------------------===// -// Constraints on input operands. -//===---------------------------------------------------------------------===// - -void transform_ext::StructuredOpMatcher::addInputMatcher( - int64_t position, std::function matcher, - OptionalMatch optional) { - addInputMatcher( - position, - // No need to handle optional inside the lambda, the wrapper will do that. - [matcher = std::move(matcher)](Value value) { - Operation *definingOp = value.getDefiningOp(); - return definingOp && matcher(definingOp); - }, - optional); -} - -void transform_ext::StructuredOpMatcher::addInputMatcher( - int64_t position, std::function matcher, - OptionalMatch optional) { - addPredicate([position, optional, matcher = std::move(matcher)]( - linalg::LinalgOp linalgOp) -> bool { - int64_t transformedPosition = - position >= 0 ? position : linalgOp.getNumDpsInputs() + position; - if (transformedPosition >= linalgOp.getNumDpsInputs()) { - LLVM_DEBUG(DBGS() << "input operand #" << position - << " does not exist but match required"); - return false; - } - - LLVM_DEBUG(DBGS() << "input operand #" << position - << (optional.value ? " (optional match) " : " ") - << "is\n"); - - // We MUST run the matcher at this point, even if the match is optional, - // to allow for capture. - LLVM_DEBUG(DBGS() << "start recursive match {\n"); - auto debugRAII = - llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (matcher(linalgOp.getDpsInputOperand(transformedPosition)->get())) { - return true; - } - return optional.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(AllOperands tag, IsPermutation) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all input operands have permutation maps"); - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand *operand : linalgOp.getDpsInputOperands()) { - if (!linalgOp.getMatchingIndexingMap(operand).isPermutation()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(AllOperands tag, - IsProjectedPermutation) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all input operands have projected permutation maps"); - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand *operand : linalgOp.getDpsInputOperands()) { - if (!linalgOp.getMatchingIndexingMap(operand).isProjectedPermutation()) { - return false; - } - } - return true; - }); -} - -/// Helper to check if the map is an identity map with a projected dim. -static bool isProjectedMap(AffineMap map, int64_t projectedDim) { - if (!map.isProjectedPermutation()) { - return false; - } - int64_t dimCounter = 0; - for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { - // Skip the project dim. - if (dimCounter == projectedDim) { - dimCounter++; - } - if (map.getDimPosition(i) != dimCounter++) { - return false; - } - } - return true; -} - -/// Helper to turn a potentially negative index to positive within the range -/// [0, ub) and indicate whether the transformed index is in bounds. -static bool makeValidPositiveIndex(int64_t &index, int64_t ub) { - int64_t positiveIndex = index >= 0 ? index : ub + index; - if (positiveIndex < 0 || ub < positiveIndex) { - LLVM_DEBUG(DBGSNL() << " index out of range"); - return false; - } - index = positiveIndex; - return true; -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(SmallVector &&positions, - IsProjected dim) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "operands "; - llvm::interleaveComma(positions, llvm::dbgs()); - llvm::dbgs() << " have a permutation maps with " << dim.value - << " projected"); - int64_t updatedDim = dim.value; - if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops())) { - return false; - } - for (int64_t position : positions) { - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, - linalgOp.getNumDpsInputs())) { - return false; - } - OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); - if (!isProjectedMap(linalgOp.getMatchingIndexingMap(operand), - updatedDim)) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(AllOperands tag, IsIdentity) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all input operands have identity maps"); - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand *operand : linalgOp.getDpsInputOperands()) { - if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(SmallVector &&positions, - IsIdentity) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operands "; - llvm::interleaveComma(positions, llvm::dbgs()); - llvm::dbgs() << " have identity maps"); - // all_of with a lambda requires const-casting dance, so using a loop. - for (int64_t position : positions) { - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, - linalgOp.getNumDpsInputs())) { - return false; - } - OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); - if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, - ElementTypeBitWidth width) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operand #" << position - << " has elemental type with bit width " << width.value); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); - return shapedType && shapedType.getElementType().isIntOrFloat() && - shapedType.getElementType().getIntOrFloatBitWidth() == width.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, - CaptureElementTypeBitWidth width) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operand #" << position << " capture bitwidth"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); - if (!shapedType || !shapedType.getElementType().isIntOrFloat()) { - return false; - } - width.value = shapedType.getElementType().getIntOrFloatBitWidth(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, - CaptureElementType elem) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operand #" << position - << " capture element type"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); - if (!shapedType) { - LLVM_DEBUG(DBGSNL() << " not a shaped type"); - return false; - } - elem.value = shapedType.getElementType(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(NumEqualsTo num) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "number of input operands == " << num.value); - return linalgOp.getNumDpsInputs() == num.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, - ConstantFloatMinOrMinusInf) { - return input(position, [](llvm::APFloat f) { - return (f.isLargest() || f.isInfinity()) && f.isNegative(); - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, ConstantFloatZero) { - return input(position, [](llvm::APFloat f) { return f.isZero(); }); -} - -transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::input( - int64_t position, std::function floatValueFn) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operand #" << position - << " is a special floating point constant"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { - return false; - } - auto cstOp = linalgOp.getDpsInputOperand(updatedPosition) - ->get() - .getDefiningOp(); - if (!cstOp) { - return false; - } - return floatValueFn(cstOp.value()); - }); -} - -//===---------------------------------------------------------------------===// -// Constraints on output operands. -//===---------------------------------------------------------------------===// - -void transform_ext::StructuredOpMatcher::addOutputMatcher( - int64_t position, std::function matcher, - OptionalMatch optional) { - addPredicate([position, optional, matcher = std::move(matcher)]( - linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position - << (optional.value ? " (optional match) " - : " (mandatory match) ") - << "is produced by\n"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - Operation *definingOp = - linalgOp.getDpsInitOperand(updatedPosition)->get().getDefiningOp(); - if (!definingOp) { - return optional.value; - } - // We MUST run the matcher at this point, even if the match is optional, - // to allow for capture. - LLVM_DEBUG(DBGS() << "start recursive match {\n"); - auto debugRAII = - llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (matcher(definingOp)) { - return true; - } - return optional.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, IsPermutation) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all output operands have permutation maps"); - for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!linalgOp.getMatchingIndexingMap(&operand).isPermutation()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, - IsProjectedPermutation) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all output operands have projected permutation maps"); - for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!linalgOp.getMatchingIndexingMap(&operand).isProjectedPermutation()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, IsProjected dim) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all output operands have a maps with projected"); - int64_t updatedDim = dim.value; - if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops())) { - return false; - } - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!isProjectedMap(linalgOp.getMatchingIndexingMap(&operand), - updatedDim)) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, IsIdentity) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all output operands have identity permutation maps"); - for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!linalgOp.getMatchingIndexingMap(&operand).isIdentity()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - ElementTypeBitWidth width) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position - << " has elemental type with bit width " << width.value); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); - return shapedType && shapedType.getElementType().isIntOrFloat() && - shapedType.getElementType().getIntOrFloatBitWidth() == width.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - CaptureElementTypeBitWidth width) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position << " capture bitwidth"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); - if (!shapedType || !shapedType.getElementType().isIntOrFloat()) { - LLVM_DEBUG(DBGSNL() << " could not infer element type"); - return false; - } - width.value = shapedType.getElementType().getIntOrFloatBitWidth(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - CaptureElementType elem) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position - << " capture element type"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); - if (!shapedType) { - LLVM_DEBUG(DBGSNL() << " not a shaped type"); - return false; - } - elem.value = shapedType.getElementType(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - SingleCombinerReduction tag) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position - << " is populated by a single-combiner reduction"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - SmallVector combinerOps; - return matchReduction(linalgOp.getRegionOutputArgs(), updatedPosition, - combinerOps) && - llvm::hasSingleElement(combinerOps); - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(NumEqualsTo num) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "number of output operands == " << num.value); - return linalgOp.getNumDpsInits() == num.value; - }); -} - -//===---------------------------------------------------------------------===// -// Constraints on results. -//===---------------------------------------------------------------------===// - -void transform_ext::StructuredOpMatcher::addResultMatcher( - int64_t position, HasAnyUse tag, std::function matcher, - OptionalMatch optional) { - addPredicate([matcher = std::move(matcher), optional, - position](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "result #" << position - << (optional.value ? " (optional match) " - : " (mandatory match) ") - << "has a use\n"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp->getNumResults())) { - return false; - } - - // We MUST run the matcher at this point, even if the match is optional, - // to allow for capture. - LLVM_DEBUG(DBGS() << "start recursive match {\n"); - auto debugRAII = - llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (llvm::any_of(linalgOp->getResult(updatedPosition).getUsers(), - [&matcher](Operation *op) { return matcher(op); })) { - return true; - } - return optional.value; - }); -} - -//===-------------------------------------------------------------------===// -// Constraints on op region. -//===-------------------------------------------------------------------===// - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::singleOpWithCanonicaleArgs( - StringRef opcode, bool commutative) { - return addPredicate([=](linalg::LinalgOp linalgOp) { - if (linalgOp.getBlock()->getOperations().size() != 2) { - return false; - } - Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); - if (innerOp->getName().getStringRef() != opcode || - innerOp->getNumResults() != 1) { - return false; - } - Operation *yieldOp = linalgOp.getBlock()->getTerminator(); - if (yieldOp->getNumOperands() != 1) { - return false; - } - if (yieldOp->getOperand(0).getDefiningOp() != innerOp) { - return false; - } - if (commutative && innerOp->getNumOperands() == 2) { - auto arg0 = dyn_cast(innerOp->getOperand(0)); - auto arg1 = dyn_cast(innerOp->getOperand(1)); - if (!arg0 || !arg1) { - return false; - } - if (arg0.getParentBlock() != linalgOp.getBlock() || - arg1.getParentBlock() != linalgOp.getBlock()) { - return false; - } - if (!((arg0.getArgNumber() == 0 && arg1.getArgNumber() == 1) || - (arg1.getArgNumber() == 0 && arg0.getArgNumber() == 1))) { - return false; - } - } else { - for (auto [index, operand] : llvm::enumerate(innerOp->getOperands())) { - auto arg = dyn_cast(operand); - if (!arg || arg.getParentBlock() != linalgOp.getBlock() || - arg.getArgNumber() != index) { - return false; - } - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::isFloatReciprocal() { - return addPredicate([=](linalg::LinalgOp linalgOp) { - LLVM_DEBUG(DBGS() << "op region represents a reciprocal operation"); - if (linalgOp.getBlock()->getOperations().size() != 2) { - return false; - } - Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); - if (!isa(innerOp) || innerOp->getNumResults() != 1) { - return false; - } - Operation *yieldOp = linalgOp.getBlock()->getTerminator(); - if (yieldOp->getNumOperands() != 1) { - return false; - } - if (yieldOp->getOperand(0).getDefiningOp() != innerOp) { - return false; - } - auto cst = innerOp->getOperand(0).getDefiningOp(); - if (!cst || cst.value().convertToDouble() != 1.0) { - return false; - } - auto arg = dyn_cast(innerOp->getOperand(1)); - if (!arg || arg.getParentBlock() != linalgOp.getBlock() || - arg.getArgNumber() != 0) { - return false; - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::passThroughOp() { - return addPredicate([=](linalg::LinalgOp linalgOp) { - if (linalgOp.getBlock()->getOperations().size() != 1) { - return false; - } - Operation *yieldOp = linalgOp.getBlock()->getTerminator(); - for (auto [index, operand] : llvm::enumerate(yieldOp->getOperands())) { - auto arg = dyn_cast(operand); - if (!arg || arg.getParentBlock() != linalgOp.getBlock() || - arg.getArgNumber() != index) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::hasContractionBody( - function_ref isaElemOpTy, - function_ref isaReductionOpTy, StringRef elemOpName, - StringRef reductionOpName) { - return addPredicate([=](linalg::LinalgOp linalgOp) { - LLVM_DEBUG(DBGS() << "op region is a " << elemOpName << "/" - << reductionOpName << " contraction ("); - auto scopeExitPrinter = - llvm::scope_exit([] { LLVM_DEBUG(llvm::dbgs() << " check failed)"); }); - - Block *body = linalgOp.getBlock(); - if (!llvm::hasNItems(*body, 3)) { - LLVM_DEBUG(llvm::dbgs() << "three-operation body"); - return false; - } - if (body->getNumArguments() != 3) { - LLVM_DEBUG(llvm::dbgs() << "three-argument block"); - return false; - } - - Operation *elemOp = &(*linalgOp.getBlock()->getOperations().begin()); - Operation *reductionOp = elemOp->getNextNode(); - Operation *yieldOp = reductionOp->getNextNode(); - if (!isaElemOpTy(elemOp)) { - LLVM_DEBUG(llvm::dbgs() << "first operation is a " << elemOpName); - return false; - } - if (!isaReductionOpTy(reductionOp)) { - LLVM_DEBUG(llvm::dbgs() << "second operation is a " << reductionOpName); - return false; - } - if (yieldOp->getNumOperands() != 1) { - LLVM_DEBUG(llvm::dbgs() << "one value yielded"); - return false; - } - if (yieldOp->getOperand(0).getDefiningOp() != reductionOp) { - LLVM_DEBUG(llvm::dbgs() << "yielded value produced by the second op"); - return false; - } - if (elemOp->getNumOperands() != 2 || elemOp->getNumResults() != 1) { - LLVM_DEBUG(llvm::dbgs() << "first op has two operands and one result"); - return false; - } - if (reductionOp->getNumOperands() != 2 || - reductionOp->getNumResults() != 1) { - LLVM_DEBUG(llvm::dbgs() << "second op has two operands and one result"); - return false; - } - - SmallVector expectedReductionOperands = {body->getArgument(2), - elemOp->getResult(0)}; - if (!llvm::equal(expectedReductionOperands, reductionOp->getOperands()) && - !llvm::equal(llvm::reverse(expectedReductionOperands), - reductionOp->getOperands())) { - LLVM_DEBUG(llvm::dbgs() << "operands of the second op"); - return false; - } - - ValueRange expectedElemOperands = body->getArguments().take_front(2); - if (!llvm::equal(expectedElemOperands, elemOp->getOperands()) && - !llvm::equal(llvm::reverse(expectedElemOperands), - elemOp->getOperands())) { - LLVM_DEBUG(llvm::dbgs() << "operands of the first op"); - return false; - } - - scopeExitPrinter.release(); - LLVM_DEBUG(llvm::dbgs() << "success)"); - return true; - }); -} - -void transform_ext::detail::debugOutputForConcreteOpMatcherConstructor( - StringRef name) { - LLVM_DEBUG(DBGS() << "op is a " << name << "'"); -} - -//===---------------------------------------------------------------------===// -// TensorPadOpMatcher -//===---------------------------------------------------------------------===// - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::low(ArrayRef sizes) { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG({ - DBGS() << "low pad sizes are "; - llvm::interleaveComma(sizes, llvm::dbgs()); - }); - for (auto [ofr, sz] : llvm::zip(tensorPad.getMixedLowPad(), sizes)) { - if (isConstantIntValue(ofr, sz)) { - return false; - } - } - return true; - }); -} - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::low(AllDims tag, int64_t size) { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG(DBGS() << "all low pad sizes are " << size); - return llvm::all_of(tensorPad.getMixedLowPad(), [&](OpFoldResult ofr) { - return isConstantIntValue(ofr, size); - }); - }); -} - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::high(ArrayRef sizes) { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG({ - DBGS() << "high pad sizes are "; - llvm::interleaveComma(sizes, llvm::dbgs()); - }); - for (auto [ofr, sz] : llvm::zip(tensorPad.getMixedHighPad(), sizes)) { - if (isConstantIntValue(ofr, sz)) { - return false; - } - } - return true; - }); -} - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::high(AllDims tag, int64_t size) { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG(DBGS() << "all high pad sizes are " << size); - return llvm::all_of(tensorPad.getMixedHighPad(), [&](OpFoldResult ofr) { - return isConstantIntValue(ofr, size); - }); - }); -} - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::yieldsExternalValue() { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG(DBGS() << "pad body yields an externally-defined value"); - Block *body = tensorPad.getBody(); - if (!llvm::hasSingleElement(*body)) { - return false; - } - return llvm::all_of(body->getTerminator()->getOperands(), - [body](Value operand) { - auto arg = dyn_cast(operand); - return !arg || arg.getOwner() != body; - }); - }); -} - -//===---------------------------------------------------------------------===// -// MatchCallbackResult. -//===---------------------------------------------------------------------===// - -ArrayRef -transform_ext::MatchCallbackResult::getPayloadGroup(int64_t position) const { - assert(position < payloadGroupLengths.size()); - int64_t start = 0; - for (int64_t i = 0; i < position; ++i) { - start += payloadGroupLengths[i]; - } - return llvm::ArrayRef(payloadOperations) - .slice(start, payloadGroupLengths[position]); -} - -//===---------------------------------------------------------------------===// -// Case-specific matcher builders. -//===---------------------------------------------------------------------===// - -static constexpr int64_t kCudaWarpSize = 32; - -void transform_ext::makeReductionMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&reductionCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::StructuredOpMatcher *&leadingCapture, - transform_ext::StructuredOpMatcher *&trailingCapture, - MatchedReductionCaptures &captures, bool mustMatchEntireFunc) { - // The core part of the matcher is anchored on a particular reduction op. - auto &reduction = - m_StructuredOp(matcherContext) - // Op has at least a parallel a reduction dimension and at - // most 3 parallel dimensions. - // TODO: relax once we have global collapse/expand_shape. - // - .rank(NumGreaterEqualTo(2)) - .rank(NumLowerEqualTo(4)) - .rank(CaptureRank(captures.reductionRank)) - // Op has a single most-minor reduction. - .dim(-1, utils::IteratorType::reduction) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.reductionOpSizes)) - // All other dimensions are parallel. - .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) - // Single input for now, can be arbitrary projected permutations. - // TODO: Multiple inputs, can be arbitrary projected permutations. - // TODO: Watch out for multiple inputs though as a reduction turns - // into a contraction when mixed with projected - // permutations. A reduction is often bandwidth bound but - // contraction is a different beast that is compute bound - // and has a very different schedule. - // - .input(NumEqualsTo(1)) - .input(AllOperands(), IsProjectedPermutation()) - // Single output supported atm. - // TODO: Multiple outputs. - // - .output(NumEqualsTo(1)) - // A reduction output must be a projected permutation, match it but we - // could also drop this technically. - .output(AllOperands(), IsProjectedPermutation()) - // Only single combiner for now due to reduction warp - // distribution. - // TODO: relax this once reduction distribution is more powerful. - // - .output(0, CaptureElementTypeBitWidth( - captures.reductionOutputElementalTypeBitWidth)) - .output(0, SingleCombinerReduction()); - reductionCapture = &reduction; - - // Mandatory FillOp must create the unique output of the reduction. - // TODO: Relax this, as any map, broadcast, transpose should also work. - // - auto &fill = m_StructuredOp(matcherContext); - reduction = reduction.output(NumEqualsTo(1)).output(0, fill); - fillCapture = &fill; - - // Optional leading or trailing op can be any map, transpose, broadcast but - // not reduce or windowing operation for now. - // It must create the unique input for the reduction. - // TODO: match more optional leading ops, one per input of the reduction. - // TODO: careful about multi-output and turning into a contraction. - // - transform_ext::StructuredOpMatcher commonLeadingOrTrailing = - m_StructuredOp(matcherContext) - // All parallel dimensions. - .dim(AllDims(), utils::IteratorType::parallel) - // All inputs are any projected permutation. - .input(AllOperands(), IsProjectedPermutation()) - .output(AllOperands(), IsPermutation()) - // leading and trailing may have 0, 1 or more input as long as they do - // not come from unmatched ops. This extra constraint is taken care of - // separately. This is also a noop but we document it. - // TODO: Base and derived classes, atm this does not compile. - // .input(NumGreaterEqualTo(0)) - // Single output supported atm. - // TODO: extend this. - // - .output(NumEqualsTo(1)); - // TODO: match more optional leading ops, one per input of the reduction. - // TODO: careful about multi-output and turning into a contraction. - // - auto &leading = - m_StructuredOp(matcherContext, commonLeadingOrTrailing) - .rank(CaptureRank(captures.maybeLeadingRank)) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.leadingOpSizes)) - // Capture output elemental type. - .output(0, CaptureElementTypeBitWidth( - captures.maybeLeadingOutputElementalTypeBitWidth)); - reduction = reduction.input(0, leading, OptionalMatch()); - leadingCapture = &leading; - - // Optional trailing can be any map, transpose, broadcast but not reduce or - // windowing operation for now. - // It must be fed by the unique input for the reduction. - // TODO: match more optional leading ops, one per input of the reduction. - // TODO: careful about multi-output and turning into a contraction. - // - auto &trailing = - m_StructuredOp(matcherContext, commonLeadingOrTrailing) - .rank(CaptureRank(captures.maybeTrailingRank)) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.trailingOpSizes)) - // Capture output elemental type. - .output(0, CaptureElementTypeBitWidth( - captures.maybeTrailingOutputElementalTypeBitWidth)); - reduction = reduction.result(0, HasAnyUse(), trailing, OptionalMatch()); - if (mustMatchEntireFunc) { - reduction = reduction.allTilableOpsCaptured(); - } - trailingCapture = &trailing; -} - -void transform_ext::makeReductionMatcher(transform_ext::MatcherContext &context, - StructuredOpMatcher *&reductionCapture, - MatchedReductionCaptures &captures, - bool mustMatchEntireFunc) { - StructuredOpMatcher *fill; - StructuredOpMatcher *leading; - StructuredOpMatcher *trailing; - makeReductionMatcher(context, reductionCapture, fill, leading, trailing, - captures, mustMatchEntireFunc); -} - -void transform_ext::makeMatmulMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&matmulCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::StructuredOpMatcher *&trailingCapture, - transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { - auto &matmul = transform_ext::m_StructuredOp(matcherContext) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) - // Capture input/output element types. - .input(0, CaptureElementType(captures.lhsElementType)) - .input(1, CaptureElementType(captures.rhsElementType)) - .output(0, CaptureElementType(captures.outputElementType)); - matmulCapture = &matmul; - // Mandatory FillOp must create the unique output of the reduction. - auto &fill = transform_ext::m_StructuredOp(matcherContext); - matmul = matmul.output(transform_ext::NumEqualsTo(1)).output(0, fill); - fillCapture = &fill; - - auto &trailing = m_StructuredOp(matcherContext); - matmul = matmul.result(0, HasAnyUse(), trailing, OptionalMatch()); - if (mustMatchEntireFunc) { - matmul = matmul.allTilableOpsCaptured(); - } - trailingCapture = &trailing; -} - -void transform_ext::makeBatchMatmulMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&bmmCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { - auto &bmm = - transform_ext::m_StructuredOp( - matcherContext) - .hasContractionBody() - .rank(NumEqualsTo(4)) - .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) - .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) - .dim(-1, utils::IteratorType::reduction) - .contractionDims(CaptureContractionDims(captures.contractionDims)) - .input(NumEqualsTo(2)) - .input(0, CaptureElementType(captures.lhsElementType)) - .input(1, CaptureElementType(captures.rhsElementType)) - .output(0, CaptureElementType(captures.outputElementType)); - bmmCapture = &bmm; - - auto &fill = transform_ext::m_StructuredOp(matcherContext); - bmm = bmm.output(0, fill); - fillCapture = &fill; - - if (mustMatchEntireFunc) { - bmm = bmm.allTilableOpsCaptured(); - } -} - -/// Match sum(%src, broadcast(%reduction)) -static void -matchSubBroadcast(transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher &maxReduction, - transform_ext::CapturingValueMatcher &softmaxSourceOperand, - transform_ext::StructuredOpMatcher *&sub) { - using namespace transform_ext; - - auto &broadcast = - transform_ext::m_StructuredOp(matcherContext) - .passThroughOp() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(0, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - broadcast = broadcast.input(0, maxReduction); - - auto &subParallel = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsIdentity()) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - subParallel = subParallel.input(0, softmaxSourceOperand); - subParallel = subParallel.input(1, broadcast); - - auto &subBroadcast = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - subBroadcast = subBroadcast.input(0, softmaxSourceOperand); - subBroadcast = subBroadcast.input(1, maxReduction); - auto &subOr = transform_ext::m_StructuredOp_Or(matcherContext, subBroadcast, - subParallel); - sub = &subOr; -} - -/// Match sum(%exp, broadcast(%sum)) -static void matchdivBroadcast(transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher &expOperand, - transform_ext::StructuredOpMatcher &sum, - transform_ext::StructuredOpMatcher *&div) { - using namespace transform_ext; - - auto &broadcast = - transform_ext::m_StructuredOp(matcherContext) - .passThroughOp() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(0, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - broadcast = broadcast.input(0, sum); - - auto &divNoBroadcast = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsIdentity()) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - - divNoBroadcast = divNoBroadcast.input(0, expOperand); - divNoBroadcast = divNoBroadcast.input(1, broadcast); - - auto &divBroadcast = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - - divBroadcast = divBroadcast.input(0, expOperand); - divBroadcast = divBroadcast.input(1, sum); - - auto &divMerge = transform_ext::m_StructuredOp_Or( - matcherContext, divNoBroadcast, divBroadcast); - div = &divMerge; -} - -void transform_ext::makeSoftmaxMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&maxReductionCapture, - transform_ext::StructuredOpMatcher *&softmaxRootCapture) { - auto &softmaxSourceOperand = m_Value(matcherContext); - - auto &fillMinusInf = m_StructuredOp(matcherContext) - .input(0, ConstantFloatMinOrMinusInf()); - auto &maxReduction = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs(/*commutative=*/true) - // Only handle most inner reduction for now. - .dim(-1, utils::IteratorType::reduction) - .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(AllOperands(), IsIdentity()) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsProjected(-1)); - maxReduction = maxReduction.input(0, softmaxSourceOperand); - maxReduction = maxReduction.output(0, fillMinusInf); - maxReductionCapture = &maxReduction; - - transform_ext::StructuredOpMatcher *subOperand; - matchSubBroadcast(matcherContext, maxReduction, softmaxSourceOperand, - subOperand); - - auto &expOperand = m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(AllOperands(), IsIdentity()) - .output(AllOperands(), IsIdentity()) - .output(NumEqualsTo(1)); - expOperand = expOperand.input(0, *subOperand); - - auto &fillZero = m_StructuredOp(matcherContext) - .input(0, ConstantFloatZero()); - auto &sum = - m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs(/*commutative=*/true) - // Only handle most inner reduction for now. - .dim(-1, utils::IteratorType::reduction) - .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(AllOperands(), IsIdentity()) - .output(AllOperands(), IsProjected(-1)) - .output(NumEqualsTo(1)); - sum = sum.input(0, expOperand); - sum = sum.output(0, fillZero); - - auto &rcpOperand = m_StructuredOp(matcherContext) - .isFloatReciprocal() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(AllOperands(), IsIdentity()) - .output(AllOperands(), IsIdentity()) - .output(NumEqualsTo(1)); - rcpOperand = rcpOperand.input(0, sum); - - auto &mulOperand = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs(/*commutative=*/true) - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - - mulOperand = mulOperand.input(0, expOperand); - mulOperand = mulOperand.input(1, rcpOperand); - - transform_ext::StructuredOpMatcher *divOperand; - matchdivBroadcast(matcherContext, expOperand, sum, divOperand); - - auto &softmaxRoot = - transform_ext::m_StructuredOp_Or(matcherContext, mulOperand, *divOperand); - softmaxRootCapture = &softmaxRoot; -} - -/// Matcher for convolutions. -void transform_ext::makeConvolutionMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&convolutionCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::StructuredOpMatcher *&trailingCapture, - MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { - // The core part of the matcher is anchored on a particular convolution op. - auto &convolution = - m_StructuredOp( - matcherContext) - // Capture convolution dim classifications. - .convolutionDims(CaptureConvDims(captures.convolutionDims)) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.convolutionOpSizes)) - // Capture convolution element types. - .input(0, CaptureElementType(captures.inputElementType)) - .input(1, CaptureElementType(captures.filterElementType)) - .output(0, CaptureElementType(captures.outputElementType)); - convolutionCapture = &convolution; - - // Optional FillOp to create the unique output of the convolution. - auto &fill = m_StructuredOp(matcherContext) - .output(0, CaptureElementTypeBitWidth( - captures.maybeFillElementalTypeBitWidth)); - convolution = - convolution.output(NumEqualsTo(1)).output(0, fill, OptionalMatch()); - fillCapture = &fill; - - // Optional trailing op can be any map, transpose, broadcast but - // not reduce or windowing operation for now. - // It must create the unique input for the reduction. - auto &trailing = - m_StructuredOp(matcherContext) - // All parallel dimensions. - .dim(AllDims(), utils::IteratorType::parallel) - // All inputs are any projected permutation. - .input(AllOperands(), IsProjectedPermutation()) - .output(AllOperands(), IsPermutation()) - .output(NumEqualsTo(1)) - .dim(AllDims(), CaptureDims(captures.trailingOpSizes)) - // Capture output elemental type. - .output(0, CaptureElementTypeBitWidth( - captures.maybeTrailingOutputElementalTypeBitWidth)); - - // Optional trailing can be any map, transpose, broadcast but not reduce or - // windowing operation for now. - convolution = convolution.result(0, HasAnyUse(), trailing, OptionalMatch()); - if (mustMatchEntireFunc) { - convolution = - convolution.allTilableOpsCaptured(); - } - trailingCapture = &trailing; -} - -void transform_ext::makeConvolutionMatcher( - transform_ext::MatcherContext &context, - StructuredOpMatcher *&convolutionCapture, - MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { - StructuredOpMatcher *fill; - StructuredOpMatcher *trailing; - makeConvolutionMatcher(context, convolutionCapture, fill, trailing, captures, - mustMatchEntireFunc); -} - -void transform_ext::makePadMatcher(MatcherContext &context, - CapturingOpMatcher *&padCapture, - MatchedPadCaptures &captures, - bool mustMatchEntireFunc) { - auto &value = transform_ext::m_ShapedValue(context); - value.rank(transform_ext::CaptureRank(captures.rank)) - .dim(transform_ext::AllDims(), transform_ext::CaptureDims(captures.dims)) - .elementType(CaptureElementType(captures.elementType)); - auto &opMatcher = transform_ext::m_tensorPad(context) - .result(0, value) - .low(AllDims(), 0) - .yieldsExternalValue(); - if (mustMatchEntireFunc) { - opMatcher = opMatcher.allTilableOpsCaptured(); - } - padCapture = &opMatcher; -} diff --git a/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.h b/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.h deleted file mode 100644 index 49d33f83c07d..000000000000 --- a/compiler/src/iree/compiler/GlobalOptimization/TransformMatchers.h +++ /dev/null @@ -1,1201 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_GLOBALOPTIMIZATION_TRANSFORMMATCHERS_H_ -#define IREE_COMPILER_GLOBALOPTIMIZATION_TRANSFORMMATCHERS_H_ - -#include -#include -#include - -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/StringMap.h" -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" -#include "mlir/IR/Matchers.h" - -namespace mlir { -namespace transform_ext { - -//===---------------------------------------------------------------------===// -// StructuredOpMatcher and predicates. -//===---------------------------------------------------------------------===// - -class StructuredOpMatcher; -class MatcherContext; -StructuredOpMatcher &m_StructuredOp(MatcherContext &); - -/// A tag indicating the shape being static or dynamic, for use with the -/// structured op matcher. -enum class ShapeKind { Static, Dynamic }; - -/// A placeholder indicating the structured op matcher to check the predicate -/// for all dimensions. -struct AllDims {}; - -/// A predicate indicating the structured op matcher to check the predicate for -/// all dimensions except the specified ones. -struct AllDimsExcept { - explicit AllDimsExcept(std::initializer_list range) { - llvm::append_range(exceptions, range); - } - ArrayRef getExcluded() const { return llvm::ArrayRef(exceptions); } - -private: - SmallVector exceptions; -}; - -/// A placeholder indicating the structured op matcher to check the predicate -/// for all operands of the relevant kind. -struct AllOperands {}; - -/// Base class for single-value captures. Concrete captures should inherit this -/// and forward the constructor via `using Base::Base`. -template -struct CaptureStaticValue { - using Base = CaptureStaticValue; - explicit CaptureStaticValue(T &value) : value(value) {} - T &value; -}; - -/// Captures the (static) size of the dimension. -struct CaptureDim : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures the (static) sizes of multiple dimensions. -struct CaptureDims : public CaptureStaticValue> { - using Base::Base; -}; - -/// Captures the contraction dimensions of the target operation. -struct CaptureIndexingMaps : public CaptureStaticValue> { - using Base::Base; -}; - -/// Captures the contraction dimensions of the target operation. -struct CaptureContractionDims - : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures the convolution dimensions of the target operation. -struct CaptureConvDims - : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures the rank of the operation. -struct CaptureRank : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures the bitwidth of an element type. -struct CaptureElementTypeBitWidth : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures element element type. -struct CaptureElementType : public CaptureStaticValue { - using Base::Base; -}; - -template -struct CaptureAttribute : public CaptureStaticValue { - static_assert(std::is_base_of_v, - "can only capture a subclass of Attribute"); - using CaptureStaticValue::CaptureStaticValue; -}; - -/// A tag indicating to look for any user of the operation's result that would -/// satisfy the predicate. -struct HasAnyUse {}; - -/// Base class for predicate parameters that can be described with the single -/// value. Concrete predicate parameters should inherit this and forward the -/// constructor via `using Base::Base`. -template -struct SingleValuePredicateParam { - using Base = SingleValuePredicateParam; - explicit SingleValuePredicateParam(T value) : value(value) {} - const T value; -}; - -/// Indicates that the dimension must be divisible by the given value. -struct DivisibleBy : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the number of entities must be equal to the given value. -struct NumEqualsTo : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the number of entities must be greater than the given value. -struct NumGreaterEqualTo : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the number of entities must be greater than the given value. -struct NumLowerEqualTo : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the bit width of the elemental type must be equal to the give -/// value. -struct ElementTypeBitWidth : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Predicate tag indicating that the affine map is a permutation. -struct IsPermutation {}; - -/// Predicate tag indicating that the affine map is a projected permutation. -struct IsProjectedPermutation {}; - -/// Predicate tag indicating that the affine map is a projection of given -/// dimension. -struct IsProjected : public SingleValuePredicateParam { - using Base::Base; -}; -/// Predicate tag indicating that the affine map is an identity. -struct IsIdentity {}; - -/// Predicate tag indicating that the operand is a special float constant. -struct ConstantFloatMinOrMinusInf {}; -struct ConstantFloatZero {}; -struct ConstantFloatOne {}; - -/// Indicates that the match optional. The matcher is still expected to run and -/// capture if successful. The parameter can be set to false -struct OptionalMatch : public SingleValuePredicateParam { - OptionalMatch() : Base(true) {} - explicit OptionalMatch(bool set) : Base(set) {} -}; - -/// Predicate tag indicating that the reduction is produced by a single combiner -/// operation. -struct SingleCombinerReduction {}; - -class CapturingOpMatcher; -class CapturingValueMatcher; - -/// Base class for capturing matchers that can be owned by the context. -class CapturingMatcherBase { -public: - // Virtual destructor so unique pointers are deallocated correctly. - // TODO: if efficiency is a problem, consider disallowing non-trivial - // destructors for subclasses. - virtual ~CapturingMatcherBase() = default; - -protected: - /// Informs the matcher that it has another, nested matcher. Derived classes - /// must call this to keep track of nested matchers for capture resetting - /// purposes. - template - void recordNestedMatcher(T &nested) { - if constexpr (std::is_base_of_v) { - nestedCapturingMatchers.push_back(&nested); - } - if constexpr (std::is_base_of_v) { - nestedCapturingValueMatchers.push_back(&nested); - } - } - - /// Appends all nested capturing matchers of a certain kind, excluding this - /// one, to `nested`. - void getAllNested(SmallVectorImpl &nested); - void - getAllNestedValueMatchers(SmallVectorImpl &nested); - - /// Resets nested capturing matchers but does NOT reset the current one. - void resetCapture(); - -private: - /// A list of (recursively) nested capturing matchers that should be reset - /// when the current matcher is. - SmallVector nestedCapturingMatchers; - SmallVector nestedCapturingValueMatchers; -}; - -/// A context object holding capturing matchers, must outlive any individual -/// matcher. When matching complex subgraphs, the caller often doesn't care -/// about all intermediate nodes (operations) in the graph and shouldn't need to -/// hold matcher objects for those. These matchers can be created in this -/// context. -class MatcherContext { -public: - /// Create a new matcher of the specified type owned by this context. - template - std::enable_if_t, T> & - allocate(Args &&...args) { - // Need to call "new" explicitly as make_unique wouldn't have access to the - // private constructor when this class would. - ownedMatchers.emplace_back( - std::unique_ptr(new T(std::forward(args)...))); - return *static_cast(ownedMatchers.back().get()); - } - -private: - /// Owning list of matchers. - // TODO: If this becomes inefficient, consider something like BumpPtrAllocator - // that derived classes can use to store their members as well. - SmallVector> ownedMatchers; -}; - -/// Base class for value matchers that capture the matched value. Stores a list -/// of predicates and requires all of them to match for the value to match. Once -/// a value matched, any repeated use just verifies that equality of the value. -class CapturingValueMatcher : public CapturingMatcherBase { - friend class CapturingMatcherBase; - friend class MatcherContext; - - using PredicateFn = std::function; - -public: - /// Resets the captured value to null. This should be called if the same - /// pattern needs to be applied more than once as it may keep captured values - /// for optional nested predicates from the previous application. - void resetCapture() { - captured = nullptr; - CapturingMatcherBase::resetCapture(); - } - - /// Returns the matched value if the match was successful. - Value getCaptured() const { return captured; } - - /// Matches the given value, hook for `matchPattern`. - bool match(Value value); - -protected: - CapturingValueMatcher() = default; - - /// Adds a predicate to the end of the predicate list for this value matcher. - template - void addPredicate(Fn &&predicate) { - predicates.emplace_back(std::forward(predicate)); - } - - /// The captured value. - Value captured = nullptr; - -private: - /// Additional predicates to be checked on the value. - SmallVector predicates; -}; - -/// Creates a matcher of an arbitrary value. -inline CapturingValueMatcher &m_Value(MatcherContext &context) { - return context.allocate(); -} - -/// Matcher for typed values whose type implements the `ShapedType` interface. -/// Allows for matching the components of the shaped type such as rank and -/// dimensions. -class ShapedValueMatcher : public CapturingValueMatcher { - friend class MatcherContext; - - ShapedValueMatcher(); - -public: - /// Add an always-succeeding matcher predicate capturing the rank. - ShapedValueMatcher &rank(CaptureRank capture); - - /// Add an always-succeeding matcher predicate capturing the size of the - /// dimension identified by the first argument. - ShapedValueMatcher &dim(int64_t dimension, CaptureDim capture); - - /// Add an always-succeeding matcher predicate capturing the sizes of all - /// dimensions in order of appearance. - ShapedValueMatcher &dim(AllDims tag, CaptureDims captures); - - /// Add an always-succeeding matcher predicate capturing the element type of - /// the value. - ShapedValueMatcher &elementType(CaptureElementType captures); -}; - -/// Construct a new matcher of a value whose type is a `ShapedType`, owned by -/// the given context. -inline ShapedValueMatcher &m_ShapedValue(MatcherContext &context) { - return context.allocate(); -} - -/// Matcher for operations with additional predicates attachable through the -/// fluent, a.k.a. chainable, API. Note that public API must *not* accept -/// additional callbacks even; new predicates should be added instead when -/// necessary. Not only this decreases the depth of the callback stack and -/// increases readability, it also allows us to port the matcher to a -/// declarative format using PDL and/or Transform dialect in the future. The -/// latter will become impossible with arbitrary C++ callbacks. -class CapturingOpMatcher : public CapturingMatcherBase { - friend class CapturingMatcherBase; - friend class MatcherContext; - - template - friend CapturingOpMatcher &m_Operation(MatcherContext &matcherContext); - -public: - using PredicateFn = std::function; - - /// Matches the given operation, hook for `matchPattern`. - bool match(Operation *op); - - /// Resets the captured value to null. This should be called if the same - /// pattern needs to be applied more than once as it may keep captured values - /// for optional nested predicates from the previous application. - void resetCapture() { - captured = nullptr; - CapturingMatcherBase::resetCapture(); - } - - /// Returns the matched operation if the match was successful. - Operation *getCaptured() const { return captured; } - - /// Adds alternative paths for predicates. In practice, this is just a - /// predicate that is satisfied when either the first or the second matcher is - /// satisfied. The alternative satisfaction is eager and short-cutting, i.e., - /// the second alternative will not be processed, and therefore will not - /// capture values, if the first alternative succeeded. - CapturingOpMatcher &alternatives(CapturingOpMatcher &first, - CapturingOpMatcher &second); - - //===-------------------------------------------------------------------===// - // Constraints on adjacent ops. - //===-------------------------------------------------------------------===// - - /// Adds a predicate checking that all ops implementing TilingInterface in the - /// parent of the given type (e.g., a function or a module) were matched by - /// this or nested matchers. This is useful to ensure that the matcher covered - /// the entire parent region, not just a parent of it. This predicate **must** - /// be added *after* all the other predicates that capture. - template - CapturingOpMatcher &allTilableOpsCaptured() { - SmallVector copy; - copy.push_back(this); - getAllNested(copy); - addPredicate([copy = std::move(copy)](Operation *op) { - Operation *parent = op->getParentOfType(); - return checkAllTilableMatched(parent, op, copy); - }); - return *this; - } - - //-------------------------------------------------------------------------// - // Predicates for operands and results. - //-------------------------------------------------------------------------// - - /// Adds a predicate checking that the operation has exactly the given number - /// of operands. - CapturingOpMatcher &operand(NumEqualsTo num); - - /// Adds a predicate checking that the `pos`-th operand of the operation is - /// defined by an operation that satisfies the given matcher. - CapturingOpMatcher &operand(int64_t pos, CapturingOpMatcher &nested); - - /// Adds a predicate checking that the `pos`-th operand of the operation - /// satisfies the given value matcher. - CapturingOpMatcher &operand(int64_t pos, CapturingValueMatcher &nested); - - /// Adds a predicate checking that the `pos`-th operand of the operation is - /// defined by `arith.constant` with the value 1.0. - // TODO: better matching for attributes. - CapturingOpMatcher &operand(int64_t pos, ConstantFloatOne); - - /// Adds a predicate checking that the operation has exactly the given number - /// of results. - CapturingOpMatcher &result(NumEqualsTo num); - - /// Adds a predicate checking that the `pos`-th result of the operation - /// satisfies the given value matcher. - CapturingOpMatcher &result(int64_t pos, CapturingValueMatcher &nested); - -protected: - /// Constructs a default operation matcher accepting any operation. - CapturingOpMatcher() = default; - - /// Adds a predicate for the matched operation to satisfy. - template - void addPredicate(Fn &&predicate) { - predicates.emplace_back(std::forward(predicate)); - } - - /// Produce the debug output for `create` method in a non-templated way. - static void debugOutputForCreate(ArrayRef opNames); - -private: - /// A list of additional conditions for the operation to match. - SmallVector predicates; - - /// Checks that `matchers` captured all tilable ops nested in `parent` except - /// for `linalgOp`. This is an implementation detail of allTilableOpsCaptured. - static bool checkAllTilableMatched(Operation *parent, Operation *op, - ArrayRef matchers); - - /// Creates a matcher for an operation with one of the given types. - template - static CapturingOpMatcher create() { - CapturingOpMatcher matcher; - matcher.addPredicate([](Operation *op) { - debugOutputForCreate(ArrayRef{OpType::getOperationName()...}); - return isa(op); - }); - return matcher; - } - - /// Common util for constant matcher. - CapturingOpMatcher &operand(int64_t position, - std::function floatValueFn); - -protected: - /// Matched value. - Operation *captured = nullptr; -}; - -namespace detail { -/// Prints the debug output from the ConcreteOpMatcher constructor. The -/// implementation must reside in the C++ file so we don't pollute the header -/// with debug includes, and ConcreteOpMatcher is a class template that can only -/// reside in the header. -void debugOutputForConcreteOpMatcherConstructor(StringRef name); -} // namespace detail - -/// Base class for matchers that match a specific op. Adds an initial predicate -/// checking if the op is indeed of the specified kind. -/// Derived classes specializing this for op interfaces MUST also define a -/// specialization of DebugOpKindDescription. -template -class ConcreteOpMatcher : public CapturingOpMatcher { -protected: - using Base = ConcreteOpMatcher; - - static StringRef getConcreteOpDescription() { - return OpTy::getOperationName(); - } - - /// Adds a predicate checking if the op is of the OpTy kind. - ConcreteOpMatcher() { - CapturingOpMatcher::addPredicate([](Operation *op) { - detail::debugOutputForConcreteOpMatcherConstructor( - Derived::getConcreteOpDescription()); - return isa(op); - }); - } - - /// Adds a predicate for the matched operation to satisfy. - template - Derived &addPredicate(FnTy &&predicate) { - // Dispatch to the callback. - CapturingOpMatcher::addPredicate( - [inner = std::move(predicate)](Operation *op) { - return inner(cast(op)); - }); - return static_cast(*this); - } - -public: - /// Adds alternative paths for predicates. In practice, this is just a - /// predicate that is satisfied when either the first or the second matcher is - /// satisfied. The alternative satisfaction is eager and short-cutting, i.e., - /// the second alternative will not be processed, and therefore will not - /// capture values, if the first alternative succeeded. - Derived &alternatives(CapturingOpMatcher &first, CapturingOpMatcher &second) { - return static_cast( - CapturingOpMatcher::alternatives(first, second)); - } - - /// Adds a predicate checking that all ops implementing TilingInterface in the - /// parent of the given type (e.g., a function or a module) were matched by - /// this or nested matchers. This is useful to ensure that the matcher covered - /// the entire parent region, not just a parent of it. This predicate **must** - /// be added *after* all the other predicates that capture. - template - Derived &allTilableOpsCaptured() { - return static_cast( - CapturingOpMatcher::allTilableOpsCaptured()); - } - - //-------------------------------------------------------------------------// - // Predicates for operands and results. - //-------------------------------------------------------------------------// - - /// Adds a predicate checking that the operation has exactly the given number - /// of operands. - Derived &operand(NumEqualsTo num) { - return static_cast(CapturingOpMatcher::operand(num)); - } - - /// Adds a predicate checking that the `pos`-th operand of the operation is - /// defined by an operation that satisfies the given matcher. - Derived &operand(int64_t pos, CapturingOpMatcher &nested) { - return static_cast(CapturingOpMatcher::operand(pos, nested)); - } - - /// Adds a predicate checking that the `pos`-th operand of the operation - /// satisfies the given value matcher. - Derived &operand(int64_t pos, CapturingValueMatcher &nested) { - return static_cast(CapturingOpMatcher::operand(pos, nested)); - } - - /// Adds a predicate checking that the `pos`-th operand of the operation is - /// defined by `arith.constant` with the value 1.0. - // TODO: better matching for attributes. - Derived &operand(int64_t pos, ConstantFloatOne c) { - return static_cast(CapturingOpMatcher::operand(pos, c)); - } - - /// Adds a predicate checking that the operation has exactly the given number - /// of results. - Derived &result(NumEqualsTo num) { - return static_cast(CapturingOpMatcher::result(num)); - } - - /// Adds a predicate checking that the `pos`-th result of the operation - /// satisfies the given value matcher. - Derived &result(int64_t pos, CapturingValueMatcher &nested) { - return static_cast(CapturingOpMatcher::result(pos, nested)); - } -}; - -/// Matcher for the `tensor.pad` operation. -class TensorPadOpMatcher - : public ConcreteOpMatcher { - friend class MatcherContext; - - TensorPadOpMatcher() = default; - -public: - /// Adds a predicate checking that the low padding sizes are exactly the given - /// values. - TensorPadOpMatcher &low(ArrayRef sizes); - - /// Adds a predicate checking that the low padding sizes for all dimensions - /// are exactly the same given value. - TensorPadOpMatcher &low(AllDims tag, int64_t size); - - /// Adds a predicate checking that the high padding sizes for all dimensions - /// are exactly the same given value. - TensorPadOpMatcher &high(ArrayRef sizes); - - /// Adds a predicate checking that the high padding sizes for all dimensions - /// are exactly the same given value. - TensorPadOpMatcher &high(AllDims tag, int64_t size); - - /// Adds a predicate checking that the body of the pad only yields values - /// defined outside the pad region. - TensorPadOpMatcher &yieldsExternalValue(); -}; - -inline TensorPadOpMatcher &m_tensorPad(MatcherContext &matcherContext) { - return matcherContext.allocate(); -} - -/// Creates a default operation matcher in the given context that accepts any -/// operation. -inline CapturingOpMatcher &m_Operation(MatcherContext &matcherContext) { - return matcherContext.allocate(); -} - -/// Creates an operation matcher in the given context that accepts only -/// operations of the kinds provided as template arguments. -template -inline CapturingOpMatcher &m_Operation(MatcherContext &matcherContext) { - return matcherContext.allocate( - CapturingOpMatcher::create()); -} - -/// Matcher for structured aka Linalg operations. -class StructuredOpMatcher - : public ConcreteOpMatcher { - friend class MatcherContext; - - StructuredOpMatcher() = default; - -public: - static StringRef getConcreteOpDescription() { - return "linalg interface implementation"; - } - - /// Creates a matcher for a structured operation with one of the given types. - template - static StructuredOpMatcher create() { - StructuredOpMatcher matcher; - matcher.addPredicate([](Operation *op) { - debugOutputForCreate(ArrayRef{OpType::getOperationName()...}); - return isa(op) && isa(op); - }); - return matcher; - } - - /// Matches a structured operation if either patterns A or B match. - StructuredOpMatcher(StructuredOpMatcher &A, StructuredOpMatcher &B); - - //===-------------------------------------------------------------------===// - // Constraints on op rank and dims. - //===-------------------------------------------------------------------===// - /// Adds a predicate checking that the given rank must be greater than some - /// constant value. - StructuredOpMatcher &rank(NumGreaterEqualTo minRank); - StructuredOpMatcher &rank(NumLowerEqualTo maxRank); - StructuredOpMatcher &rank(NumEqualsTo exactRank); - - /// Adds a predicate checking that the given iteration space dimension is - /// static/dynamic. The dimension index may be negative, in which case - /// dimensions are counted from the last one (i.e. Python-style), or be an - /// AllDims tag, in which case all dimensions are checked. This may be - /// eventually extended to slices and/or lists of dimensions. - StructuredOpMatcher &dim(int64_t dimension, ShapeKind kind) { - return dim(SmallVector{dimension}, kind); - } - StructuredOpMatcher &dim(SmallVector &&dimensions, ShapeKind kind); - StructuredOpMatcher &dim(AllDims tag, ShapeKind kind); - - /// Adds a predicate checking that the given iteration space dimension has the - /// given iterator type, e.g., parallel or reduction. The dimension index may - /// be negative, in which case dimensions are counted from the last one - /// (i.e. Python-style), or be an AllDims tag, in which case all dimensions - /// are checked. This may be eventually extended to slices and/or lists of - /// dimensions. - StructuredOpMatcher &dim(int64_t dimension, utils::IteratorType kind) { - return dim(SmallVector{dimension}, kind); - } - // Ownership may get tricky here so we wrap in an explicit vector. - StructuredOpMatcher &dim(SmallVector &&dimensions, - utils::IteratorType kind); - StructuredOpMatcher &dim(AllDims tag, utils::IteratorType kind); - StructuredOpMatcher &dim(AllDimsExcept &&dimensions, - utils::IteratorType kind); - - /// Adds a predicate checking that the given iteration space dimension is - /// statically known to be divisible by the given value. The dimension index - /// may be negative, in which case dimensions are counted from the last one - /// (i.e. Python-style). - StructuredOpMatcher &dim(int64_t dimension, DivisibleBy divisibleBy); - - //===-------------------------------------------------------------------===// - // Capture directives. - //===-------------------------------------------------------------------===// - StructuredOpMatcher &rank(CaptureRank capture); - StructuredOpMatcher &dim(int64_t dimension, CaptureDim capture); - StructuredOpMatcher &dim(AllDims tag, CaptureDims captures); - StructuredOpMatcher &indexingMaps(CaptureIndexingMaps indexingMaps); - StructuredOpMatcher &contractionDims(CaptureContractionDims contractionDims); - StructuredOpMatcher &convolutionDims(CaptureConvDims convDims); - - //===-------------------------------------------------------------------===// - // Constraints on input operands. - //===-------------------------------------------------------------------===// - /// Adds a predicate checking that the structured op has the given number of - /// inputs. - StructuredOpMatcher &input(NumEqualsTo num); - - /// Adds a predicate that recursively applies other predicates to the - /// operation defining the `position`-th operand. The position may be - /// negative, in which case positions are counted from the last one - /// (i.e. Python-style). When the match is optional, the predicate check - /// succeeds as long as the `position` is in bounds. The matcher is executed - /// if there is a defining operation for the input operand. - template - std::enable_if_t::value, - StructuredOpMatcher &> - input(int64_t position, T &operandMatcher, - OptionalMatch optional = OptionalMatch(false)) { - addInputMatcher( - position, - [&operandMatcher](Operation *op) { return operandMatcher.match(op); }, - optional); - recordNestedMatcher(operandMatcher); - return *this; - } - template - std::enable_if_t::value, - StructuredOpMatcher &> - input(int64_t position, T &operandMatcher, - OptionalMatch optional = OptionalMatch(false)) { - addInputMatcher( - position, - [&operandMatcher](Value v) { return operandMatcher.match(v); }, - optional); - recordNestedMatcher(operandMatcher); - return *this; - } - - /// Adds a predicate checking that all input operands of the structured op - /// have a permutation indexing map. - StructuredOpMatcher &input(AllOperands tag, IsPermutation); - - /// Adds a predicate checking that all input operands of the structured op - /// have a projected permutation indexing map. - StructuredOpMatcher &input(AllOperands tag, IsProjectedPermutation); - - /// Adds a predicate checking that all input operands of the structured op - /// are projected along the given dimension. - StructuredOpMatcher &input(SmallVector &&positions, IsProjected dim); - StructuredOpMatcher &input(int64_t position, IsProjected dim) { - return input(SmallVector{position}, dim); - } - - /// Adds a predicate checking that all input operands of the structured op - /// have identity indexing map. - StructuredOpMatcher &input(AllOperands tag, IsIdentity); - StructuredOpMatcher &input(SmallVector &&positions, IsIdentity); - StructuredOpMatcher &input(int64_t position, IsIdentity) { - return input(SmallVector{position}, IsIdentity()); - } - - /// Adds a predicate checking that the bit width of the elemental type of the - /// structured op input at the given position is equal to the given value. - StructuredOpMatcher &input(int64_t position, ElementTypeBitWidth width); - - /// Capture the elemental type bitwidth of input operand `position`. - StructuredOpMatcher &input(int64_t position, - CaptureElementTypeBitWidth width); - - /// Capture the elemental type of input operand `position`. - StructuredOpMatcher &input(int64_t position, CaptureElementType elem); - - /// Check if input is equal to a known constant. - // TODO: Support matching for constant ops. - StructuredOpMatcher &input(int64_t position, ConstantFloatMinOrMinusInf); - StructuredOpMatcher &input(int64_t position, ConstantFloatZero); - - //===-------------------------------------------------------------------===// - // Constraints on output operands. - //===-------------------------------------------------------------------===// - - /// Adds a predicate checking that the structured op has the given number of - /// outputs. - StructuredOpMatcher &output(NumEqualsTo num); - - /// Adds a predicate checking that all output operands of the structured op - /// have a permutation indexing map. - StructuredOpMatcher &output(AllOperands tag, IsPermutation); - - /// Adds a predicate checking that all output operands of the structured op - /// have a projected permutation indexing map. - StructuredOpMatcher &output(AllOperands tag, IsProjectedPermutation); - - /// Adds a predicate checking that all output operands of the structured op - /// have a - StructuredOpMatcher &output(AllOperands tag, IsProjected dim); - - /// Adds a predicate checking that all output operands of the structured op - /// have identity indexing map. - StructuredOpMatcher &output(AllOperands tag, IsIdentity); - - /// Adds a predicate checking that the bit width of the elemental type of the - /// structured op output at the given position is equal to the given value. - StructuredOpMatcher &output(int64_t position, ElementTypeBitWidth width); - - /// Capture the elemental type bitwidth of output operand `position`. - StructuredOpMatcher &output(int64_t position, - CaptureElementTypeBitWidth width); - - /// Capture the elemental type of output operand `position`. - StructuredOpMatcher &output(int64_t position, CaptureElementType elem); - - /// Adds a predicate checking that the output of the structured op is produced - /// by a reduction with a single-operation combinator (such as addf or mulf, - /// but not a compare+select pair). - StructuredOpMatcher &output(int64_t position, SingleCombinerReduction tag); - - /// Adds a predicate that recursively applies other predicates to the - /// operation defining the init/out operand corresponding to `position`-th - /// output. The position may be negative, in which case positions are counted - /// from the last one (i.e. Python-style). When the match is optional, the - /// predicate check succeeds as long as the `position` is in bounds. The - /// matcher executed if there is a defining operation for the output operand. - template - std::enable_if_t::value, - StructuredOpMatcher &> - output(int64_t position, T &operandMatcher, - OptionalMatch optional = OptionalMatch(false)) { - addOutputMatcher( - position, - [&operandMatcher](Operation *op) { return operandMatcher.match(op); }, - optional); - recordNestedMatcher(operandMatcher); - return *this; - } - - //===-------------------------------------------------------------------===// - // Constraints on results. - //===-------------------------------------------------------------------===// - - /// Adds a predicate that recursively applies to users of the `position`-th - /// result of the structured op. Succeeds if any user matches the predicate. - /// When the match is optional, the predicate check succeeds as long as the - /// `position` is in bounds, after running the given matcher. - template - std::enable_if_t::value, - StructuredOpMatcher &> - result(int64_t position, HasAnyUse tag, T &resultUserMatcher, - OptionalMatch optional = OptionalMatch(false)) { - addResultMatcher( - position, tag, - [&resultUserMatcher](Operation *op) { - return resultUserMatcher.match(op); - }, - optional); - recordNestedMatcher(resultUserMatcher); - return *this; - } - - //===-------------------------------------------------------------------===// - // Constraints on op region. - //===-------------------------------------------------------------------===// - - /// Return true if the linalg op only contains a single ops and the arguments - /// of the operation match the order of the linalg operand. - /// Example: - /// linalg.generic - /// ins(%0, %1 : tensor, tensor) - /// outs(%2 : tensor) { - /// ^bb0(%arg0: f32, %arg1: f32): - /// %3 = arith.maxf %arg0, %arg1 : f32 - /// linalg.yield %3 : f32 - /// } -> tensor - /// If commutative is set binary operations can have their operands swapped. - template - StructuredOpMatcher &singleOpWithCanonicaleArgs(bool commutative = false) { - return singleOpWithCanonicaleArgs(OpType::getOperationName(), commutative); - } - StructuredOpMatcher &singleOpWithCanonicaleArgs(StringRef opname, - bool commutative); - /// Check if the op is a linalg of with a single float reciprocal op. - StructuredOpMatcher &isFloatReciprocal(); - /// Check if the op is a linalg of with a region containing only a yield op - /// using block arguments in order. - StructuredOpMatcher &passThroughOp(); - - /// Check if the body of the linalg op implements a contraction of the kind - /// result = input1 input2 - template - StructuredOpMatcher &hasContractionBody() { - return hasContractionBody( - [](Operation *op) { return isa(op); }, - [](Operation *op) { return isa(op); }, - ElemOpTy::getOperationName(), ReductionOpTy::getOperationName()); - } - -private: - /// Non-template implementations of nested predicate builders for inputs, - /// outputs and results. Should not be called directly. - void addInputMatcher(int64_t position, - std::function matcher, - OptionalMatch optional); - void addInputMatcher(int64_t position, std::function matcher, - OptionalMatch optional); - void addOutputMatcher(int64_t position, - std::function matcher, - OptionalMatch optional); - void addResultMatcher(int64_t position, HasAnyUse tag, - std::function matcher, - OptionalMatch optional); - - // Common util for constant matcher. - StructuredOpMatcher &input(int64_t position, - std::function floatValueFn); - - /// Non-template implementation of hasContractionBody. Takes callbacks for - /// checking operation kinds and names for error reporting. - StructuredOpMatcher & - hasContractionBody(function_ref isaElemOpTy, - function_ref isaReductionOpTy, - StringRef elemOpName, StringRef reductionOpName); -}; - -/// Creates a matcher of an arbitrary structured op. -inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext) { - return matcherContext.allocate(); -} - -/// Creates a matcher that is a copy of the given matcher. -inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext, - const StructuredOpMatcher &other) { - return matcherContext.allocate(other); -} - -/// Creates a matcher that accepts as disjunction of the two given matchers. -inline StructuredOpMatcher &m_StructuredOp_Or(MatcherContext &matcherContext, - StructuredOpMatcher &A, - StructuredOpMatcher &B) { - return matcherContext.allocate(A, B); -} - -/// Creates a matcher of a structured op with kinds provided as template -/// arguments. -template -inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext) { - return matcherContext.allocate( - StructuredOpMatcher::create()); -} - -//===---------------------------------------------------------------------===// -// MatchCallback functionality. -//===---------------------------------------------------------------------===// - -/// Additional results of the C++ callback usable in the `match_callback` -/// transform operation. Conceptually, a list of lists of payload operations to -/// be associated with each result handle. -class MatchCallbackResult { -public: - /// Returns the number of lists of payload operations. - int64_t getNumPayloadGroups() const { return payloadGroupLengths.size(); } - - /// Returns the `position`-th list of payload operations. - ArrayRef getPayloadGroup(int64_t position) const; - - /// Adds a new list of payload operations to the list of lists. The new list - /// must not contain null operations. - template - int64_t addPayloadGroup(Range operations) { - int64_t originalLength = payloadOperations.size(); - assert(llvm::all_of(operations, [](Operation *op) -> bool { return op; }) && - "null operation"); - llvm::append_range(payloadOperations, operations); - payloadGroupLengths.push_back(payloadOperations.size() - originalLength); - return payloadGroupLengths.size() - 1; - } - void addPayloadGroup(ArrayRef operations) { - addPayloadGroup>(operations); - } - - /// Adds a new singleton list of payload operation to the list of lists if the - /// operation is non-null, adds an empty list otherwise. Useful for results of - /// optional matches. - void addPotentiallyEmptyPayloadGroup(Operation *op) { - if (!op) { - addPayloadGroup(ArrayRef()); - } else { - addPayloadGroup(ArrayRef(op)); - } - } - -private: - /// The flat list of all payload operations. `payloadGroupLengths` can be used - /// to compute the sublist that corresponds to one nested list. - // TODO: if somebody implements such a flattened vector generically, use it. - SmallVector payloadOperations; - SmallVector payloadGroupLengths; -}; - -/// A transform state extension that maintains the mapping between callback -/// names as strings usable in `match_callback` and their implementations. -class MatchCallbacksRegistry : public transform::TransformState::Extension { -public: - using MatchCallbackFn = std::function; - - /// Constructs the extension. - MatchCallbacksRegistry(transform::TransformState &state) - : transform::TransformState::Extension(state) {} - - /// Registers the given function as a callback with the given name. The name - /// must not be already present in the registry. The callback must be - /// convertible to MatchCallbackFn. - template - void registerCallback(StringRef name, Fn &&fn) { - bool succeeded = callbacks.try_emplace(name, std::forward(fn)).second; - (void)succeeded; - assert(succeeded && "adding a callback with a repeated name"); - } - - /// Returns a pointer to the implementation of the callback with the given - /// name, or null if it is not present in the registry. - const MatchCallbackFn *get(StringRef name) const { - auto iter = callbacks.find(name); - if (iter == callbacks.end()) { - return nullptr; - } - return &iter->getValue(); - } - -private: - llvm::StringMap callbacks; -}; - -//===---------------------------------------------------------------------===// -// Case-specific matcher builders. -//===---------------------------------------------------------------------===// - -struct MatchedReductionCaptures { - int64_t reductionRank = 0; - int64_t maybeLeadingRank = 0; - int64_t maybeTrailingRank = 0; - SmallVector leadingOpSizes = {}; - SmallVector reductionOpSizes = {}; - SmallVector trailingOpSizes = {}; - int64_t reductionOutputElementalTypeBitWidth = 0; - int64_t maybeLeadingOutputElementalTypeBitWidth = 0; - int64_t maybeTrailingOutputElementalTypeBitWidth = 0; -}; - -struct MatchedMatmulCaptures { - linalg::ContractionDimensions contractionDims = {}; - Type lhsElementType, rhsElementType, outputElementType; - SmallVector matmulOpSizes = {}; - SmallVector indexingMaps; - - /// Helper functions. - int64_t rank() const { return matmulOpSizes.size(); } - /// Return all batches. - ArrayRef batches() const { return contractionDims.batch; } - /// Return the most minor candidate dimension for `m`. - int64_t m() const { return contractionDims.m.back(); } - /// Return the most minor candidate dimension for `n`. - int64_t n() const { return contractionDims.n.back(); } - /// Return the most minor candidate dimension for `k`. - int64_t k() const { return contractionDims.k.back(); } - /// AffineMap for indexing into the LHS. - AffineMap lhsIndexing() const { - assert(indexingMaps.size() == 3 && "expected 3 indexing maps"); - return indexingMaps[0]; - } - /// AffineMap for indexing into the RHS. - AffineMap rhsIndexing() const { - assert(indexingMaps.size() == 3 && "expected 3 indexing maps"); - return indexingMaps[1]; - } - /// AffineMap for indexing into the RES. - AffineMap resIndexing() const { - assert(indexingMaps.size() == 3 && "expected 3 indexing maps"); - return indexingMaps[2]; - } -}; - -/// Creates a group of matchers for: -/// -/// trailing(reduction(leading(), fill())) -/// -/// where trailing and leading are elementwise operations whose presence is -/// optional. Each matcher will capture the corresponding operation. If -/// `mustMatchEntireFunc` is set, the matcher additionally checks if all -/// tileable operations in the functions are captured. -void makeReductionMatcher(MatcherContext &context, - StructuredOpMatcher *&reductionCapture, - StructuredOpMatcher *&fillCapture, - StructuredOpMatcher *&leadingCapture, - StructuredOpMatcher *&trailingCapture, - MatchedReductionCaptures &captures, - bool mustMatchEntireFunc); -void makeReductionMatcher(MatcherContext &context, - StructuredOpMatcher *&reductionCapture, - MatchedReductionCaptures &captures, - bool mustMatchEntireFunc); -/// -/// trailing(matmul(*, *, fill())) -/// -/// where trailing and leading are elementwise operations whose presence is -/// optional. Each matcher will capture the corresponding operation. If -/// `mustMatchEntireFunc` is set, the matcher additionally checks if all -/// tileable operations in the functions are captured. -void makeMatmulMatcher(MatcherContext &matcherContext, - StructuredOpMatcher *&matmulCapture, - StructuredOpMatcher *&fillCapture, - StructuredOpMatcher *&trailingCapture, - MatchedMatmulCaptures &captures, - bool mustMatchEntireFunc); - -/// Create a group of matchers of batch mamtul with a fill: -/// -/// batch_matmul(*, *, fill()) -/// -/// and capture various useful quantities. If `mustMatchEntireFunc` is set, the -/// matcher additionally checks if all tileable operations in the functions are -/// captured. -void makeBatchMatmulMatcher(transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&bmmCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::MatchedMatmulCaptures &captures, - bool mustMatchEntireFunc); - -/// Create a group of matchers for a different code sequence of operations -/// matching exactly a softmax operation. -/// -/// %red = reduce_max(%0) -/// %sub = sub(%0, %red) -/// %exp = exp(%sub) -/// %sum = reduce_sum(%exp) -/// %mul = div(%exp, %%sum) -void makeSoftmaxMatcher(MatcherContext &context, - StructuredOpMatcher *&maxReductionCapture, - StructuredOpMatcher *&softmaxRootCapture); - -struct MatchedConvolutionCaptures { - Type inputElementType, filterElementType, outputElementType; - mlir::linalg::ConvolutionDimensions convolutionDims = {}; - SmallVector convolutionOpSizes = {}; - SmallVector trailingOpSizes = {}; - int64_t maybeTrailingOutputElementalTypeBitWidth = 0; - int64_t maybeFillElementalTypeBitWidth = 0; -}; - -/// Creates a group of matchers for: -/// -/// trailing(convolution(input, filter, fill())) -/// -/// where fill is a FillOp and trailing is an elementwise operation, both of -/// which is optional. Each matcher will capture the corresponding operation. If -/// `mustMatchEntireFunc` is set, the matcher additionally checks if all -/// tileable operations in the functions are captured. -void makeConvolutionMatcher(MatcherContext &context, - StructuredOpMatcher *&convolutionCapture, - StructuredOpMatcher *&fillCapture, - StructuredOpMatcher *&trailingCapture, - MatchedConvolutionCaptures &captures, - bool mustMatchEntireFunc); -void makeConvolutionMatcher(MatcherContext &context, - StructuredOpMatcher *&convolutionCapture, - MatchedConvolutionCaptures &captures, - bool mustMatchEntireFunc); - -struct MatchedPadCaptures { - int64_t rank = 0; - Type elementType; - SmallVector dims = {}; -}; - -/// Create a matcher for tensor.pad(*) without leading or trailing ops atm. -/// If `mustMatchEntireFunc` is set, the matcher additionally checks if all -/// tileable operations in the functions are captured. -void makePadMatcher(MatcherContext &context, CapturingOpMatcher *&padCapture, - MatchedPadCaptures &captures, bool mustMatchEntireFunc); - -/// Wraps the given matcher callback to indicate that it must capture all -/// tilable ops in the parent function. Expects the callback to accept the same -/// arguments as what is expected by MatchCallbacksRegistry::register, followed -/// by a bool. -template -auto wrapAsEntireFuncMatch(Fn &&fn) { - return [fn = std::move(fn)](MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - return fn(res, loc, state, handles, true); - }; -} - -/// Wraps the given matcher callback to indicate that it can match subgraphs. -/// Expects the callback to accept the same arguments as what is expected by -/// MatchCallbacksRegistry::register, followed by a bool. -template -auto wrapAsPartialMatch(Fn &&fn) { - return [fn = std::move(fn)](MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - return fn(res, loc, state, handles, false); - }; -} - -} // namespace transform_ext -} // namespace mlir - -#endif // IREE_COMPILER_GLOBALOPTIMIZATION_TRANSFORMMATCHERS_H_ diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir index fc1b4e23e8a2..b75de6ce3cec 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir @@ -168,6 +168,115 @@ util.func public @softmax_broadcast(%93 : tensor<12x128x128xf32>) -> (tensor<12x // ----- +// Negative test: the max reduction is initialized with 0.0 instead of -inf, so +// this is not a numerically-stabilized softmax and must not be raised. +// CHECK-LABEL: @not_softmax_wrong_max_init +// CHECK-NOT: linalg.softmax +util.func public @not_softmax_wrong_max_init(%src : tensor) -> (tensor) { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %c_0_index = arith.constant 0 : index + %c_1_index = arith.constant 1 : index + %c_2_index = arith.constant 2 : index + %dim_0 = tensor.dim %src, %c_0_index : tensor + %dim_1 = tensor.dim %src, %c_1_index : tensor + %dim_2 = tensor.dim %src, %c_2_index : tensor + %1 = tensor.empty(%dim_0, %dim_1) : tensor + // Wrong init: 0.0 rather than -inf / lowest. + %2 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor) -> tensor + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%src : tensor) outs(%2 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.maximumf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + %4 = tensor.empty(%dim_0, %dim_1, %dim_2) : tensor + %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%src, %3 : tensor, tensor) outs(%4 : tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %11 = arith.subf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5 : tensor) outs(%4 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = math.exp %arg0 : f32 + linalg.yield %11 : f32 + } -> tensor + %7 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor) -> tensor + %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%6 : tensor) outs(%7 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.addf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%8 : tensor) outs(%1 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.divf %cst, %arg0 : f32 + linalg.yield %11 : f32 + } -> tensor + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6, %9 : tensor, tensor) outs(%4 : tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %11 = arith.mulf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + util.return %10 : tensor +} + +// ----- + +// Negative test: the max reduction reduces %src but the subtraction reads a +// different tensor %other, so the captured source is inconsistent and the +// pattern must not be raised. +// CHECK-LABEL: @not_softmax_mismatched_source +// CHECK-NOT: linalg.softmax +util.func public @not_softmax_mismatched_source(%src : tensor, %other : tensor) -> (tensor) { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %cst_1 = arith.constant -3.40282347E+38 : f32 + %c_0_index = arith.constant 0 : index + %c_1_index = arith.constant 1 : index + %c_2_index = arith.constant 2 : index + %dim_0 = tensor.dim %src, %c_0_index : tensor + %dim_1 = tensor.dim %src, %c_1_index : tensor + %dim_2 = tensor.dim %src, %c_2_index : tensor + %1 = tensor.empty(%dim_0, %dim_1) : tensor + %2 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor) -> tensor + // Reduces %src ... + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%src : tensor) outs(%2 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.maximumf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + %4 = tensor.empty(%dim_0, %dim_1, %dim_2) : tensor + // ... but subtracts from %other. + %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%other, %3 : tensor, tensor) outs(%4 : tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %11 = arith.subf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5 : tensor) outs(%4 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = math.exp %arg0 : f32 + linalg.yield %11 : f32 + } -> tensor + %7 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor) -> tensor + %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%6 : tensor) outs(%7 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.addf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%8 : tensor) outs(%1 : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %11 = arith.divf %cst, %arg0 : f32 + linalg.yield %11 : f32 + } -> tensor + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6, %9 : tensor, tensor) outs(%4 : tensor) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %11 = arith.mulf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor + util.return %10 : tensor +} + +// ----- + util.func public @aTransposeBMatmul(%arg0 : tensor<10x20xf32>, %arg1 : tensor<40x20xf32>) -> tensor<10x40xf32> { %0 = tensor.empty() : tensor<20x40xf32> From 18d5f9673eadd271c3791ab1a88e66fde60fb933 Mon Sep 17 00:00:00 2001 From: Han || Alex <36247722+Alex-Wengg@users.noreply.github.com> Date: Tue, 2 Jun 2026 18:58:02 -0400 Subject: [PATCH 5/8] [GlobalOpt] Accept arith.maxnumf in the softmax matcher (#24466) The stabilizing max of a softmax can be spelled with either arith.maximumf (NaN-propagating, as emitted by e.g. StableHLO frontends) or arith.maxnumf (NaN-ignoring). The latter is what linalg.softmax itself decomposes to (SoftmaxOp::decomposeOperation and the iree-codegen-decompose-softmax pass both use arith.maxnumf), so matching only arith.maximumf made the matcher narrower than the op it raises to and missed that form. Accept both ops for the max reduction. This is strictly safer: maxnumf is the form linalg.softmax decomposes to, so raising it introduces no NaN behavior change. Add a positive lit test for the maxnumf spelling. Signed-off-by: Han || Alex <36247722+Alex-Wengg@users.noreply.github.com> --- .../GlobalOptimization/RaiseSpecialOps.cpp | 17 ++++++-- .../test/raise_special_ops.mlir | 43 +++++++++++++++++++ 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp index cb9a999d17dd..4077d7604254 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp @@ -639,12 +639,21 @@ static FailureOr matchSoftmax(linalg::LinalgOp rootOp) { return failure(); } - // max = reduce_max(src), reducing the same source the subtraction reads. + // max = reduce_max(src), reducing the same source the subtraction reads. The + // init must be -inf or the lowest finite value. Accept both arith.maximumf + // (NaN-propagating, as emitted by e.g. StableHLO frontends) and arith.maxnumf + // (NaN-ignoring, which is what linalg.softmax itself decomposes to), since + // both denote the stabilizing max of a softmax. + auto isNegInfOrLowest = [](APFloat f) { + return (f.isLargest() || f.isInfinity()) && f.isNegative(); + }; Value source = subSource->get(); Value reducedValue = - matchInnermostReduction(maxValue, [](APFloat f) { - return (f.isLargest() || f.isInfinity()) && f.isNegative(); - }); + matchInnermostReduction(maxValue, isNegInfOrLowest); + if (!reducedValue) { + reducedValue = + matchInnermostReduction(maxValue, isNegInfOrLowest); + } if (reducedValue != source) { return failure(); } diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir index b75de6ce3cec..f2c9ebd525dd 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir @@ -168,6 +168,49 @@ util.func public @softmax_broadcast(%93 : tensor<12x128x128xf32>) -> (tensor<12x // ----- +// The stabilizing max may use arith.maxnumf (NaN-ignoring) instead of +// arith.maximumf -- this is the form linalg.softmax itself decomposes to. +// CHECK-LABEL: @softmax_maxnumf +// CHECK-SAME: %[[ARG:.+]]: tensor<2x4xf32> +// CHECK: %[[S:.+]] = linalg.softmax dimension(1) ins(%[[ARG]] : tensor<2x4xf32>) +// CHECK: util.return %[[S]] +util.func public @softmax_maxnumf(%src : tensor<2x4xf32>) -> (tensor<2x4xf32>) { + %cst0 = arith.constant 0.000000e+00 : f32 + %cstlow = arith.constant -3.40282347E+38 : f32 + %e1 = tensor.empty() : tensor<2xf32> + %e2 = tensor.empty() : tensor<2x4xf32> + %fillmax = linalg.fill ins(%cstlow : f32) outs(%e1 : tensor<2xf32>) -> tensor<2xf32> + %max = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%src : tensor<2x4xf32>) outs(%fillmax : tensor<2xf32>) { + ^bb0(%a: f32, %b: f32): + %m = arith.maxnumf %a, %b : f32 + linalg.yield %m : f32 + } -> tensor<2xf32> + %sub = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%src, %max : tensor<2x4xf32>, tensor<2xf32>) outs(%e2 : tensor<2x4xf32>) { + ^bb0(%a: f32, %b: f32, %c: f32): + %s = arith.subf %a, %b : f32 + linalg.yield %s : f32 + } -> tensor<2x4xf32> + %exp = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%sub : tensor<2x4xf32>) outs(%e2 : tensor<2x4xf32>) { + ^bb0(%a: f32, %b: f32): + %e = math.exp %a : f32 + linalg.yield %e : f32 + } -> tensor<2x4xf32> + %fillsum = linalg.fill ins(%cst0 : f32) outs(%e1 : tensor<2xf32>) -> tensor<2xf32> + %sum = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%exp : tensor<2x4xf32>) outs(%fillsum : tensor<2xf32>) { + ^bb0(%a: f32, %b: f32): + %s = arith.addf %a, %b : f32 + linalg.yield %s : f32 + } -> tensor<2xf32> + %div = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%exp, %sum : tensor<2x4xf32>, tensor<2xf32>) outs(%e2 : tensor<2x4xf32>) { + ^bb0(%a: f32, %b: f32, %c: f32): + %d = arith.divf %a, %b : f32 + linalg.yield %d : f32 + } -> tensor<2x4xf32> + util.return %div : tensor<2x4xf32> +} + +// ----- + // Negative test: the max reduction is initialized with 0.0 instead of -inf, so // this is not a numerically-stabilized softmax and must not be raised. // CHECK-LABEL: @not_softmax_wrong_max_init From 5056147400519b40af54f6905b28d50245eec3aa Mon Sep 17 00:00:00 2001 From: Alex-Wengg Date: Thu, 18 Jun 2026 10:26:39 -0400 Subject: [PATCH 6/8] Remove redundant transform_ext code from iree-dialects (#24466) The softmax StructuredOpMatcher infrastructure and ErrorCheckingTrackingListener were relocated into GlobalOptimization/ and Codegen/Common/ in prior commits. This deletes the now-dead originals in iree-dialects: - Transforms/TransformMatchers.{h,cpp} - the LinalgTransform StructuredTransformOpsExt transform-dialect extension (.td/.h/.cpp + TableGen) and its embedded ErrorCheckingTrackingListener - the iree-dialects CAPI (Dialects.{h,cpp}) and the Python bindings for the extension No compiler code includes any iree-dialects header anymore, so the deps on IREELinalgTransformDialect / IREEDialectsTransforms / CAPI across the compiler were stale; they are removed and the mirrored CMakeLists regenerated via bazel_to_cmake. The ireeRegisterTransformExtensions symbol is dropped from the public C API exports accordingly. Progress toward retiring the iree-dialects dependency. Signed-off-by: Alex-Wengg --- .../bazel_to_cmake/bazel_to_cmake_targets.py | 4 - compiler/plugins/target/LLVMCPU/BUILD.bazel | 1 - .../plugins/target/LLVMCPU/CMakeLists.txt | 1 - compiler/src/iree/compiler/API/BUILD.bazel | 1 - compiler/src/iree/compiler/API/CMakeLists.txt | 2 - compiler/src/iree/compiler/API/api_exports.c | 2 - .../src/iree/compiler/API/api_exports.def | 1 - compiler/src/iree/compiler/API/api_exports.ld | 1 - .../iree/compiler/API/api_exports.macos.lst | 1 - .../src/iree/compiler/API/generate_exports.py | 14 - .../iree/compiler/Codegen/LLVMCPU/BUILD.bazel | 1 - .../compiler/Codegen/LLVMCPU/CMakeLists.txt | 1 - .../LLVMCPU/TransformExtensions/BUILD.bazel | 2 - .../TransformExtensions/CMakeLists.txt | 2 - .../iree/compiler/Codegen/LLVMGPU/BUILD.bazel | 1 - .../compiler/Codegen/LLVMGPU/CMakeLists.txt | 1 - .../iree/compiler/Codegen/SPIRV/BUILD.bazel | 1 - .../compiler/Codegen/SPIRV/CMakeLists.txt | 1 - .../Flow/TransformExtensions/BUILD.bazel | 2 - .../Flow/TransformExtensions/CMakeLists.txt | 2 - .../Dialect/Flow/Transforms/BUILD.bazel | 2 - .../Dialect/Flow/Transforms/CMakeLists.txt | 2 - .../LinalgExt/TransformExtensions/BUILD.bazel | 2 - .../TransformExtensions/CMakeLists.txt | 2 - .../TransformExtensions/BUILD.bazel | 2 - .../TransformExtensions/CMakeLists.txt | 2 - compiler/src/iree/compiler/Tools/BUILD.bazel | 1 - .../src/iree/compiler/Tools/CMakeLists.txt | 1 - docs/website/generate_extra_files.sh | 1 - .../iree-dialects/BUILD.bazel | 211 +- .../include/iree-dialects-c/Dialects.h | 30 - .../include/iree-dialects/CMakeLists.txt | 1 - .../iree-dialects/Dialect/CMakeLists.txt | 1 - .../Dialect/LinalgTransform/CMakeLists.txt | 24 - .../StructuredTransformOpsExt.h | 97 - .../StructuredTransformOpsExt.td | 132 -- .../Transforms/TransformMatchers.h | 1201 ----------- .../iree-dialects/lib/CAPI/CMakeLists.txt | 10 - .../iree-dialects/lib/CAPI/Dialects.cpp | 32 - .../iree-dialects/lib/CMakeLists.txt | 3 - .../iree-dialects/lib/Dialect/CMakeLists.txt | 1 - .../Dialect/LinalgTransform/CMakeLists.txt | 1 - .../Dialect/LinalgTransform/IR/CMakeLists.txt | 38 - .../IR/StructuredTransformOpsExt.cpp | 999 --------- .../lib/Transforms/CMakeLists.txt | 19 - .../lib/Transforms/TransformMatchers.cpp | 1845 ----------------- .../iree-dialects/python/CMakeLists.txt | 12 +- .../python/IREEDialectsModule.cpp | 3 +- .../dialects/IreeStructuredTransformOps.td | 12 - .../_iree_structured_transform_ops_ext.py | 84 - .../dialects/transform/iree_structured.py | 7 - .../tools/iree-dialects-opt/CMakeLists.txt | 2 - .../iree-dialects-opt/iree-dialects-opt.cpp | 2 - 53 files changed, 3 insertions(+), 4821 deletions(-) delete mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h delete mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt delete mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt delete mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h delete mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td delete mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h delete mode 100644 llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt delete mode 100644 llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp delete mode 100644 llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt delete mode 100644 llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt delete mode 100644 llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt delete mode 100644 llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp delete mode 100644 llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt delete mode 100644 llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp delete mode 100644 llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeStructuredTransformOps.td delete mode 100644 llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py delete mode 100644 llvm-external-projects/iree-dialects/python/iree/compiler/dialects/transform/iree_structured.py diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index bebad4e0ea34..297e05f69b8c 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -27,10 +27,6 @@ def __init__(self, repo_map: Dict[str, str]): f"{iree_core_repo}//compiler/src/iree/compiler/API:CAPI": [ "IREECompilerCAPILib" ], - # IREE llvm-external-projects - f"{iree_core_repo}//llvm-external-projects/iree-dialects:CAPI": [ - "IREEDialectsCAPI" - ], # Disable all hard-coded codegen targets (they are expanded dynamically # in CMake). "@llvm-project//llvm:AArch64AsmParser": ["IREELLVMCPUTargetDeps"], diff --git a/compiler/plugins/target/LLVMCPU/BUILD.bazel b/compiler/plugins/target/LLVMCPU/BUILD.bazel index 2b931e7df565..8e25d6be7a3a 100644 --- a/compiler/plugins/target/LLVMCPU/BUILD.bazel +++ b/compiler/plugins/target/LLVMCPU/BUILD.bazel @@ -47,7 +47,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/PluginAPI", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:AArch64AsmParser", "@llvm-project//llvm:AArch64CodeGen", "@llvm-project//llvm:ARMAsmParser", diff --git a/compiler/plugins/target/LLVMCPU/CMakeLists.txt b/compiler/plugins/target/LLVMCPU/CMakeLists.txt index b0cb849fa713..a70f61efd8a9 100644 --- a/compiler/plugins/target/LLVMCPU/CMakeLists.txt +++ b/compiler/plugins/target/LLVMCPU/CMakeLists.txt @@ -31,7 +31,6 @@ iree_cc_library( ::LinkerTool ::StaticLibraryGenerator IREELLVMCPUTargetDeps - IREELinalgTransformDialect LLVMAnalysis LLVMBitReader LLVMBitWriter diff --git a/compiler/src/iree/compiler/API/BUILD.bazel b/compiler/src/iree/compiler/API/BUILD.bazel index 9ca3ea751001..f10c9226d0ed 100644 --- a/compiler/src/iree/compiler/API/BUILD.bazel +++ b/compiler/src/iree/compiler/API/BUILD.bazel @@ -37,7 +37,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/API/Internal:IREEOptToolEntryPoint", "//compiler/src/iree/compiler/API/Internal:IREEReduceToolEntryPoint", "//compiler/src/iree/compiler/API/Internal:LLDToolEntryPoint", - "//llvm-external-projects/iree-dialects:CAPI", "@llvm-project//mlir:CAPIAMDGPU", "@llvm-project//mlir:CAPIDebug", "@llvm-project//mlir:CAPIGPU", diff --git a/compiler/src/iree/compiler/API/CMakeLists.txt b/compiler/src/iree/compiler/API/CMakeLists.txt index 8d76da8211cb..d0a0a697a71f 100644 --- a/compiler/src/iree/compiler/API/CMakeLists.txt +++ b/compiler/src/iree/compiler/API/CMakeLists.txt @@ -14,7 +14,6 @@ iree_cc_library( NAME StaticImpl DEPS - IREEDialectsCAPI MLIRCAPIAMDGPU MLIRCAPIDebug MLIRCAPIExportSMTLIB @@ -82,7 +81,6 @@ set(_EXPORT_OBJECT_LIBS iree_compiler_API_Internal_IREEOptToolEntryPoint.objects iree_compiler_API_Internal_IREEReduceToolEntryPoint.objects iree_compiler_API_Internal_LLDToolEntryPoint.objects - obj.IREEDialectsCAPI obj.MLIRCAPIAMDGPU obj.MLIRCAPIDebug obj.MLIRCAPIExportSMTLIB diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c index 31a69d22d4a9..731a366c9faf 100644 --- a/compiler/src/iree/compiler/API/api_exports.c +++ b/compiler/src/iree/compiler/API/api_exports.c @@ -154,7 +154,6 @@ extern void ireeLinkRunMain(); extern void ireeMlirLspServerRunMain(); extern void ireeOptRunMain(); extern void ireeReduceRunMain(); -extern void ireeRegisterTransformExtensions(); extern void mlirAffineAddExprGet(); extern void mlirAffineBinaryOpExprGetLHS(); extern void mlirAffineBinaryOpExprGetRHS(); @@ -1330,7 +1329,6 @@ uintptr_t __iree_compiler_hidden_force_extern() { x += (uintptr_t)&ireeMlirLspServerRunMain; x += (uintptr_t)&ireeOptRunMain; x += (uintptr_t)&ireeReduceRunMain; - x += (uintptr_t)&ireeRegisterTransformExtensions; x += (uintptr_t)&mlirAffineAddExprGet; x += (uintptr_t)&mlirAffineBinaryOpExprGetLHS; x += (uintptr_t)&mlirAffineBinaryOpExprGetRHS; diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def index c62f0082631f..a1eb623ceae8 100644 --- a/compiler/src/iree/compiler/API/api_exports.def +++ b/compiler/src/iree/compiler/API/api_exports.def @@ -144,7 +144,6 @@ EXPORTS ireeMlirLspServerRunMain ireeOptRunMain ireeReduceRunMain - ireeRegisterTransformExtensions mlirAffineAddExprGet mlirAffineBinaryOpExprGetLHS mlirAffineBinaryOpExprGetRHS diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld index 902f3c9c7986..8916a9d471b4 100644 --- a/compiler/src/iree/compiler/API/api_exports.ld +++ b/compiler/src/iree/compiler/API/api_exports.ld @@ -145,7 +145,6 @@ VER_0 { ireeMlirLspServerRunMain; ireeOptRunMain; ireeReduceRunMain; - ireeRegisterTransformExtensions; mlirAffineAddExprGet; mlirAffineBinaryOpExprGetLHS; mlirAffineBinaryOpExprGetRHS; diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst index fc7d0fa15962..3daa67377cac 100644 --- a/compiler/src/iree/compiler/API/api_exports.macos.lst +++ b/compiler/src/iree/compiler/API/api_exports.macos.lst @@ -143,7 +143,6 @@ _ireeLinkRunMain _ireeMlirLspServerRunMain _ireeOptRunMain _ireeReduceRunMain -_ireeRegisterTransformExtensions _mlirAffineAddExprGet _mlirAffineBinaryOpExprGetLHS _mlirAffineBinaryOpExprGetRHS diff --git a/compiler/src/iree/compiler/API/generate_exports.py b/compiler/src/iree/compiler/API/generate_exports.py index 31b6bd0e741a..05d79fef99e3 100755 --- a/compiler/src/iree/compiler/API/generate_exports.py +++ b/compiler/src/iree/compiler/API/generate_exports.py @@ -67,10 +67,6 @@ "Dialect/PDL.h", ] -IREE_DIALECTS_HEADER_FILES = [ - "Dialects.h", -] - IREE_COMPILER_DIALECTS_HEADER_FILES = [ "iree_codegen.h", "iree_gpu.h", @@ -101,16 +97,6 @@ def main(repo_root: Path, api_root: Path): for local_name in LOCAL_HEADER_FILES: export_symbols.extend(collect_header_exports(api_root / local_name)) - # Collect symbols from iree-dialects header files. - for local_name in IREE_DIALECTS_HEADER_FILES: - export_symbols.extend( - collect_header_exports( - repo_root - / "llvm-external-projects/iree-dialects/include/iree-dialects-c" - / local_name - ) - ) - # Collect symbols from iree compiler dialect header files. for local_name in IREE_COMPILER_DIALECTS_HEADER_FILES: export_symbols.extend( diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index 2e82813eb75b..10537c9023df 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -117,7 +117,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Transforms", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "//runtime/src/iree/schemas:cpu_data", "//runtime/src/iree/schemas/instruments", "@llvm-project//llvm:BinaryFormat", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 21ee96022063..b49e7dbce7a7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -82,7 +82,6 @@ iree_cc_library( DEPS ::PassHeaders ::PassesIncGen - IREELinalgTransformDialect LLVMBinaryFormat LLVMSupport LLVMTargetParser diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel index 522ae9359b0e..367ec8a8ea34 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel @@ -58,8 +58,6 @@ iree_compiler_cc_library( ], deps = [ ":LLVMCPUExtensionsOpGen", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:TransformDialect", ], diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/CMakeLists.txt index fa06b1419552..10f5eae8ce47 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/CMakeLists.txt @@ -31,8 +31,6 @@ iree_cc_library( "LLVMCPUExtensionsOps.cpp.inc" DEPS ::LLVMCPUExtensionsOpGen - IREEDialectsTransforms - IREELinalgTransformDialect MLIRIR MLIRTransformDialect PUBLIC diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index e8e27a7cb247..dc6a9d4f8fa8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -193,7 +193,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Transforms", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", "@llvm-project//mlir:AMDGPUToROCDL", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index 8897d7253611..a38916616b89 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -129,7 +129,6 @@ iree_cc_library( ::PassesIncGen ::ROCDLPassHeaders ::ROCDLPassesIncGen - IREELinalgTransformDialect LLVMSupport MLIRAMDGPUDialect MLIRAMDGPUToROCDL diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel index 3f0cb389343a..201158fcfff9 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel @@ -110,7 +110,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Transforms", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineAnalysis", "@llvm-project//mlir:AffineDialect", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt index 3a72e9982213..ec2e172ef876 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt @@ -79,7 +79,6 @@ iree_cc_library( DEPS ::PassHeaders ::PassesIncGen - IREELinalgTransformDialect LLVMSupport MLIRAffineAnalysis MLIRAffineDialect diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel index b528e6b198d2..98eb5921c6c9 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel @@ -61,8 +61,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/Flow/Transforms", "//compiler/src/iree/compiler/Dialect/TensorExt/IR", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/CMakeLists.txt index e698587c89b7..66fe9c2cb575 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/CMakeLists.txt @@ -31,8 +31,6 @@ iree_cc_library( "FlowExtensionsOps.cpp.inc" DEPS ::FlowExtensionsOpGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRAnalysis MLIRArithDialect diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel index 3b367579d3ec..da95c44b97ee 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel @@ -84,8 +84,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineUtils", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index 136de136cb2b..69c8b8185038 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -53,8 +53,6 @@ iree_cc_library( "VerifyInputLegality.cpp" DEPS ::PassesIncGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRAffineDialect MLIRAffineUtils diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel index 71ccf89cd779..11a03e6a7151 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel @@ -63,8 +63,6 @@ iree_compiler_cc_library( ":LinalgExtExtensionsOpGen", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/CMakeLists.txt index a634d0e869b5..0e2dd24b5cc9 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/CMakeLists.txt @@ -31,8 +31,6 @@ iree_cc_library( "LinalgExtExtensionsOps.cpp.inc" DEPS ::LinalgExtExtensionsOpGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRIR MLIRLinalgDialect diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel index 5ddf952d8fe9..7982c8feaf87 100644 --- a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel @@ -60,8 +60,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils", "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREEDialectsTransforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgInterfaces", diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt index 21977dd066e9..633f4e874ab1 100644 --- a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt @@ -31,8 +31,6 @@ iree_cc_library( "PreprocessingExtensionsOps.cpp.inc" DEPS ::PreprocessingExtensionsOpGen - IREEDialectsTransforms - IREELinalgTransformDialect LLVMSupport MLIRIR MLIRLinalgInterfacesIncGenLib diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel index 3ddd96e03d39..5a8acd213473 100644 --- a/compiler/src/iree/compiler/Tools/BUILD.bazel +++ b/compiler/src/iree/compiler/Tools/BUILD.bazel @@ -76,7 +76,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Preprocessing:Passes", "//compiler/src/iree/compiler/Preprocessing/TransformExtensions:PreprocessingExtensions", "//compiler/src/iree/compiler/Transforms", - "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//mlir:IR", ], ) diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt index b8be6ff36a3d..de2f6848f95c 100644 --- a/compiler/src/iree/compiler/Tools/CMakeLists.txt +++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt @@ -24,7 +24,6 @@ iree_cc_library( "init_iree_dialects.h" "init_iree_passes.h" DEPS - IREELinalgTransformDialect MLIRIR iree::compiler::Bindings::Native::Transforms iree::compiler::Bindings::TFLite::Transforms diff --git a/docs/website/generate_extra_files.sh b/docs/website/generate_extra_files.sh index 25be565b4fb1..c1014c69724b 100755 --- a/docs/website/generate_extra_files.sh +++ b/docs/website/generate_extra_files.sh @@ -52,7 +52,6 @@ cp -r "${BUILD_PASSES_ORIGINAL_DIR}/." "${BUILD_PASSES_PROCESSED_DIR}" # Delete any dialect docs we don't want to publish (yet?). rm "${BUILD_DIALECTS_PROCESSED_DIR}/SimpleIODialect.md" # Sample dialect, just ignore -rm "${BUILD_DIALECTS_PROCESSED_DIR}/StructuredTransformOpsExt.md" # Dialect extensions # Trim "Dialect"/"Passes" suffix from file names e.g. FlowDialect.md -> Flow.md. for f in ${BUILD_DIALECTS_PROCESSED_DIR}/*Dialect.md; do diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel index c2dce5810f34..811772bff3fe 100644 --- a/llvm-external-projects/iree-dialects/BUILD.bazel +++ b/llvm-external-projects/iree-dialects/BUILD.bazel @@ -1,5 +1,4 @@ -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") +load("@rules_cc//cc:defs.bzl", "cc_binary") package( default_visibility = ["//visibility:public"], @@ -25,212 +24,6 @@ filegroup( ), ) -################################################################################ -# Tablegen exports -################################################################################ - -td_library( - name = "TdFiles", - srcs = glob( - [ - "include/iree-dialects/Dialect/Input/*.td", - "include/iree-dialects/Dialect/LinalgTransform/*.td", - "python/iree/compiler/dialects/*.td", - ], - allow_empty = True, - ), - includes = ["include"], - deps = [ - "@llvm-project//mlir:BuiltinDialectTdFiles", - "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:PDLDialectTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - "@llvm-project//mlir:TransformDialectTdFiles", - ], -) - -################################################################################ -# IREELinalgTransform Dialect -################################################################################ - -cc_library( - name = "IREEDialectsTransforms", - srcs = glob( - [ - "lib/Transforms/*.cpp", - ], - allow_empty = True, - ), - hdrs = glob( - [ - "include/iree-dialects/Transforms/*.h", - ], - allow_empty = True, - ), - includes = ["include"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FunctionInterfaces", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LinalgInterfaces", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:Rewrite", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformDialect", - "@llvm-project//mlir:TransformDialectInterfaces", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - ], -) - -gentbl_cc_library( - name = "IREELinalgTransformStructuredIncGen", - strip_include_prefix = "include", - tbl_outs = [ - ( - ["--gen-op-decls"], - "include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h.inc", - ), - ( - ["--gen-op-defs"], - "include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td", - deps = [ - ":TdFiles", - ], -) - -cc_library( - name = "IREELinalgTransformDialect", - srcs = glob( - [ - "lib/Dialect/LinalgTransform/IR/*.cpp", - "lib/Dialect/LinalgTransform/IR/*.h", - ], - allow_empty = True, - ), - hdrs = glob( - [ - "include/iree-dialects/Dialect/LinalgTransform/*.h", - ], - allow_empty = True, - ), - includes = ["include"], - deps = [ - ":IREEDialectsTransforms", - ":IREELinalgTransformStructuredIncGen", - "@llvm-project//llvm:Support", - - # Dialects - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineUtils", - "@llvm-project//mlir:AsyncDialect", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:BufferizationDialect", - "@llvm-project//mlir:BufferizationTransforms", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FunctionInterfaces", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:PDLDialect", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SCFUtils", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TilingInterface", - "@llvm-project//mlir:TransformDialect", - "@llvm-project//mlir:TransformDialectInterfaces", - "@llvm-project//mlir:TransformPDLExtension", - - # IR - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Rewrite", - - # Interfaces - "@llvm-project//mlir:ControlFlowInterfaces", - - # Transforms - "@llvm-project//mlir:AffineToStandard", - "@llvm-project//mlir:AsyncTransforms", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:MemRefTransforms", - "@llvm-project//mlir:ReconcileUnrealizedCasts", - "@llvm-project//mlir:SCFTransforms", - "@llvm-project//mlir:TensorTransformOps", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:VectorToSCF", - - # Utils - "@llvm-project//mlir:ArithUtils", - "@llvm-project//mlir:DialectUtils", - - # Conversions - "@llvm-project//mlir:AsyncToLLVM", - "@llvm-project//mlir:FuncToLLVM", - "@llvm-project//mlir:IndexToLLVM", - "@llvm-project//mlir:LinalgToStandard", - "@llvm-project//mlir:MathToLLVM", - "@llvm-project//mlir:MemRefToLLVM", - "@llvm-project//mlir:SCFToControlFlow", - "@llvm-project//mlir:VectorToLLVM", - ], -) - -################################################################################ -# CAPI -################################################################################ - -cc_library( - name = "CAPI", - srcs = glob( - ["lib/CAPI/*.cpp"], - allow_empty = True, - ), - hdrs = glob( - ["include/iree-dialects-c/*.h"], - allow_empty = True, - ), - includes = ["include"], - deps = [ - ":IREELinalgTransformDialect", - "@llvm-project//mlir:CAPIIR", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgTransformOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformDialect", - ], -) - -################################################################################ -# Test lib -################################################################################ - -cc_library( - name = "IREEDialectsTest", - deps = [ - ":IREEDialectsTransforms", - ":IREELinalgTransformDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Rewrite", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - ], -) - ################################################################################ # Tools ################################################################################ @@ -242,8 +35,6 @@ cc_binary( ], tags = ["hostonly"], deps = [ - ":IREEDialectsTest", - ":IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h deleted file mode 100644 index c729238d8f63..000000000000 --- a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_C_DIALECTS_H -#define IREE_DIALECTS_C_DIALECTS_H - -#include "mlir-c/IR.h" -#include "mlir-c/Pass.h" -#include "mlir-c/RegisterEverything.h" - -#ifdef __cplusplus -extern "C" { -#endif - -//===--------------------------------------------------------------------===// -// TransformDialect -//===--------------------------------------------------------------------===// - -MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform); - -MLIR_CAPI_EXPORTED void ireeRegisterTransformExtensions(MlirContext context); - -#ifdef __cplusplus -} -#endif - -#endif // IREE_DIALECTS_C_DIALECTS_H diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/CMakeLists.txt index 0ca0f41c5af4..e69de29bb2d1 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/CMakeLists.txt @@ -1 +0,0 @@ -add_subdirectory(Dialect) diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt deleted file mode 100644 index 1da2860785ea..000000000000 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(LinalgTransform) diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt deleted file mode 100644 index c0ebeb41f586..000000000000 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -function(_add_transform_dialect_extension) - set(LLVM_TARGET_DEFINITIONS StructuredTransformOpsExt.td) - mlir_tablegen(StructuredTransformOpsExt.h.inc -gen-op-decls) - mlir_tablegen(StructuredTransformOpsExt.cpp.inc -gen-op-defs) - add_public_tablegen_target(IREELinalgTransformExtIncGen) - add_dependencies(mlir-headers IREELinalgTransformExtIncGen) -endfunction() - -function(_add_structured_transform_doc) - set(LLVM_TARGET_DEFINITIONS StructuredTransformOpsExt.td) - mlir_tablegen(StructuredTransformOpsExt.md -gen-dialect-doc) - set(GEN_DOC_FILE ${IREE_DIALECTS_BINARY_DIR}/docs/Dialects/StructuredTransformOpsExt.md) - add_custom_command( - OUTPUT ${GEN_DOC_FILE} - COMMAND ${CMAKE_COMMAND} -E copy - ${CMAKE_CURRENT_BINARY_DIR}/StructuredTransformOpsExt.md - ${GEN_DOC_FILE} - DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/StructuredTransformOpsExt.md) - add_custom_target(StructuredTransformOpsExtDocGen DEPENDS ${GEN_DOC_FILE}) - add_dependencies(iree-dialects-doc StructuredTransformOpsExtDocGen) -endfunction() - -_add_transform_dialect_extension() -_add_structured_transform_doc() diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h deleted file mode 100644 index c7b6d0dc750b..000000000000 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H -#define IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H - -#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Dialect/Transform/IR/TransformOps.h" -#include "mlir/Dialect/Transform/IR/TransformTypes.h" -#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/OpDefinition.h" - -namespace mlir { -namespace linalg { -class LinalgOp; -} // namespace linalg -namespace scf { -class ForOp; -} // namespace scf -namespace transform_ext { -class MatchCallbackOp; -} // namespace transform_ext - -/// Matches a C++ callback previously registered under `callbackName` and -/// taking arguments `args`. -/// Unpacks a number of handles `N` (asserts there are exactly `N` matched -/// ops but this could be relaxed if needed). Returns the tuple of handles. -template -auto unpackRegisteredMatchCallback(ImplicitLocOpBuilder &b, - StringRef callbackName, - MatchingArgs... args) { - SmallVector matchedTypes(N, transform::AnyOpType::get(b.getContext())); - auto matchOp = b.create( - matchedTypes, callbackName, std::forward(args)...); - assert(matchOp->getNumResults() == N && "Unexpected number of results"); - std::array a; - for (int64_t i = 0; i < N; ++i) { - a[i] = matchOp->getResult(i); - } - return std::tuple_cat(a); -} - -/// A tracking listener for tensor IR that checks for payload replacement -/// errors. -class ErrorCheckingTrackingListener : public transform::TrackingListener { -public: - using transform::TrackingListener::TrackingListener; - - ~ErrorCheckingTrackingListener() override { - assert(status.succeeded() && "must check listener error state"); - } - - /// Return "true" if this tracking listener had a failure. - bool failed() const { return !status.succeeded(); } - - /// Check and return the current error state of this listener. In case of a - /// failure state, only the most recent error is returned. Afterwards, resets - /// the error state. - DiagnosedSilenceableFailure checkAndResetError() { - DiagnosedSilenceableFailure result(std::move(status)); - status = DiagnosedSilenceableFailure::success(); - return result; - } - -private: - void - notifyPayloadReplacementNotFound(Operation *op, ValueRange values, - DiagnosedSilenceableFailure &&diag) override; - - /// The error state of this listener. "Success" indicates that no error - /// happened so far. Otherwise, the status contains the most recent error. - DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success(); -}; - -} // namespace mlir - -#define GET_OP_CLASSES -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h.inc" - -namespace mlir { -namespace transform_ext { -class StructuredTransformOpsExtension - : public mlir::transform::TransformDialectExtension< - StructuredTransformOpsExtension> { -public: - StructuredTransformOpsExtension(); -}; - -} // namespace transform_ext -} // namespace mlir - -#endif // IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td deleted file mode 100644 index 81beb062615d..000000000000 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef STRUCTURED_TRANSFORM_OPS_EXT -#define STRUCTURED_TRANSFORM_OPS_EXT - -include "mlir/Dialect/Transform/IR/TransformAttrs.td" -include "mlir/Dialect/Transform/IR/TransformDialect.td" -include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/OpAsmInterface.td" -include "mlir/IR/OpBase.td" - -def RegisterMatchCallbacksOp : - Op, - DeclareOpInterfaceMethods, - ReportTrackingListenerFailuresOpTrait]> { - let description = [{ - Registers named structured op matcher callbacks specific for IREE to use - with `transform.iree.match_callback`. This should be called before first - `match_callback` may be executed following the transform dialect control - flow. - - The callbacks must have a unique name and a signature compatible with - `MatchCallbacksRegistry::MatchCallbackFn`, which currently means - `DiagnosedSilenceableFailure(MatchCallbackResult &, Location, - const TransformState &, ValueRange)`. The callback receives a "result", - followed by a location at which errors should be reported, a transform - state at the moment of the _match_ (not registration) and a list of - handle values passed as operands to the `match_callback` operation. - It is expected to populate the "result" object with lists of payload - operations that will be bound to the handles produced by the - `match_callback` operation. The callback may fail, at which point - it should produce a silenceable error. The callback currently is not - allowed to modify the payload IR (though this may be revised in the - future for the purpose of communicating the properties of the IR - captured by the match). Therefore, it should not have a reason to - produce a definite error. - }]; - - let arguments = (ins); - let results = (outs); - let assemblyFormat = "attr-dict"; - let cppNamespace = "mlir::transform_ext"; -} - -def MatchCallbackOp : - Op, - DeclareOpInterfaceMethods, - ReportTrackingListenerFailuresOpTrait]> { - let description = [{ - Performs payload IR matching using a C++ callback registered beforehand. - The callback is identified by name and is passed the current transform - state and the list of handle operands, along with information necessary - for error propagation. See `register_match_callbacks` for the description - of the callback contract. - - If `failure_propagation_mode` is set to `suppress`, any silenceable errors - in the callback (typically, "failure to match") will be ignored and the - resulting handles will be associated with empty lists of payload - operations. Otherwise, silenceable failures are propagated. - }]; - - let arguments = (ins StrAttr:$callback_name, - FailurePropagationMode:$failure_propagation_mode, - Variadic:$inputs); - let results = (outs Variadic:$outputs); - let assemblyFormat = "`failures` `(` $failure_propagation_mode `)` " - "$callback_name `(` $inputs `)` attr-dict " - "`:` functional-type($inputs, $outputs)"; - let cppNamespace = "mlir::transform_ext"; -} - -def TakeFirstOp : - Op, - DeclareOpInterfaceMethods, - ReportTrackingListenerFailuresOpTrait]> { - let description = [{ - Given an arbitrary list of handles associated with potentially empty lists - of payload operations, produces two new handles: - - - a handle pointing to the same payload operations as the first operand - handle with a non-empty list of payload operations; - - a handle pointing to the concatenated list of payload operations - associated with any other handle. - - Note that this does not perform any deduplication. - - This operation is useful to select a single target after some potentially - unsuccessful matches. - }]; - - let arguments = (ins Variadic:$inputs); - let results = (outs TransformHandleTypeInterface:$first, - TransformHandleTypeInterface:$rest); - let assemblyFormat = - "$inputs attr-dict `:` functional-type($inputs, results)"; - let cppNamespace = "mlir::transform_ext"; -} - -def EmitRemarkOp : - Op, - TransformOpInterface, TransformEachOpTrait, - ReportTrackingListenerFailuresOpTrait]> { - let description = [{ - Emits a diagnostic remark with the given message located at payload ops - associated with the given handle. This can be used, e.g., for debugging. - }]; - - let arguments = (ins TransformHandleTypeInterface:$handle, - StrAttr:$message); - let assemblyFormat = "$message `at` $handle attr-dict `:` type($handle)"; - let cppNamespace = "mlir::transform_ext"; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::transform::TransformRewriter &rewriter, - ::mlir::Operation *target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - -#endif // STRUCTURED_TRANSFORM_OPS_EXT diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h deleted file mode 100644 index 4b7d7169bb41..000000000000 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h +++ /dev/null @@ -1,1201 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_ -#define IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_ - -#include -#include -#include - -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" -#include "mlir/IR/Matchers.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/StringMap.h" - -namespace mlir { -namespace transform_ext { - -//===---------------------------------------------------------------------===// -// StructuredOpMatcher and predicates. -//===---------------------------------------------------------------------===// - -class StructuredOpMatcher; -class MatcherContext; -StructuredOpMatcher &m_StructuredOp(MatcherContext &); - -/// A tag indicating the shape being static or dynamic, for use with the -/// structured op matcher. -enum class ShapeKind { Static, Dynamic }; - -/// A placeholder indicating the structured op matcher to check the predicate -/// for all dimensions. -struct AllDims {}; - -/// A predicate indicating the structured op matcher to check the predicate for -/// all dimensions except the specified ones. -struct AllDimsExcept { - explicit AllDimsExcept(std::initializer_list range) { - llvm::append_range(exceptions, range); - } - ArrayRef getExcluded() const { return llvm::ArrayRef(exceptions); } - -private: - SmallVector exceptions; -}; - -/// A placeholder indicating the structured op matcher to check the predicate -/// for all operands of the relevant kind. -struct AllOperands {}; - -/// Base class for single-value captures. Concrete captures should inherit this -/// and forward the constructor via `using Base::Base`. -template -struct CaptureStaticValue { - using Base = CaptureStaticValue; - explicit CaptureStaticValue(T &value) : value(value) {} - T &value; -}; - -/// Captures the (static) size of the dimension. -struct CaptureDim : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures the (static) sizes of multiple dimensions. -struct CaptureDims : public CaptureStaticValue> { - using Base::Base; -}; - -/// Captures the contraction dimensions of the target operation. -struct CaptureIndexingMaps : public CaptureStaticValue> { - using Base::Base; -}; - -/// Captures the contraction dimensions of the target operation. -struct CaptureContractionDims - : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures the convolution dimensions of the target operation. -struct CaptureConvDims - : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures the rank of the operation. -struct CaptureRank : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures the bitwidth of an element type. -struct CaptureElementTypeBitWidth : public CaptureStaticValue { - using Base::Base; -}; - -/// Captures element element type. -struct CaptureElementType : public CaptureStaticValue { - using Base::Base; -}; - -template -struct CaptureAttribute : public CaptureStaticValue { - static_assert(std::is_base_of_v, - "can only capture a subclass of Attribute"); - using CaptureStaticValue::CaptureStaticValue; -}; - -/// A tag indicating to look for any user of the operation's result that would -/// satisfy the predicate. -struct HasAnyUse {}; - -/// Base class for predicate parameters that can be described with the single -/// value. Concrete predicate parameters should inherit this and forward the -/// constructor via `using Base::Base`. -template -struct SingleValuePredicateParam { - using Base = SingleValuePredicateParam; - explicit SingleValuePredicateParam(T value) : value(value) {} - const T value; -}; - -/// Indicates that the dimension must be divisible by the given value. -struct DivisibleBy : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the number of entities must be equal to the given value. -struct NumEqualsTo : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the number of entities must be greater than the given value. -struct NumGreaterEqualTo : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the number of entities must be greater than the given value. -struct NumLowerEqualTo : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Indicates that the bit width of the elemental type must be equal to the give -/// value. -struct ElementTypeBitWidth : public SingleValuePredicateParam { - using Base::Base; -}; - -/// Predicate tag indicating that the affine map is a permutation. -struct IsPermutation {}; - -/// Predicate tag indicating that the affine map is a projected permutation. -struct IsProjectedPermutation {}; - -/// Predicate tag indicating that the affine map is a projection of given -/// dimension. -struct IsProjected : public SingleValuePredicateParam { - using Base::Base; -}; -/// Predicate tag indicating that the affine map is an identity. -struct IsIdentity {}; - -/// Predicate tag indicating that the operand is a special float constant. -struct ConstantFloatMinOrMinusInf {}; -struct ConstantFloatZero {}; -struct ConstantFloatOne {}; - -/// Indicates that the match optional. The matcher is still expected to run and -/// capture if successful. The parameter can be set to false -struct OptionalMatch : public SingleValuePredicateParam { - OptionalMatch() : Base(true) {} - explicit OptionalMatch(bool set) : Base(set) {} -}; - -/// Predicate tag indicating that the reduction is produced by a single combiner -/// operation. -struct SingleCombinerReduction {}; - -class CapturingOpMatcher; -class CapturingValueMatcher; - -/// Base class for capturing matchers that can be owned by the context. -class CapturingMatcherBase { -public: - // Virtual destructor so unique pointers are deallocated correctly. - // TODO: if efficiency is a problem, consider disallowing non-trivial - // destructors for subclasses. - virtual ~CapturingMatcherBase() = default; - -protected: - /// Informs the matcher that it has another, nested matcher. Derived classes - /// must call this to keep track of nested matchers for capture resetting - /// purposes. - template - void recordNestedMatcher(T &nested) { - if constexpr (std::is_base_of_v) { - nestedCapturingMatchers.push_back(&nested); - } - if constexpr (std::is_base_of_v) { - nestedCapturingValueMatchers.push_back(&nested); - } - } - - /// Appends all nested capturing matchers of a certain kind, excluding this - /// one, to `nested`. - void getAllNested(SmallVectorImpl &nested); - void - getAllNestedValueMatchers(SmallVectorImpl &nested); - - /// Resets nested capturing matchers but does NOT reset the current one. - void resetCapture(); - -private: - /// A list of (recursively) nested capturing matchers that should be reset - /// when the current matcher is. - SmallVector nestedCapturingMatchers; - SmallVector nestedCapturingValueMatchers; -}; - -/// A context object holding capturing matchers, must outlive any individual -/// matcher. When matching complex subgraphs, the caller often doesn't care -/// about all intermediate nodes (operations) in the graph and shouldn't need to -/// hold matcher objects for those. These matchers can be created in this -/// context. -class MatcherContext { -public: - /// Create a new matcher of the specified type owned by this context. - template - std::enable_if_t, T> & - allocate(Args &&...args) { - // Need to call "new" explicitly as make_unique wouldn't have access to the - // private constructor when this class would. - ownedMatchers.emplace_back( - std::unique_ptr(new T(std::forward(args)...))); - return *static_cast(ownedMatchers.back().get()); - } - -private: - /// Owning list of matchers. - // TODO: If this becomes inefficient, consider something like BumpPtrAllocator - // that derived classes can use to store their members as well. - SmallVector> ownedMatchers; -}; - -/// Base class for value matchers that capture the matched value. Stores a list -/// of predicates and requires all of them to match for the value to match. Once -/// a value matched, any repeated use just verifies that equality of the value. -class CapturingValueMatcher : public CapturingMatcherBase { - friend class CapturingMatcherBase; - friend class MatcherContext; - - using PredicateFn = std::function; - -public: - /// Resets the captured value to null. This should be called if the same - /// pattern needs to be applied more than once as it may keep captured values - /// for optional nested predicates from the previous application. - void resetCapture() { - captured = nullptr; - CapturingMatcherBase::resetCapture(); - } - - /// Returns the matched value if the match was successful. - Value getCaptured() const { return captured; } - - /// Matches the given value, hook for `matchPattern`. - bool match(Value value); - -protected: - CapturingValueMatcher() = default; - - /// Adds a predicate to the end of the predicate list for this value matcher. - template - void addPredicate(Fn &&predicate) { - predicates.emplace_back(std::forward(predicate)); - } - - /// The captured value. - Value captured = nullptr; - -private: - /// Additional predicates to be checked on the value. - SmallVector predicates; -}; - -/// Creates a matcher of an arbitrary value. -inline CapturingValueMatcher &m_Value(MatcherContext &context) { - return context.allocate(); -} - -/// Matcher for typed values whose type implements the `ShapedType` interface. -/// Allows for matching the components of the shaped type such as rank and -/// dimensions. -class ShapedValueMatcher : public CapturingValueMatcher { - friend class MatcherContext; - - ShapedValueMatcher(); - -public: - /// Add an always-succeeding matcher predicate capturing the rank. - ShapedValueMatcher &rank(CaptureRank capture); - - /// Add an always-succeeding matcher predicate capturing the size of the - /// dimension identified by the first argument. - ShapedValueMatcher &dim(int64_t dimension, CaptureDim capture); - - /// Add an always-succeeding matcher predicate capturing the sizes of all - /// dimensions in order of appearance. - ShapedValueMatcher &dim(AllDims tag, CaptureDims captures); - - /// Add an always-succeeding matcher predicate capturing the element type of - /// the value. - ShapedValueMatcher &elementType(CaptureElementType captures); -}; - -/// Construct a new matcher of a value whose type is a `ShapedType`, owned by -/// the given context. -inline ShapedValueMatcher &m_ShapedValue(MatcherContext &context) { - return context.allocate(); -} - -/// Matcher for operations with additional predicates attachable through the -/// fluent, a.k.a. chainable, API. Note that public API must *not* accept -/// additional callbacks even; new predicates should be added instead when -/// necessary. Not only this decreases the depth of the callback stack and -/// increases readability, it also allows us to port the matcher to a -/// declarative format using PDL and/or Transform dialect in the future. The -/// latter will become impossible with arbitrary C++ callbacks. -class CapturingOpMatcher : public CapturingMatcherBase { - friend class CapturingMatcherBase; - friend class MatcherContext; - - template - friend CapturingOpMatcher &m_Operation(MatcherContext &matcherContext); - -public: - using PredicateFn = std::function; - - /// Matches the given operation, hook for `matchPattern`. - bool match(Operation *op); - - /// Resets the captured value to null. This should be called if the same - /// pattern needs to be applied more than once as it may keep captured values - /// for optional nested predicates from the previous application. - void resetCapture() { - captured = nullptr; - CapturingMatcherBase::resetCapture(); - } - - /// Returns the matched operation if the match was successful. - Operation *getCaptured() const { return captured; } - - /// Adds alternative paths for predicates. In practice, this is just a - /// predicate that is satisfied when either the first or the second matcher is - /// satisfied. The alternative satisfaction is eager and short-cutting, i.e., - /// the second alternative will not be processed, and therefore will not - /// capture values, if the first alternative succeeded. - CapturingOpMatcher &alternatives(CapturingOpMatcher &first, - CapturingOpMatcher &second); - - //===-------------------------------------------------------------------===// - // Constraints on adjacent ops. - //===-------------------------------------------------------------------===// - - /// Adds a predicate checking that all ops implementing TilingInterface in the - /// parent of the given type (e.g., a function or a module) were matched by - /// this or nested matchers. This is useful to ensure that the matcher covered - /// the entire parent region, not just a parent of it. This predicate **must** - /// be added *after* all the other predicates that capture. - template - CapturingOpMatcher &allTilableOpsCaptured() { - SmallVector copy; - copy.push_back(this); - getAllNested(copy); - addPredicate([copy = std::move(copy)](Operation *op) { - Operation *parent = op->getParentOfType(); - return checkAllTilableMatched(parent, op, copy); - }); - return *this; - } - - //-------------------------------------------------------------------------// - // Predicates for operands and results. - //-------------------------------------------------------------------------// - - /// Adds a predicate checking that the operation has exactly the given number - /// of operands. - CapturingOpMatcher &operand(NumEqualsTo num); - - /// Adds a predicate checking that the `pos`-th operand of the operation is - /// defined by an operation that satisfies the given matcher. - CapturingOpMatcher &operand(int64_t pos, CapturingOpMatcher &nested); - - /// Adds a predicate checking that the `pos`-th operand of the operation - /// satisfies the given value matcher. - CapturingOpMatcher &operand(int64_t pos, CapturingValueMatcher &nested); - - /// Adds a predicate checking that the `pos`-th operand of the operation is - /// defined by `arith.constant` with the value 1.0. - // TODO: better matching for attributes. - CapturingOpMatcher &operand(int64_t pos, ConstantFloatOne); - - /// Adds a predicate checking that the operation has exactly the given number - /// of results. - CapturingOpMatcher &result(NumEqualsTo num); - - /// Adds a predicate checking that the `pos`-th result of the operation - /// satisfies the given value matcher. - CapturingOpMatcher &result(int64_t pos, CapturingValueMatcher &nested); - -protected: - /// Constructs a default operation matcher accepting any operation. - CapturingOpMatcher() = default; - - /// Adds a predicate for the matched operation to satisfy. - template - void addPredicate(Fn &&predicate) { - predicates.emplace_back(std::forward(predicate)); - } - - /// Produce the debug output for `create` method in a non-templated way. - static void debugOutputForCreate(ArrayRef opNames); - -private: - /// A list of additional conditions for the operation to match. - SmallVector predicates; - - /// Checks that `matchers` captured all tilable ops nested in `parent` except - /// for `linalgOp`. This is an implementation detail of allTilableOpsCaptured. - static bool checkAllTilableMatched(Operation *parent, Operation *op, - ArrayRef matchers); - - /// Creates a matcher for an operation with one of the given types. - template - static CapturingOpMatcher create() { - CapturingOpMatcher matcher; - matcher.addPredicate([](Operation *op) { - debugOutputForCreate(ArrayRef{OpType::getOperationName()...}); - return isa(op); - }); - return matcher; - } - - /// Common util for constant matcher. - CapturingOpMatcher &operand(int64_t position, - std::function floatValueFn); - -protected: - /// Matched value. - Operation *captured = nullptr; -}; - -namespace detail { -/// Prints the debug output from the ConcreteOpMatcher constructor. The -/// implementation must reside in the C++ file so we don't pollute the header -/// with debug includes, and ConcreteOpMatcher is a class template that can only -/// reside in the header. -void debugOutputForConcreteOpMatcherConstructor(StringRef name); -} // namespace detail - -/// Base class for matchers that match a specific op. Adds an initial predicate -/// checking if the op is indeed of the specified kind. -/// Derived classes specializing this for op interfaces MUST also define a -/// specialization of DebugOpKindDescription. -template -class ConcreteOpMatcher : public CapturingOpMatcher { -protected: - using Base = ConcreteOpMatcher; - - static StringRef getConcreteOpDescription() { - return OpTy::getOperationName(); - } - - /// Adds a predicate checking if the op is of the OpTy kind. - ConcreteOpMatcher() { - CapturingOpMatcher::addPredicate([](Operation *op) { - detail::debugOutputForConcreteOpMatcherConstructor( - Derived::getConcreteOpDescription()); - return isa(op); - }); - } - - /// Adds a predicate for the matched operation to satisfy. - template - Derived &addPredicate(FnTy &&predicate) { - // Dispatch to the callback. - CapturingOpMatcher::addPredicate( - [inner = std::move(predicate)](Operation *op) { - return inner(cast(op)); - }); - return static_cast(*this); - } - -public: - /// Adds alternative paths for predicates. In practice, this is just a - /// predicate that is satisfied when either the first or the second matcher is - /// satisfied. The alternative satisfaction is eager and short-cutting, i.e., - /// the second alternative will not be processed, and therefore will not - /// capture values, if the first alternative succeeded. - Derived &alternatives(CapturingOpMatcher &first, CapturingOpMatcher &second) { - return static_cast( - CapturingOpMatcher::alternatives(first, second)); - } - - /// Adds a predicate checking that all ops implementing TilingInterface in the - /// parent of the given type (e.g., a function or a module) were matched by - /// this or nested matchers. This is useful to ensure that the matcher covered - /// the entire parent region, not just a parent of it. This predicate **must** - /// be added *after* all the other predicates that capture. - template - Derived &allTilableOpsCaptured() { - return static_cast( - CapturingOpMatcher::allTilableOpsCaptured()); - } - - //-------------------------------------------------------------------------// - // Predicates for operands and results. - //-------------------------------------------------------------------------// - - /// Adds a predicate checking that the operation has exactly the given number - /// of operands. - Derived &operand(NumEqualsTo num) { - return static_cast(CapturingOpMatcher::operand(num)); - } - - /// Adds a predicate checking that the `pos`-th operand of the operation is - /// defined by an operation that satisfies the given matcher. - Derived &operand(int64_t pos, CapturingOpMatcher &nested) { - return static_cast(CapturingOpMatcher::operand(pos, nested)); - } - - /// Adds a predicate checking that the `pos`-th operand of the operation - /// satisfies the given value matcher. - Derived &operand(int64_t pos, CapturingValueMatcher &nested) { - return static_cast(CapturingOpMatcher::operand(pos, nested)); - } - - /// Adds a predicate checking that the `pos`-th operand of the operation is - /// defined by `arith.constant` with the value 1.0. - // TODO: better matching for attributes. - Derived &operand(int64_t pos, ConstantFloatOne c) { - return static_cast(CapturingOpMatcher::operand(pos, c)); - } - - /// Adds a predicate checking that the operation has exactly the given number - /// of results. - Derived &result(NumEqualsTo num) { - return static_cast(CapturingOpMatcher::result(num)); - } - - /// Adds a predicate checking that the `pos`-th result of the operation - /// satisfies the given value matcher. - Derived &result(int64_t pos, CapturingValueMatcher &nested) { - return static_cast(CapturingOpMatcher::result(pos, nested)); - } -}; - -/// Matcher for the `tensor.pad` operation. -class TensorPadOpMatcher - : public ConcreteOpMatcher { - friend class MatcherContext; - - TensorPadOpMatcher() = default; - -public: - /// Adds a predicate checking that the low padding sizes are exactly the given - /// values. - TensorPadOpMatcher &low(ArrayRef sizes); - - /// Adds a predicate checking that the low padding sizes for all dimensions - /// are exactly the same given value. - TensorPadOpMatcher &low(AllDims tag, int64_t size); - - /// Adds a predicate checking that the high padding sizes for all dimensions - /// are exactly the same given value. - TensorPadOpMatcher &high(ArrayRef sizes); - - /// Adds a predicate checking that the high padding sizes for all dimensions - /// are exactly the same given value. - TensorPadOpMatcher &high(AllDims tag, int64_t size); - - /// Adds a predicate checking that the body of the pad only yields values - /// defined outside the pad region. - TensorPadOpMatcher &yieldsExternalValue(); -}; - -inline TensorPadOpMatcher &m_tensorPad(MatcherContext &matcherContext) { - return matcherContext.allocate(); -} - -/// Creates a default operation matcher in the given context that accepts any -/// operation. -inline CapturingOpMatcher &m_Operation(MatcherContext &matcherContext) { - return matcherContext.allocate(); -} - -/// Creates an operation matcher in the given context that accepts only -/// operations of the kinds provided as template arguments. -template -inline CapturingOpMatcher &m_Operation(MatcherContext &matcherContext) { - return matcherContext.allocate( - CapturingOpMatcher::create()); -} - -/// Matcher for structured aka Linalg operations. -class StructuredOpMatcher - : public ConcreteOpMatcher { - friend class MatcherContext; - - StructuredOpMatcher() = default; - -public: - static StringRef getConcreteOpDescription() { - return "linalg interface implementation"; - } - - /// Creates a matcher for a structured operation with one of the given types. - template - static StructuredOpMatcher create() { - StructuredOpMatcher matcher; - matcher.addPredicate([](Operation *op) { - debugOutputForCreate(ArrayRef{OpType::getOperationName()...}); - return isa(op) && isa(op); - }); - return matcher; - } - - /// Matches a structured operation if either patterns A or B match. - StructuredOpMatcher(StructuredOpMatcher &A, StructuredOpMatcher &B); - - //===-------------------------------------------------------------------===// - // Constraints on op rank and dims. - //===-------------------------------------------------------------------===// - /// Adds a predicate checking that the given rank must be greater than some - /// constant value. - StructuredOpMatcher &rank(NumGreaterEqualTo minRank); - StructuredOpMatcher &rank(NumLowerEqualTo maxRank); - StructuredOpMatcher &rank(NumEqualsTo exactRank); - - /// Adds a predicate checking that the given iteration space dimension is - /// static/dynamic. The dimension index may be negative, in which case - /// dimensions are counted from the last one (i.e. Python-style), or be an - /// AllDims tag, in which case all dimensions are checked. This may be - /// eventually extended to slices and/or lists of dimensions. - StructuredOpMatcher &dim(int64_t dimension, ShapeKind kind) { - return dim(SmallVector{dimension}, kind); - } - StructuredOpMatcher &dim(SmallVector &&dimensions, ShapeKind kind); - StructuredOpMatcher &dim(AllDims tag, ShapeKind kind); - - /// Adds a predicate checking that the given iteration space dimension has the - /// given iterator type, e.g., parallel or reduction. The dimension index may - /// be negative, in which case dimensions are counted from the last one - /// (i.e. Python-style), or be an AllDims tag, in which case all dimensions - /// are checked. This may be eventually extended to slices and/or lists of - /// dimensions. - StructuredOpMatcher &dim(int64_t dimension, utils::IteratorType kind) { - return dim(SmallVector{dimension}, kind); - } - // Ownership may get tricky here so we wrap in an explicit vector. - StructuredOpMatcher &dim(SmallVector &&dimensions, - utils::IteratorType kind); - StructuredOpMatcher &dim(AllDims tag, utils::IteratorType kind); - StructuredOpMatcher &dim(AllDimsExcept &&dimensions, - utils::IteratorType kind); - - /// Adds a predicate checking that the given iteration space dimension is - /// statically known to be divisible by the given value. The dimension index - /// may be negative, in which case dimensions are counted from the last one - /// (i.e. Python-style). - StructuredOpMatcher &dim(int64_t dimension, DivisibleBy divisibleBy); - - //===-------------------------------------------------------------------===// - // Capture directives. - //===-------------------------------------------------------------------===// - StructuredOpMatcher &rank(CaptureRank capture); - StructuredOpMatcher &dim(int64_t dimension, CaptureDim capture); - StructuredOpMatcher &dim(AllDims tag, CaptureDims captures); - StructuredOpMatcher &indexingMaps(CaptureIndexingMaps indexingMaps); - StructuredOpMatcher &contractionDims(CaptureContractionDims contractionDims); - StructuredOpMatcher &convolutionDims(CaptureConvDims convDims); - - //===-------------------------------------------------------------------===// - // Constraints on input operands. - //===-------------------------------------------------------------------===// - /// Adds a predicate checking that the structured op has the given number of - /// inputs. - StructuredOpMatcher &input(NumEqualsTo num); - - /// Adds a predicate that recursively applies other predicates to the - /// operation defining the `position`-th operand. The position may be - /// negative, in which case positions are counted from the last one - /// (i.e. Python-style). When the match is optional, the predicate check - /// succeeds as long as the `position` is in bounds. The matcher is executed - /// if there is a defining operation for the input operand. - template - std::enable_if_t::value, - StructuredOpMatcher &> - input(int64_t position, T &operandMatcher, - OptionalMatch optional = OptionalMatch(false)) { - addInputMatcher( - position, - [&operandMatcher](Operation *op) { return operandMatcher.match(op); }, - optional); - recordNestedMatcher(operandMatcher); - return *this; - } - template - std::enable_if_t::value, - StructuredOpMatcher &> - input(int64_t position, T &operandMatcher, - OptionalMatch optional = OptionalMatch(false)) { - addInputMatcher( - position, - [&operandMatcher](Value v) { return operandMatcher.match(v); }, - optional); - recordNestedMatcher(operandMatcher); - return *this; - } - - /// Adds a predicate checking that all input operands of the structured op - /// have a permutation indexing map. - StructuredOpMatcher &input(AllOperands tag, IsPermutation); - - /// Adds a predicate checking that all input operands of the structured op - /// have a projected permutation indexing map. - StructuredOpMatcher &input(AllOperands tag, IsProjectedPermutation); - - /// Adds a predicate checking that all input operands of the structured op - /// are projected along the given dimension. - StructuredOpMatcher &input(SmallVector &&positions, IsProjected dim); - StructuredOpMatcher &input(int64_t position, IsProjected dim) { - return input(SmallVector{position}, dim); - } - - /// Adds a predicate checking that all input operands of the structured op - /// have identity indexing map. - StructuredOpMatcher &input(AllOperands tag, IsIdentity); - StructuredOpMatcher &input(SmallVector &&positions, IsIdentity); - StructuredOpMatcher &input(int64_t position, IsIdentity) { - return input(SmallVector{position}, IsIdentity()); - } - - /// Adds a predicate checking that the bit width of the elemental type of the - /// structured op input at the given position is equal to the given value. - StructuredOpMatcher &input(int64_t position, ElementTypeBitWidth width); - - /// Capture the elemental type bitwidth of input operand `position`. - StructuredOpMatcher &input(int64_t position, - CaptureElementTypeBitWidth width); - - /// Capture the elemental type of input operand `position`. - StructuredOpMatcher &input(int64_t position, CaptureElementType elem); - - /// Check if input is equal to a known constant. - // TODO: Support matching for constant ops. - StructuredOpMatcher &input(int64_t position, ConstantFloatMinOrMinusInf); - StructuredOpMatcher &input(int64_t position, ConstantFloatZero); - - //===-------------------------------------------------------------------===// - // Constraints on output operands. - //===-------------------------------------------------------------------===// - - /// Adds a predicate checking that the structured op has the given number of - /// outputs. - StructuredOpMatcher &output(NumEqualsTo num); - - /// Adds a predicate checking that all output operands of the structured op - /// have a permutation indexing map. - StructuredOpMatcher &output(AllOperands tag, IsPermutation); - - /// Adds a predicate checking that all output operands of the structured op - /// have a projected permutation indexing map. - StructuredOpMatcher &output(AllOperands tag, IsProjectedPermutation); - - /// Adds a predicate checking that all output operands of the structured op - /// have a - StructuredOpMatcher &output(AllOperands tag, IsProjected dim); - - /// Adds a predicate checking that all output operands of the structured op - /// have identity indexing map. - StructuredOpMatcher &output(AllOperands tag, IsIdentity); - - /// Adds a predicate checking that the bit width of the elemental type of the - /// structured op output at the given position is equal to the given value. - StructuredOpMatcher &output(int64_t position, ElementTypeBitWidth width); - - /// Capture the elemental type bitwidth of output operand `position`. - StructuredOpMatcher &output(int64_t position, - CaptureElementTypeBitWidth width); - - /// Capture the elemental type of output operand `position`. - StructuredOpMatcher &output(int64_t position, CaptureElementType elem); - - /// Adds a predicate checking that the output of the structured op is produced - /// by a reduction with a single-operation combinator (such as addf or mulf, - /// but not a compare+select pair). - StructuredOpMatcher &output(int64_t position, SingleCombinerReduction tag); - - /// Adds a predicate that recursively applies other predicates to the - /// operation defining the init/out operand corresponding to `position`-th - /// output. The position may be negative, in which case positions are counted - /// from the last one (i.e. Python-style). When the match is optional, the - /// predicate check succeeds as long as the `position` is in bounds. The - /// matcher executed if there is a defining operation for the output operand. - template - std::enable_if_t::value, - StructuredOpMatcher &> - output(int64_t position, T &operandMatcher, - OptionalMatch optional = OptionalMatch(false)) { - addOutputMatcher( - position, - [&operandMatcher](Operation *op) { return operandMatcher.match(op); }, - optional); - recordNestedMatcher(operandMatcher); - return *this; - } - - //===-------------------------------------------------------------------===// - // Constraints on results. - //===-------------------------------------------------------------------===// - - /// Adds a predicate that recursively applies to users of the `position`-th - /// result of the structured op. Succeeds if any user matches the predicate. - /// When the match is optional, the predicate check succeeds as long as the - /// `position` is in bounds, after running the given matcher. - template - std::enable_if_t::value, - StructuredOpMatcher &> - result(int64_t position, HasAnyUse tag, T &resultUserMatcher, - OptionalMatch optional = OptionalMatch(false)) { - addResultMatcher( - position, tag, - [&resultUserMatcher](Operation *op) { - return resultUserMatcher.match(op); - }, - optional); - recordNestedMatcher(resultUserMatcher); - return *this; - } - - //===-------------------------------------------------------------------===// - // Constraints on op region. - //===-------------------------------------------------------------------===// - - /// Return true if the linalg op only contains a single ops and the arguments - /// of the operation match the order of the linalg operand. - /// Example: - /// linalg.generic - /// ins(%0, %1 : tensor, tensor) - /// outs(%2 : tensor) { - /// ^bb0(%arg0: f32, %arg1: f32): - /// %3 = arith.maxf %arg0, %arg1 : f32 - /// linalg.yield %3 : f32 - /// } -> tensor - /// If commutative is set binary operations can have their operands swapped. - template - StructuredOpMatcher &singleOpWithCanonicaleArgs(bool commutative = false) { - return singleOpWithCanonicaleArgs(OpType::getOperationName(), commutative); - } - StructuredOpMatcher &singleOpWithCanonicaleArgs(StringRef opname, - bool commutative); - /// Check if the op is a linalg of with a single float reciprocal op. - StructuredOpMatcher &isFloatReciprocal(); - /// Check if the op is a linalg of with a region containing only a yield op - /// using block arguments in order. - StructuredOpMatcher &passThroughOp(); - - /// Check if the body of the linalg op implements a contraction of the kind - /// result = input1 input2 - template - StructuredOpMatcher &hasContractionBody() { - return hasContractionBody( - [](Operation *op) { return isa(op); }, - [](Operation *op) { return isa(op); }, - ElemOpTy::getOperationName(), ReductionOpTy::getOperationName()); - } - -private: - /// Non-template implementations of nested predicate builders for inputs, - /// outputs and results. Should not be called directly. - void addInputMatcher(int64_t position, - std::function matcher, - OptionalMatch optional); - void addInputMatcher(int64_t position, std::function matcher, - OptionalMatch optional); - void addOutputMatcher(int64_t position, - std::function matcher, - OptionalMatch optional); - void addResultMatcher(int64_t position, HasAnyUse tag, - std::function matcher, - OptionalMatch optional); - - // Common util for constant matcher. - StructuredOpMatcher &input(int64_t position, - std::function floatValueFn); - - /// Non-template implementation of hasContractionBody. Takes callbacks for - /// checking operation kinds and names for error reporting. - StructuredOpMatcher & - hasContractionBody(function_ref isaElemOpTy, - function_ref isaReductionOpTy, - StringRef elemOpName, StringRef reductionOpName); -}; - -/// Creates a matcher of an arbitrary structured op. -inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext) { - return matcherContext.allocate(); -} - -/// Creates a matcher that is a copy of the given matcher. -inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext, - const StructuredOpMatcher &other) { - return matcherContext.allocate(other); -} - -/// Creates a matcher that accepts as disjunction of the two given matchers. -inline StructuredOpMatcher &m_StructuredOp_Or(MatcherContext &matcherContext, - StructuredOpMatcher &A, - StructuredOpMatcher &B) { - return matcherContext.allocate(A, B); -} - -/// Creates a matcher of a structured op with kinds provided as template -/// arguments. -template -inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext) { - return matcherContext.allocate( - StructuredOpMatcher::create()); -} - -//===---------------------------------------------------------------------===// -// MatchCallback functionality. -//===---------------------------------------------------------------------===// - -/// Additional results of the C++ callback usable in the `match_callback` -/// transform operation. Conceptually, a list of lists of payload operations to -/// be associated with each result handle. -class MatchCallbackResult { -public: - /// Returns the number of lists of payload operations. - int64_t getNumPayloadGroups() const { return payloadGroupLengths.size(); } - - /// Returns the `position`-th list of payload operations. - ArrayRef getPayloadGroup(int64_t position) const; - - /// Adds a new list of payload operations to the list of lists. The new list - /// must not contain null operations. - template - int64_t addPayloadGroup(Range operations) { - int64_t originalLength = payloadOperations.size(); - assert(llvm::all_of(operations, [](Operation *op) -> bool { return op; }) && - "null operation"); - llvm::append_range(payloadOperations, operations); - payloadGroupLengths.push_back(payloadOperations.size() - originalLength); - return payloadGroupLengths.size() - 1; - } - void addPayloadGroup(ArrayRef operations) { - addPayloadGroup>(operations); - } - - /// Adds a new singleton list of payload operation to the list of lists if the - /// operation is non-null, adds an empty list otherwise. Useful for results of - /// optional matches. - void addPotentiallyEmptyPayloadGroup(Operation *op) { - if (!op) { - addPayloadGroup(ArrayRef()); - } else { - addPayloadGroup(ArrayRef(op)); - } - } - -private: - /// The flat list of all payload operations. `payloadGroupLengths` can be used - /// to compute the sublist that corresponds to one nested list. - // TODO: if somebody implements such a flattened vector generically, use it. - SmallVector payloadOperations; - SmallVector payloadGroupLengths; -}; - -/// A transform state extension that maintains the mapping between callback -/// names as strings usable in `match_callback` and their implementations. -class MatchCallbacksRegistry : public transform::TransformState::Extension { -public: - using MatchCallbackFn = std::function; - - /// Constructs the extension. - MatchCallbacksRegistry(transform::TransformState &state) - : transform::TransformState::Extension(state) {} - - /// Registers the given function as a callback with the given name. The name - /// must not be already present in the registry. The callback must be - /// convertible to MatchCallbackFn. - template - void registerCallback(StringRef name, Fn &&fn) { - bool succeeded = callbacks.try_emplace(name, std::forward(fn)).second; - (void)succeeded; - assert(succeeded && "adding a callback with a repeated name"); - } - - /// Returns a pointer to the implementation of the callback with the given - /// name, or null if it is not present in the registry. - const MatchCallbackFn *get(StringRef name) const { - auto iter = callbacks.find(name); - if (iter == callbacks.end()) { - return nullptr; - } - return &iter->getValue(); - } - -private: - llvm::StringMap callbacks; -}; - -//===---------------------------------------------------------------------===// -// Case-specific matcher builders. -//===---------------------------------------------------------------------===// - -struct MatchedReductionCaptures { - int64_t reductionRank = 0; - int64_t maybeLeadingRank = 0; - int64_t maybeTrailingRank = 0; - SmallVector leadingOpSizes = {}; - SmallVector reductionOpSizes = {}; - SmallVector trailingOpSizes = {}; - int64_t reductionOutputElementalTypeBitWidth = 0; - int64_t maybeLeadingOutputElementalTypeBitWidth = 0; - int64_t maybeTrailingOutputElementalTypeBitWidth = 0; -}; - -struct MatchedMatmulCaptures { - linalg::ContractionDimensions contractionDims = {}; - Type lhsElementType, rhsElementType, outputElementType; - SmallVector matmulOpSizes = {}; - SmallVector indexingMaps; - - /// Helper functions. - int64_t rank() const { return matmulOpSizes.size(); } - /// Return all batches. - ArrayRef batches() const { return contractionDims.batch; } - /// Return the most minor candidate dimension for `m`. - int64_t m() const { return contractionDims.m.back(); } - /// Return the most minor candidate dimension for `n`. - int64_t n() const { return contractionDims.n.back(); } - /// Return the most minor candidate dimension for `k`. - int64_t k() const { return contractionDims.k.back(); } - /// AffineMap for indexing into the LHS. - AffineMap lhsIndexing() const { - assert(indexingMaps.size() == 3 && "expected 3 indexing maps"); - return indexingMaps[0]; - } - /// AffineMap for indexing into the RHS. - AffineMap rhsIndexing() const { - assert(indexingMaps.size() == 3 && "expected 3 indexing maps"); - return indexingMaps[1]; - } - /// AffineMap for indexing into the RES. - AffineMap resIndexing() const { - assert(indexingMaps.size() == 3 && "expected 3 indexing maps"); - return indexingMaps[2]; - } -}; - -/// Creates a group of matchers for: -/// -/// trailing(reduction(leading(), fill())) -/// -/// where trailing and leading are elementwise operations whose presence is -/// optional. Each matcher will capture the corresponding operation. If -/// `mustMatchEntireFunc` is set, the matcher additionally checks if all -/// tileable operations in the functions are captured. -void makeReductionMatcher(MatcherContext &context, - StructuredOpMatcher *&reductionCapture, - StructuredOpMatcher *&fillCapture, - StructuredOpMatcher *&leadingCapture, - StructuredOpMatcher *&trailingCapture, - MatchedReductionCaptures &captures, - bool mustMatchEntireFunc); -void makeReductionMatcher(MatcherContext &context, - StructuredOpMatcher *&reductionCapture, - MatchedReductionCaptures &captures, - bool mustMatchEntireFunc); -/// -/// trailing(matmul(*, *, fill())) -/// -/// where trailing and leading are elementwise operations whose presence is -/// optional. Each matcher will capture the corresponding operation. If -/// `mustMatchEntireFunc` is set, the matcher additionally checks if all -/// tileable operations in the functions are captured. -void makeMatmulMatcher(MatcherContext &matcherContext, - StructuredOpMatcher *&matmulCapture, - StructuredOpMatcher *&fillCapture, - StructuredOpMatcher *&trailingCapture, - MatchedMatmulCaptures &captures, - bool mustMatchEntireFunc); - -/// Create a group of matchers of batch matmul with a fill: -/// -/// batch_matmul(*, *, fill()) -/// -/// and capture various useful quantities. If `mustMatchEntireFunc` is set, the -/// matcher additionally checks if all tileable operations in the functions are -/// captured. -void makeBatchMatmulMatcher(transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&bmmCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::MatchedMatmulCaptures &captures, - bool mustMatchEntireFunc); - -/// Create a group of matchers for a different code sequence of operations -/// matching exactly a softmax operation. -/// -/// %red = reduce_max(%0) -/// %sub = sub(%0, %red) -/// %exp = exp(%sub) -/// %sum = reduce_sum(%exp) -/// %mul = div(%exp, %%sum) -void makeSoftmaxMatcher(MatcherContext &context, - StructuredOpMatcher *&maxReductionCapture, - StructuredOpMatcher *&softmaxRootCapture); - -struct MatchedConvolutionCaptures { - Type inputElementType, filterElementType, outputElementType; - mlir::linalg::ConvolutionDimensions convolutionDims = {}; - SmallVector convolutionOpSizes = {}; - SmallVector trailingOpSizes = {}; - int64_t maybeTrailingOutputElementalTypeBitWidth = 0; - int64_t maybeFillElementalTypeBitWidth = 0; -}; - -/// Creates a group of matchers for: -/// -/// trailing(convolution(input, filter, fill())) -/// -/// where fill is a FillOp and trailing is an elementwise operation, both of -/// which is optional. Each matcher will capture the corresponding operation. If -/// `mustMatchEntireFunc` is set, the matcher additionally checks if all -/// tileable operations in the functions are captured. -void makeConvolutionMatcher(MatcherContext &context, - StructuredOpMatcher *&convolutionCapture, - StructuredOpMatcher *&fillCapture, - StructuredOpMatcher *&trailingCapture, - MatchedConvolutionCaptures &captures, - bool mustMatchEntireFunc); -void makeConvolutionMatcher(MatcherContext &context, - StructuredOpMatcher *&convolutionCapture, - MatchedConvolutionCaptures &captures, - bool mustMatchEntireFunc); - -struct MatchedPadCaptures { - int64_t rank = 0; - Type elementType; - SmallVector dims = {}; -}; - -/// Create a matcher for tensor.pad(*) without leading or trailing ops atm. -/// If `mustMatchEntireFunc` is set, the matcher additionally checks if all -/// tileable operations in the functions are captured. -void makePadMatcher(MatcherContext &context, CapturingOpMatcher *&padCapture, - MatchedPadCaptures &captures, bool mustMatchEntireFunc); - -/// Wraps the given matcher callback to indicate that it must capture all -/// tilable ops in the parent function. Expects the callback to accept the same -/// arguments as what is expected by MatchCallbacksRegistry::register, followed -/// by a bool. -template -auto wrapAsEntireFuncMatch(Fn &&fn) { - return [fn = std::move(fn)](MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - return fn(res, loc, state, handles, true); - }; -} - -/// Wraps the given matcher callback to indicate that it can match subgraphs. -/// Expects the callback to accept the same arguments as what is expected by -/// MatchCallbacksRegistry::register, followed by a bool. -template -auto wrapAsPartialMatch(Fn &&fn) { - return [fn = std::move(fn)](MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - return fn(res, loc, state, handles, false); - }; -} - -} // namespace transform_ext -} // namespace mlir - -#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_ diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt deleted file mode 100644 index 9ed9512dd331..000000000000 --- a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -add_mlir_public_c_api_library(IREEDialectsCAPI - Dialects.cpp - LINK_LIBS PUBLIC - IREELinalgTransformDialect - MLIRIR - MLIRLinalgTransformOps - MLIRTransformDialect -) - -iree_dialects_target_includes(IREEDialectsCAPI) diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp deleted file mode 100644 index 3bcb0b0d0b3e..000000000000 --- a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects-c/Dialects.h" - -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" -#include "mlir/CAPI/IR.h" -#include "mlir/CAPI/Pass.h" -#include "mlir/CAPI/Registration.h" -#include "mlir/CAPI/Support.h" -#include "mlir/CAPI/Utils.h" -#include "mlir/CAPI/Wrap.h" -#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Support/LLVM.h" - -using namespace mlir; - -//===--------------------------------------------------------------------===// -// TransformDialect -//===--------------------------------------------------------------------===// - -void ireeRegisterTransformExtensions(MlirContext context) { - MLIRContext *ctx = unwrap(context); - DialectRegistry registry; - registry - .addExtensions(); - ctx->appendDialectRegistry(registry); -} diff --git a/llvm-external-projects/iree-dialects/lib/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/CMakeLists.txt index 76b98c3aea72..e69de29bb2d1 100644 --- a/llvm-external-projects/iree-dialects/lib/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/lib/CMakeLists.txt @@ -1,3 +0,0 @@ -add_subdirectory(CAPI) -add_subdirectory(Dialect) -add_subdirectory(Transforms) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt deleted file mode 100644 index 1da2860785ea..000000000000 --- a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(LinalgTransform) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt deleted file mode 100644 index f33061b2d87c..000000000000 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt deleted file mode 100644 index 6eb51d339055..000000000000 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ -add_mlir_library(IREELinalgTransformDialect - StructuredTransformOpsExt.cpp - - DEPENDS - mlir-headers - - LINK_LIBS PUBLIC - IREEDialectsTransforms - MLIRIR - - MLIRAsyncDialect - MLIRControlFlowInterfaces - MLIRLinalgDialect - MLIRPDLDialect - MLIRRewrite - MLIRTransformDialect - MLIRTransformDialectInterfaces - MLIRTransformPDLExtension - - # Transforms - MLIRAffineToStandard - MLIRAsyncTransforms - MLIRLinalgTransforms - MLIRMemRefTransforms - MLIRReconcileUnrealizedCasts - MLIRTensorTransformOps - MLIRTransforms - MLIRVectorToSCF - - # Conversions - MLIRAsyncToLLVM - MLIRIndexToLLVM - MLIRMathToLLVM - MLIRMemRefToLLVM - MLIRSCFToControlFlow - MLIRVectorToLLVM - MLIRVectorToLLVMPass -) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp deleted file mode 100644 index 6f9ae25b15fc..000000000000 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp +++ /dev/null @@ -1,999 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" - -#include "iree-dialects/Transforms/TransformMatchers.h" -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" -#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" -#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/LoopUtils.h" -#include "mlir/Dialect/Async/Passes.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" -#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" -#include "mlir/Transforms/Passes.h" -#include "llvm/ADT/ScopeExit.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "transform-ops-ext" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// Additional constraints for PDLMatchOp. -//===----------------------------------------------------------------------===// - -/// Hook for PDL driver to check if an operation (`pdlValues[0]`) is directly -/// nested in a function with the name provided by an attribute -/// (`pdlValues[1]`). -/// TODO: PDL needs user-defined "questions". -static LogicalResult nestedInFunc(PatternRewriter &rewriter, - PDLResultList &pdlResults, - ArrayRef pdlValues) { - assert(pdlValues.size() == 2 && "expected 2 PDL values"); - Operation *operation = pdlValues[0].cast(); - Attribute attr = pdlValues[1].cast(); - - auto func = operation->getParentOfType(); - if (!func) { - return rewriter.notifyMatchFailure(operation, "not nested in a function"); - } - auto functionSymbol = dyn_cast(attr); - if (!functionSymbol) { - return rewriter.notifyMatchFailure(operation, "not a function identifier"); - } - return success(functionSymbol.getLeafReference() == func.getName()); -} - -/// Construct a IRMapping from `linalgOp` to `genericLinalgModelOp`. -/// Walk both ops and check whether all subops are the same. -static LogicalResult -haveIdenticalBodiesImpl(linalg::LinalgOp linalgOp, - linalg::LinalgOp genericLinalgModelOp) { - IRMapping bvm; - bvm.map(linalgOp.getBlock()->getArguments(), - genericLinalgModelOp.getBlock()->getArguments()); - SmallVector linalgBodyOps; - linalgOp.getBlock()->walk( - [&](Operation *op) { linalgBodyOps.push_back(op); }); - - unsigned idx = 0; - WalkResult res = genericLinalgModelOp.getBlock()->walk([&](Operation *op) { - Operation *linalgSubOp = linalgBodyOps[idx++]; - if (op->getName() != linalgSubOp->getName()) { - return WalkResult::interrupt(); - } - if (op->getAttrs() != linalgSubOp->getAttrs()) { - return WalkResult::interrupt(); - } - for (auto it : llvm::zip(op->getOperands(), linalgSubOp->getOperands())) { - if (std::get<0>(it) != bvm.lookupOrNull(std::get<1>(it))) { - return WalkResult::interrupt(); - } - } - bvm.map(linalgSubOp->getResults(), op->getResults()); - return WalkResult::advance(); - }); - - return success(!res.wasInterrupted()); -} - -/// Dispatch body equivalence check depending on case. -static LogicalResult haveEquivalentBodies(linalg::LinalgOp linalgOp, - linalg::LinalgOp genericLinalgModelOp, - PatternRewriter &rewriter) { - if (succeeded(haveIdenticalBodiesImpl(linalgOp, genericLinalgModelOp))) { - return success(); - } - // TODO: haveEquivalentBodiesImpl, see e.g. - // https://gist.github.com/nicolasvasilache/39e89e18c46e02335c16db6ec20a07e3 - return failure(); -} - -/// Succeed when `linalgOp` and `linalgModelOp` are deemed equivalent. -static LogicalResult isEquivalentToOpImpl(PatternRewriter &rewriter, - linalg::LinalgOp linalgOp, - linalg::LinalgOp linalgModelOp) { - // If basic properties do not match, return failure. - { - SmallVector opInputs = linalgOp.getDpsInputs(); - SmallVector modelInputs = linalgModelOp.getDpsInputs(); - ValueRange opOutputs = linalgOp.getDpsInits(); - ValueRange modelOutputs = linalgModelOp.getDpsInits(); - auto notEqualFn = [](std::tuple in) -> bool { - return std::get<0>(in) != std::get<1>(in); - }; - - if (opInputs.size() != modelInputs.size() || - opOutputs.size() != modelOutputs.size() || - llvm::any_of(llvm::zip(opInputs, modelInputs), notEqualFn) || - llvm::any_of(llvm::zip(opOutputs, modelOutputs), notEqualFn) || - linalgOp.getIndexingMaps() != linalgModelOp.getIndexingMaps() || - linalgOp.getIteratorTypesArray() != - linalgModelOp.getIteratorTypesArray()) { - return failure(); - } - } - - // Build the block and go perform a body comparison. - { - // createBlock moves the insertion point, scope it in an RAII block. - OpBuilder::InsertionGuard guard(rewriter); - Region &r = linalgModelOp->getRegion(0); - Block *bodyBlock = rewriter.createBlock( - &r, r.end(), linalgOp.getBlock()->getArgumentTypes(), - llvm::map_to_vector<4>(linalgOp.getBlock()->getArguments(), - [](Value v) { return v.getLoc(); })); - ImplicitLocOpBuilder b(linalgModelOp.getLoc(), rewriter); - auto regionBuilder = linalgModelOp.getRegionBuilder(); - llvm::ArrayRef attrs = {}; - regionBuilder(b, *bodyBlock, attrs, /*emitError=*/{}); - } - - return haveEquivalentBodies(linalgOp, linalgModelOp, rewriter); -} - -/// Check whether the unique Operation* stored in `pdlValues[0]` (assumed) is -/// equivalent to the unique StringRefAttr passed in `pdlValues[1]` (assumed). -/// Equivalence is achieved when either: -/// 1. `pdlValues[0]` has the name stored in `pdlValues[1]`. -/// 2. `pdlValues[0]` and `pdlValues[1]` are both linalg ops and their -/// structured interfaces as well as their bodies are equivalent. -/// Structured interfaces equivalence is a simple attribute level check. -/// Body equivalence is more involved and currently limited: -/// a. the current impl constructs an instance of the op whose name is -/// specified in `pdlValues[1]` and checks for exact body equality. -/// b. a more advanced version would "subtract" the bodies and fold, cse -/// and canonicalize to fixed point. If the result is "all zeros", -/// then the bodies would be equivalent (really isomorphic). -/// 3. other cases TBD (e.g. vector.generic when available). -static LogicalResult isEquivalentToOp(PatternRewriter &rewriter, - PDLResultList &pdlResults, - ArrayRef pdlValues) { - assert(pdlValues.size() == 2 && "expected 2 PDL values"); - Operation *operation = pdlValues[0].cast(); - Attribute attribute = pdlValues[1].cast(); - - auto modelOpNameAttr = dyn_cast(attribute); - if (!modelOpNameAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - auto modelOpName = modelOpNameAttr.strref(); - - // 1. If op has name `modelOpName`, the match is trivial. - if (operation->getName().getStringRef() == modelOpName) { - return success(); - } - - // 2. Linalg vs Linalg. - // Create op from `modelOpName`. - OperationState modelOpState( - operation->getLoc(), modelOpName, operation->getOperands(), - operation->getResultTypes(), operation->getAttrs()); - modelOpState.addRegion(); - Operation *modelOp = rewriter.create(modelOpState); - auto g1 = llvm::scope_exit([&]() { rewriter.eraseOp(modelOp); }); - linalg::LinalgOp linalgOp = dyn_cast(operation); - linalg::LinalgOp linalgModelOp = dyn_cast(modelOp); - if (linalgOp && linalgModelOp) { - return isEquivalentToOpImpl(rewriter, linalgOp, linalgModelOp); - } - - // 3. TBD - return failure(); -} - -/// Assume that: -/// 1. `pdlValues[0]` is an operands range -/// 2. `pdlValues[1]` contains a DictAttr with `operand_number`, `dim` and -/// `divisor` IntegerAttr entries. -/// Succeed if `operands`[`operand_number`] is a ranked type whose `dim` is a -/// multiple of `divisor`. -/// Note: 0 is the convention to express "do not tile", it is considered to -/// divide everything. -static LogicalResult isDimMultipleOf(PatternRewriter &rewriter, - PDLResultList &pdlResults, - ArrayRef pdlValues) { - assert(pdlValues.size() == 2 && "expected 2 PDL values"); - ValueRange operands = pdlValues[0].cast(); - Attribute attribute = pdlValues[1].cast(); - - auto dict = dyn_cast(attribute); - if (!dict) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - - int64_t dim; - auto dimAttr = dict.getAs("dim"); - if (!dimAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - dim = dimAttr.getInt(); - - int64_t divisor; - auto divisorAttr = dict.getAs("divisor"); - if (!divisorAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - divisor = divisorAttr.getInt(); - - int64_t operandNumber; - auto operandNumberAttr = dict.getAs("operand_number"); - if (!operandNumberAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - operandNumber = operandNumberAttr.getInt(); - - ShapedType shapedType; - if (static_cast(operands.size()) > operandNumber) { - shapedType = dyn_cast(operands[operandNumber].getType()); - } - if (!shapedType || shapedType.getRank() <= dim) { - return failure(); - } - return success(divisor == 0 || (shapedType.getShape()[dim] > 0 && - shapedType.getShape()[dim] % divisor == 0)); -} - -/// Assume that: -/// 1. `pdlValues[0]` is an operands range -/// 2. `pdlValues[1]` contains a DictAttr with `operand_number` and `dim` -/// IntegerAttr entries. -/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is -/// dynamic. -static LogicalResult isDimStatic(PatternRewriter &rewriter, - PDLResultList &pdlResults, - ArrayRef pdlValues) { - assert(pdlValues.size() == 2 && "expected 2 PDL values"); - ValueRange operands = pdlValues[0].cast(); - Attribute attribute = pdlValues[1].cast(); - - auto dict = dyn_cast(attribute); - if (!dict) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - - int64_t dim; - auto dimAttr = dict.getAs("dim"); - if (!dimAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - dim = dimAttr.getInt(); - - int64_t operandNumber; - auto operandNumberAttr = dict.getAs("operand_number"); - if (!operandNumberAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - operandNumber = operandNumberAttr.getInt(); - - ShapedType shapedType; - if (static_cast(operands.size()) > operandNumber) { - shapedType = dyn_cast(operands[operandNumber].getType()); - } - return success(shapedType && !shapedType.isDynamicDim(dim)); -} - -/// Assume that: -/// 1. `pdlValues[0]` is an operands range -/// 2. `pdlValues[1]` contains a DictAttr with `operand_number` and `dim` -/// IntegerAttr entries. -/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is -/// dynamic. -static LogicalResult isDimDynamic(PatternRewriter &rewriter, - PDLResultList &pdlResults, - ArrayRef pdlValues) { - assert(pdlValues.size() == 2 && "expected 2 PDL values"); - ValueRange operands = pdlValues[0].cast(); - Attribute attribute = pdlValues[1].cast(); - - auto dict = dyn_cast(attribute); - if (!dict) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - - int64_t dim; - auto dimAttr = dict.getAs("dim"); - if (!dimAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - dim = dimAttr.getInt(); - - int64_t operandNumber; - auto operandNumberAttr = dict.getAs("operand_number"); - if (!operandNumberAttr) { - return failure(); // TODO: notifyMatchFailure needs an Operation* handle. - } - operandNumber = operandNumberAttr.getInt(); - - ShapedType shapedType; - if (static_cast(operands.size()) > operandNumber) { - shapedType = dyn_cast(operands[operandNumber].getType()); - } - return success(shapedType && shapedType.isDynamicDim(dim)); -} - -//===----------------------------------------------------------------------===// -// StructuredTransformOpsExtension -//===----------------------------------------------------------------------===// - -mlir::transform_ext::StructuredTransformOpsExtension:: - StructuredTransformOpsExtension() { - registerTransformOps< -#define GET_OP_LIST -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.cpp.inc" - >(); - - addDialectDataInitializer( - [&](transform::PDLMatchHooks &hooks) { - llvm::StringMap constraints; - constraints.try_emplace("nestedInFunc", nestedInFunc); - constraints.try_emplace("isDimDynamic", isDimDynamic); - constraints.try_emplace("isDimMultipleOf", isDimMultipleOf); - constraints.try_emplace("isDimStatic", isDimStatic); - constraints.try_emplace("isEquivalentToOp", isEquivalentToOp); - hooks.mergeInPDLMatchHooks(std::move(constraints)); - }); - - declareDependentDialect(); - declareDependentDialect(); -} - -#define GET_OP_CLASSES -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.cpp.inc" - -//===----------------------------------------------------------------------===// -// ErrorCheckingTrackingListener -//===----------------------------------------------------------------------===// - -void ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( - Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) { - // Certain ops can dropped safely. - if (isa(op)) { - LLVM_DEBUG(DBGS() << "Silently dropping scf.for op mapping\n"); - return; - } - - SmallVector diags; - diag.takeDiagnostics(diags); - if (!status.succeeded()) { - status.takeDiagnostics(diags); - } - status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags)); - - status = emitSilenceableFailure( - getTransformOp(), "!!! tracking listener failed to find replacement op"); - status.attachNote(op->getLoc()) << "replaced op"; - for (Value v : values) { - status.attachNote(v.getLoc()) << "replacement value"; - } -} - -//===---------------------------------------------------------------------===// -// MatchCallbackOp -//===---------------------------------------------------------------------===// - -DiagnosedSilenceableFailure transform_ext::MatchCallbackOp::apply( - mlir::transform::TransformRewriter &rewriter, - mlir::transform::TransformResults &results, - mlir::transform::TransformState &state) { - auto setEmptyResults = [&results, this] { - for (OpResult value : getResults()) { - results.set(value, {}); - } - }; - auto errorOut = [this, &setEmptyResults] { - setEmptyResults(); - return emitSilenceableError(); - }; - - auto *registry = state.getExtension(); - if (!registry) { - return errorOut() << "match registry not available"; - } - - const transform_ext::MatchCallbacksRegistry::MatchCallbackFn *callback = - registry->get(getCallbackName()); - if (!callback) { - return errorOut() << "callback '" << getCallbackName() - << "' not found in the registry"; - } - - MatchCallbackResult result; - DiagnosedSilenceableFailure status = - (*callback)(result, getLoc(), state, getInputs()); - if (!status.succeeded()) { - setEmptyResults(); - if (status.isDefiniteFailure()) { - return status; - } - if (getFailurePropagationMode() == - mlir::transform::FailurePropagationMode::Propagate) { - return emitSilenceableError() << "failed to match"; - } else { - return DiagnosedSilenceableFailure::success(); - } - } - if (getNumResults() != result.getNumPayloadGroups()) { - return errorOut() - << "callback produced a different number of handles than expected ( " - << result.getNumPayloadGroups() << " vs " << getNumResults() << " )"; - } - - for (OpResult value : getResults()) { - results.set(value, result.getPayloadGroup(value.getResultNumber())); - } - return DiagnosedSilenceableFailure::success(); -} - -void transform_ext::MatchCallbackOp::getEffects( - SmallVectorImpl &effects) { - mlir::transform::onlyReadsHandle(getInputsMutable(), effects); - mlir::transform::producesHandle(getOutputs(), effects); - // TODO: it doesn't really modify the payload, we need a separate resource for - // this mapping. - mlir::transform::modifiesPayload(effects); -} - -//===---------------------------------------------------------------------===// -// Callbacks for tests driven by RegisterMatchCallbacksOp -//===---------------------------------------------------------------------===// - -/// Match callback for "_test_match_callback" hook. Matches any payload -/// operations associated with operand handles unless they have the -/// "test.iree_transform_do_not_match" attribute, in which case produces a -/// silenceable failure. -static DiagnosedSilenceableFailure -testMatchCallbackCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - bool hadFailures = false; - for (Value handle : handles) { - if (llvm::any_of(state.getPayloadOps(handle), [](Operation *op) { - return op->hasAttr("test.iree_transform_do_not_match"); - })) { - res.addPayloadGroup(ArrayRef()); - hadFailures = true; - } else { - res.addPayloadGroup(state.getPayloadOps(handle)); - } - } - if (hadFailures) { - return emitSilenceableFailure(loc) << "failed to match"; - } - return DiagnosedSilenceableFailure::success(); -} - -static DiagnosedSilenceableFailure testRepeatedMatcherUseCallback( - transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, ValueRange handles) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - transform_ext::MatcherContext matcherContext; - auto &operand = transform_ext::m_StructuredOp(matcherContext); - auto &first = transform_ext::m_StructuredOp(matcherContext).input(0, operand); - auto &second = transform_ext::m_StructuredOp(matcherContext) - .input(0, operand) - .input(1, first); - - WalkResult walkResult = root->walk([&](Operation *op) { - second.resetCapture(); - if (!matchPattern(op, second)) { - return WalkResult::advance(); - } - - res.addPayloadGroup({first.getCaptured()}); - res.addPayloadGroup({second.getCaptured()}); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -static DiagnosedSilenceableFailure -testValueMatcherCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - transform_ext::MatcherContext matcherContext; - auto &operand = transform_ext::m_Value(matcherContext); - auto &first = transform_ext::m_StructuredOp(matcherContext).input(0, operand); - auto &second = transform_ext::m_StructuredOp(matcherContext) - .input(0, operand) - .input(1, first); - - WalkResult walkResult = root->walk([&](Operation *op) { - second.resetCapture(); - if (!matchPattern(op, second)) { - return WalkResult::advance(); - } - - res.addPayloadGroup({first.getCaptured()}); - res.addPayloadGroup({second.getCaptured()}); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -static DiagnosedSilenceableFailure testShapedValueMatcherCallback( - transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, ValueRange handles) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - int64_t rank; - SmallVector dims; - transform_ext::MatcherContext matcherContext; - auto &value = transform_ext::m_ShapedValue(matcherContext); - value.rank(transform_ext::CaptureRank(rank)) - .dim(transform_ext::AllDims(), transform_ext::CaptureDims(dims)); - auto &opMatcher = - transform_ext::m_Operation(matcherContext); - opMatcher.result(0, value); - - WalkResult walkResult = root->walk([&](Operation *op) { - opMatcher.resetCapture(); - if (!matchPattern(op, opMatcher)) { - return WalkResult::advance(); - } - - op->emitRemark() << "rank: " << rank; - std::string message; - llvm::raw_string_ostream os(message); - llvm::interleaveComma(dims, os); - os.flush(); - op->emitRemark() << "dimensions: " << message; - - res.addPayloadGroup({opMatcher.getCaptured()}); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -//===---------------------------------------------------------------------===// -// Callbacks for codegen driven by RegisterMatchCallbacksOp. -//===---------------------------------------------------------------------===// - -/// Match callback for a convolution with optional fill and trailing -/// elementwise operations. Matches *the first* occurrence of such a convolution -/// within an op associated with the given handle. -/// -/// Input handles: -/// -/// - container op, must be associated with one operation. -/// -/// Output handles: -/// -/// - the "fill" op preceding the convolution, if present; -/// - convolution op; -/// - trailing elementwise op, if any. -static DiagnosedSilenceableFailure -convolutionCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - - transform_ext::StructuredOpMatcher *pattern, *fill, *trailing; - transform_ext::MatchedConvolutionCaptures ignore; - transform_ext::MatcherContext matcherContext; - makeConvolutionMatcher(matcherContext, pattern, fill, trailing, ignore, - /*mustMatchEntireFunc=*/true); - - // TODO: need a mechanism for this to go around the entire IR, - // potentially with list matches for each group. - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - WalkResult walkResult = root->walk([&](Operation *op) { - pattern->resetCapture(); - if (!matchPattern(op, *pattern)) { - return WalkResult::advance(); - } - - // TODO: notify properly. - LLVM_DEBUG({ - DBGS() << "fill:\n"; - if (fill->getCaptured()) { - DBGS() << fill->getCaptured() << "\n"; - } - DBGS() << "pattern: " << pattern->getCaptured() << "\n"; - DBGS() << "trailing:\n"; - if (trailing->getCaptured()) { - DBGS() << trailing->getCaptured() << "\n"; - } - }); - - res.addPotentiallyEmptyPayloadGroup(fill->getCaptured()); - res.addPayloadGroup({pattern->getCaptured()}); - res.addPotentiallyEmptyPayloadGroup(trailing->getCaptured()); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -/// Match callback for a reduction with optional leading and trailing -/// elementwise operations. Matches *the first* occurrence of such a reduction -/// within an op associated with the given handle. -/// -/// Input handles: -/// -/// - container op, must be associated with one operation. -/// -/// Output handles: -/// -/// - leading elementwise op, if any; -/// - the "fill" op preceding the reduction; -/// - reduction op; -/// - trailing elementwise op, if any. -static DiagnosedSilenceableFailure -reductionCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles, bool mustMatchEntireFunc) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - - transform_ext::StructuredOpMatcher *pattern, *fill, *leading, *trailing; - transform_ext::MatchedReductionCaptures ignore; - transform_ext::MatcherContext matcherContext; - makeReductionMatcher(matcherContext, pattern, fill, leading, trailing, ignore, - mustMatchEntireFunc); - - // TODO: need a mechanism for this to go around the entire IR, - // potentially with list matches for each group. - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - WalkResult walkResult = root->walk([&](Operation *op) { - pattern->resetCapture(); - if (!matchPattern(op, *pattern)) { - return WalkResult::advance(); - } - - // TODO: notify properly. - LLVM_DEBUG({ - DBGS() << "leading:\n"; - if (leading->getCaptured()) { - DBGS() << leading->getCaptured() << "\n"; - } - DBGS() << "fill: " << fill->getCaptured() << "\n"; - DBGS() << "pattern: " << pattern->getCaptured() << "\n"; - DBGS() << "trailing:\n"; - if (trailing->getCaptured()) { - DBGS() << trailing->getCaptured() << "\n"; - } - }); - - res.addPotentiallyEmptyPayloadGroup(leading->getCaptured()); - res.addPayloadGroup({fill->getCaptured()}); - res.addPayloadGroup({pattern->getCaptured()}); - res.addPotentiallyEmptyPayloadGroup(trailing->getCaptured()); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -/// Match callback for a matmul with fill and optional trailing -/// elementwise operations. Matches *the first* occurrence of such a convolution -/// within an op associated with the given handle. -/// -/// Input handles: -/// -/// - container op, must be associated with one operation. -/// -/// Output handles: -/// -/// - the "fill" op preceding the convolution, if present; -/// - convolution op; -/// - trailing elementwise op, if any. -static DiagnosedSilenceableFailure -matmulCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - - transform_ext::StructuredOpMatcher *pattern, *fill, *trailing; - transform_ext::MatchedMatmulCaptures ignore; - transform_ext::MatcherContext matcherContext; - makeMatmulMatcher(matcherContext, pattern, fill, trailing, ignore, - /*mustMatchEntireFunc=*/true); - - // TODO: need a mechanism for this to go around the entire IR, - // potentially with list matches for each group. - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - WalkResult walkResult = root->walk([&](Operation *op) { - pattern->resetCapture(); - if (!matchPattern(op, *pattern)) { - return WalkResult::advance(); - } - - // TODO: notify properly. - LLVM_DEBUG({ - DBGS() << "fill:\n"; - if (fill->getCaptured()) { - DBGS() << fill->getCaptured() << "\n"; - } - DBGS() << "pattern: " << pattern->getCaptured() << "\n"; - DBGS() << "trailing:\n"; - if (trailing->getCaptured()) { - DBGS() << trailing->getCaptured() << "\n"; - } - }); - - res.addPayloadGroup({fill->getCaptured()}); - res.addPayloadGroup({pattern->getCaptured()}); - res.addPotentiallyEmptyPayloadGroup(trailing->getCaptured()); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -/// Match callback for linalg.batch_matmul and its linalg.generic equivalent fed -/// by a linalg.fill. -/// -/// Input handles: -/// -/// - the container op, must be associated with one operation. -/// -/// Output handles: -/// -/// - the fill op initializing the output; -/// - the main compute op. -static DiagnosedSilenceableFailure -batchMatmulCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, - ValueRange handles) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - - transform_ext::StructuredOpMatcher *pattern, *fill; - transform_ext::MatchedMatmulCaptures ignore; - transform_ext::MatcherContext matcherContext; - transform_ext::makeBatchMatmulMatcher(matcherContext, pattern, fill, ignore, - /*mustMatchEntireFunc*/ true); - - // TODO: need a mechanism for this to go around the entire IR, - // potentially with list matches for each group. - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - WalkResult walkResult = root->walk([&](Operation *op) { - pattern->resetCapture(); - if (!matchPattern(op, *pattern)) { - return WalkResult::advance(); - } - - // TODO: notify properly - LLVM_DEBUG({ - DBGS() << "fill:" << fill->getCaptured() << "\n"; - DBGS() << "pattern: " << pattern->getCaptured() << "\n"; - }); - - res.addPayloadGroup({fill->getCaptured()}); - res.addPayloadGroup({pattern->getCaptured()}); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match batch matmul"; -} - -/// Match callback for a tensor.pad. Matches *the first* occurrence of such pad -/// within an op associated with the given handle. -/// -/// Input handles: -/// -/// - the container op, must be associated with one operation. -/// -/// Output handles: -/// -/// - the pad op. -static DiagnosedSilenceableFailure -padCallback(transform_ext::MatchCallbackResult &res, Location loc, - const mlir::transform::TransformState &state, ValueRange handles, - bool mustMatchEntireFunc) { - if (handles.size() != 1 || - !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) { - return emitSilenceableFailure(loc) - << "expected one handle to one operation"; - } - - transform_ext::CapturingOpMatcher *pattern; - transform_ext::MatchedPadCaptures ignore; - transform_ext::MatcherContext matcherContext; - makePadMatcher(matcherContext, pattern, ignore, mustMatchEntireFunc); - - Operation *root = *state.getPayloadOps(handles[0]).begin(); - - WalkResult walkResult = root->walk([&](Operation *op) { - pattern->resetCapture(); - if (!matchPattern(op, *pattern)) { - return WalkResult::advance(); - } - - // TODO: notify properly. - LLVM_DEBUG({ - DBGS() << "pad:\n"; - if (pattern->getCaptured()) { - DBGS() << pattern->getCaptured() << "\n"; - } - }); - - res.addPayloadGroup({pattern->getCaptured()}); - return WalkResult::interrupt(); - }); - - if (walkResult.wasInterrupted()) { - return DiagnosedSilenceableFailure::success(); - } - return emitSilenceableFailure(loc) << "failed to match"; -} - -//===---------------------------------------------------------------------===// -// RegisterMatchCallbacksOp -//===---------------------------------------------------------------------===// - -DiagnosedSilenceableFailure transform_ext::RegisterMatchCallbacksOp::apply( - mlir::transform::TransformRewriter &rewriter, - mlir::transform::TransformResults &results, - mlir::transform::TransformState &state) { - auto ®istry = state.addExtension(); - registry.registerCallback("_test_match_callback", testMatchCallbackCallback); - registry.registerCallback("_test_repeated_matcher_use_callback", - testRepeatedMatcherUseCallback); - registry.registerCallback("_test_value_matcher_callback", - testValueMatcherCallback); - registry.registerCallback("_test_shaped_value_matcher_callback", - testShapedValueMatcherCallback); - registry.registerCallback("convolution", convolutionCallback); - registry.registerCallback("matmul", matmulCallback); - registry.registerCallback("batch_matmul", batchMatmulCallback); - registry.registerCallback("pad", wrapAsEntireFuncMatch(padCallback)); - registry.registerCallback("reduction", - wrapAsEntireFuncMatch(reductionCallback)); - registry.registerCallback("reduction_partial", - wrapAsPartialMatch(reductionCallback)); - return DiagnosedSilenceableFailure::success(); -} - -void transform_ext::RegisterMatchCallbacksOp::getEffects( - SmallVectorImpl &effects) { - // TODO: it doesn't really modify the payload, we need a separate resource for - // this mapping. - mlir::transform::modifiesPayload(effects); -} - -//===---------------------------------------------------------------------===// -// TakeFirstOp -//===---------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform_ext::TakeFirstOp::apply(mlir::transform::TransformRewriter &rewriter, - mlir::transform::TransformResults &results, - mlir::transform::TransformState &state) { - SmallVector concatenated; - bool found = false; - for (Value handle : getInputs()) { - auto payloads = state.getPayloadOps(handle); - if (payloads.empty()) { - continue; - } - if (!found) { - results.set(cast(getFirst()), payloads); - found = true; - } else { - llvm::append_range(concatenated, payloads); - } - } - - if (!found) { - results.set(cast(getFirst()), {}); - } - results.set(cast(getRest()), concatenated); - return DiagnosedSilenceableFailure::success(); -} - -void transform_ext::TakeFirstOp::getEffects( - SmallVectorImpl &effects) { - mlir::transform::onlyReadsHandle(getInputsMutable(), effects); - mlir::transform::producesHandle(getOperation()->getOpResults(), effects); -} - -//===---------------------------------------------------------------------===// -// EmitRemarkOp -//===---------------------------------------------------------------------===// - -DiagnosedSilenceableFailure transform_ext::EmitRemarkOp::applyToOne( - transform::TransformRewriter &rewriter, Operation *target, - mlir::transform::ApplyToEachResultList &results, - mlir::transform::TransformState &state) { - for (Operation *payload : state.getPayloadOps(getHandle())) { - payload->emitRemark(getMessage()); - } - return DiagnosedSilenceableFailure::success(); -} - -void transform_ext::EmitRemarkOp::getEffects( - SmallVectorImpl &effects) { - mlir::transform::onlyReadsHandle(getHandleMutable(), effects); - mlir::transform::onlyReadsPayload(effects); -} diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt deleted file mode 100644 index da2cbd97ae82..000000000000 --- a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ - -add_mlir_library(IREEDialectsTransforms - TransformMatchers.cpp - - LINK_LIBS PRIVATE - # TODO: break dialect dependency by implementing the transformation separately - # and registering it. - MLIRArithDialect - MLIRAsyncDialect - MLIRFuncDialect - MLIRLinalgDialect - MLIRLinalgTransforms - MLIRMathDialect - - DEPENDS - mlir-headers -) - -iree_dialects_target_includes(IREEDialectsTransforms) diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp deleted file mode 100644 index a90ff9b2c32f..000000000000 --- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp +++ /dev/null @@ -1,1845 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-dialects/Transforms/TransformMatchers.h" - -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/Debug.h" - -using namespace mlir; - -#define DEBUG_TYPE "transform-matchers" -#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " -#define DBGSNL() llvm::dbgs() << "\n[" DEBUG_TYPE "] " - -//===---------------------------------------------------------------------===// -// CapturingMatcherBase -//===---------------------------------------------------------------------===// - -void transform_ext::CapturingMatcherBase::getAllNested( - SmallVectorImpl &nested) { - - SetVector found; - found.insert(nested.begin(), nested.end()); - int64_t start = found.size(); - - auto appendOne = [&found](CapturingMatcherBase &one) { - found.insert(one.nestedCapturingMatchers.begin(), - one.nestedCapturingMatchers.end()); - for (CapturingValueMatcher *valueMatcher : - one.nestedCapturingValueMatchers) { - found.insert(valueMatcher->nestedCapturingMatchers.begin(), - valueMatcher->nestedCapturingMatchers.end()); - } - }; - - appendOne(*this); - for (int64_t position = start; position < found.size(); ++position) { - appendOne(*found[position]); - } - - llvm::append_range(nested, found.getArrayRef()); -} - -void transform_ext::CapturingMatcherBase::getAllNestedValueMatchers( - SmallVectorImpl &nested) { - - SetVector found; - found.insert(nested.begin(), nested.end()); - int64_t start = found.size(); - - auto appendOne = [&found](CapturingMatcherBase &one) { - found.insert(one.nestedCapturingValueMatchers.begin(), - one.nestedCapturingValueMatchers.end()); - for (CapturingOpMatcher *opMatcher : one.nestedCapturingMatchers) { - found.insert(opMatcher->nestedCapturingValueMatchers.begin(), - opMatcher->nestedCapturingValueMatchers.end()); - } - }; - - appendOne(*this); - for (int64_t position = start; position < found.size(); ++position) { - appendOne(*found[position]); - } - - llvm::append_range(nested, found.getArrayRef()); -} - -void transform_ext::CapturingMatcherBase::resetCapture() { - SmallVector nested; - getAllNested(nested); - for (CapturingOpMatcher *matcher : nested) { - matcher->captured = nullptr; - } - SmallVector nestedValue; - getAllNestedValueMatchers(nestedValue); - for (CapturingValueMatcher *matcher : nestedValue) { - matcher->captured = nullptr; - } -} - -//===---------------------------------------------------------------------===// -// CapturingOpMatcher -//===---------------------------------------------------------------------===// - -bool transform_ext::CapturingOpMatcher::checkAllTilableMatched( - Operation *parent, Operation *op, - ArrayRef matchers) { - LLVM_DEBUG(DBGS() << "all tilable ops captured"); - int64_t numTilableOps = 0; - if (!parent) { - return false; - } - parent->walk([&](TilingInterface Op) { ++numTilableOps; }); - - llvm::SmallPtrSet matched; - for (CapturingOpMatcher *nested : matchers) { - if (Operation *captured = nested->getCaptured()) { - matched.insert(captured); - } - } - - // Don't forget to include the root matcher. - matched.insert(op); - return numTilableOps == matched.size(); -} - -bool transform_ext::CapturingOpMatcher::match(Operation *op) { - auto debugRAII = llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "-------\n"); }); - LLVM_DEBUG(DBGS() << "matching: " << *op << "\n"); - - if (getCaptured()) { - LLVM_DEBUG(DBGS() << "found an already captured op: "); - if (getCaptured() == op) { - LLVM_DEBUG(llvm::dbgs() << "same\n"); - return true; - } else { - LLVM_DEBUG(llvm::dbgs() << "different\n"); - return false; - } - } - - if (!llvm::all_of(predicates, [op](const PredicateFn &fn) { - bool result = fn(op); - LLVM_DEBUG(llvm::dbgs() << ": " << result << "\n"); - return result; - })) { - return false; - } - - captured = op; - return true; -} - -void transform_ext::CapturingOpMatcher::debugOutputForCreate( - ArrayRef opNames) { - LLVM_DEBUG(DBGS() << "operation type is one of {"; - llvm::interleaveComma(opNames, llvm::dbgs()); llvm::dbgs() << "}"); -} - -/// Apply the given matcher to the given object, produce debug messages. -template ::template args<0>> -static bool recursiveMatch(Matcher &matcher, Object &object, - StringRef extraMessage = "") { - LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "] " << "start recursive match (" - << extraMessage << ") {\n"); - bool result = matcher.match(object); - LLVM_DEBUG(DBGS() << "} end recursive match"); - return result; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::alternatives( - transform_ext::CapturingOpMatcher &first, - transform_ext::CapturingOpMatcher &second) { - addPredicate([&first, &second](Operation *op) { - LLVM_DEBUG(DBGS() << "matching alternatives\n"); - return recursiveMatch(first, op, "alternative 1") || - recursiveMatch(second, op, "alternative 2"); - }); - return *this; -} - -//---------------------------------------------------------------------------// -// Predicates for operands and results. -//---------------------------------------------------------------------------// - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::operand(transform_ext::NumEqualsTo num) { - addPredicate([=](Operation *op) { - LLVM_DEBUG(DBGS() << "operation has exactly " << num.value << " operands"); - return num.value == op->getNumOperands(); - }); - return *this; -} - -/// If `pos` is negative, returns the number of the operand in op starting from -/// the last. For example, -1 means the last operand, -2 means the -/// second-to-last, etc. Returns nullopt if pos is out-of-bounds, both positive -/// and negative. -static std::optional remapNegativeOperandNumber(int64_t pos, - Operation *op) { - int64_t updated = pos < 0 ? op->getNumOperands() + pos : pos; - if (updated < 0 || updated >= op->getNumOperands()) { - LLVM_DEBUG(DBGS() << "match operand #" << pos - << "that does not exist in the operation"); - return std::nullopt; - } - return updated; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::operand(int64_t pos, - CapturingOpMatcher &nested) { - addPredicate([pos, &nested](Operation *op) { - std::optional operandNo = remapNegativeOperandNumber(pos, op); - if (!operandNo) { - return false; - } - LLVM_DEBUG(DBGS() << "operand #" << pos << " is defined by an operation"); - Operation *definingOp = op->getOperand(*operandNo).getDefiningOp(); - if (!definingOp) { - return false; - } - return recursiveMatch(nested, definingOp); - }); - recordNestedMatcher(nested); - return *this; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::operand(int64_t pos, - CapturingValueMatcher &nested) { - addPredicate([pos, &nested](Operation *op) { - std::optional operandNo = remapNegativeOperandNumber(pos, op); - if (!operandNo) { - return false; - } - LLVM_DEBUG(DBGS() << "operand #" << pos << " is"); - Value operand = op->getOperand(*operandNo); - return recursiveMatch(nested, operand); - }); - recordNestedMatcher(nested); - return *this; -} - -transform_ext::CapturingOpMatcher &transform_ext::CapturingOpMatcher::operand( - int64_t position, std::function floatValueFn) { - addPredicate([position, - floatValueFn = std::move(floatValueFn)](Operation *op) -> bool { - std::optional operandNo = remapNegativeOperandNumber(position, op); - if (!operandNo) { - return false; - } - - LLVM_DEBUG(DBGS() << "operand #" << *operandNo - << " is a special floating point constant"); - auto cstOp = - op->getOperand(*operandNo).getDefiningOp(); - if (!cstOp) { - return false; - } - return floatValueFn(cstOp.value()); - }); - - return *this; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::operand(int64_t position, ConstantFloatOne) { - return operand(position, - [](llvm::APFloat value) { return value.isExactlyValue(1.0); }); -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::result(transform_ext::NumEqualsTo num) { - addPredicate([=](Operation *op) { - LLVM_DEBUG(DBGS() << "operation has exactly " << num.value << " results"); - return num.value == op->getNumResults(); - }); - return *this; -} - -transform_ext::CapturingOpMatcher & -transform_ext::CapturingOpMatcher::result(int64_t pos, - CapturingValueMatcher &nested) { - addPredicate([pos, &nested](Operation *op) { - int64_t updated = pos < 0 ? op->getNumResults() + pos : pos; - if (updated < 0 || updated >= op->getNumResults()) { - LLVM_DEBUG(DBGS() << "matching result #" << pos - << " that does not exist in the operation"); - return false; - } - LLVM_DEBUG(DBGS() << "result #" << pos << " is"); - Value result = op->getResult(updated); - return recursiveMatch(nested, result); - }); - recordNestedMatcher(nested); - return *this; -} - -//===---------------------------------------------------------------------===// -// CapturingValueMatcher -//===---------------------------------------------------------------------===// - -namespace { -struct DebugPrintValueWrapper { - Value value; -}; - -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const DebugPrintValueWrapper &wrapper) { - if (auto opResult = dyn_cast(wrapper.value)) { - return os << "op result #" << opResult.getResultNumber() << " in " - << wrapper.value; - } - - auto blockArg = cast(wrapper.value); - os << "block argument #" << blockArg.getArgNumber(); - Block *parentBlock = blockArg.getParentBlock(); - Region *parentRegion = parentBlock->getParent(); - if (!parentRegion) { - os << " of a detached block:\n"; - parentBlock->print(os); - return os; - } - - os << " of block #" - << std::distance(parentRegion->begin(), parentBlock->getIterator()); - Operation *parentOp = parentRegion->getParentOp(); - if (!parentOp) { - os << " of a detached region:\n"; - for (Block &b : *parentRegion) { - b.print(os); - } - return os; - } - - os << " in region #" << parentRegion->getRegionNumber() << " of " - << *parentOp; - return os; -} -} // namespace - -bool transform_ext::CapturingValueMatcher::match(Value value) { - auto debugRAII = llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "-------\n"); }); - LLVM_DEBUG(DBGS() << "matching " << DebugPrintValueWrapper{value} << "\n"); - - if (getCaptured()) { - LLVM_DEBUG(DBGS() << "found an already captured value: "); - if (getCaptured() == value) { - LLVM_DEBUG(llvm::dbgs() << "same\n"); - return true; - } else { - LLVM_DEBUG(llvm::dbgs() << "different\n"); - return false; - } - } - - for (const PredicateFn &fn : predicates) { - bool result = fn(value); - LLVM_DEBUG(llvm::dbgs() << ": " << result << "\n"); - if (!result) { - return false; - } - } - - captured = value; - return true; -} - -transform_ext::ShapedValueMatcher::ShapedValueMatcher() - : CapturingValueMatcher() { - addPredicate([](Value value) { - LLVM_DEBUG(DBGS() << "value is of shaped type"); - return value && isa(value.getType()); - }); -} - -transform_ext::ShapedValueMatcher & -transform_ext::ShapedValueMatcher::rank(transform_ext::CaptureRank capture) { - addPredicate([=](Value value) { - LLVM_DEBUG(DBGS() << "capturing shaped value rank"); - capture.value = cast(value.getType()).getRank(); - return true; - }); - return *this; -} - -transform_ext::ShapedValueMatcher & -transform_ext::ShapedValueMatcher::dim(int64_t dimension, CaptureDim capture) { - addPredicate([=](Value value) { - LLVM_DEBUG(DBGS() << "capturing shaped value dimension " << dimension); - capture.value = cast(value.getType()).getDimSize(dimension); - return true; - }); - return *this; -} - -transform_ext::ShapedValueMatcher & -transform_ext::ShapedValueMatcher::dim(AllDims tag, CaptureDims captures) { - (void)tag; - addPredicate([=](Value value) { - LLVM_DEBUG(DBGS() << "capturing all shaped value dimensions"); - ArrayRef shape = cast(value.getType()).getShape(); - captures.value.assign(shape.begin(), shape.end()); - return true; - }); - return *this; -} - -transform_ext::ShapedValueMatcher & -transform_ext::ShapedValueMatcher::elementType(CaptureElementType captures) { - addPredicate([=](Value value) { - LLVM_DEBUG(DBGS() << "capturing elementType"); - captures.value = cast(value.getType()).getElementType(); - return true; - }); - return *this; -} - -//===---------------------------------------------------------------------===// -// Constraints on op rank and dims. -//===---------------------------------------------------------------------===// - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::rank(NumGreaterEqualTo minRank) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "rank >= " << minRank.value); - return linalgOp.getNumLoops() >= minRank.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::rank(NumLowerEqualTo maxRank) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "rank <= " << maxRank.value); - return linalgOp.getNumLoops() <= maxRank.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::rank(NumEqualsTo exactRank) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "rank == " << exactRank.value); - return linalgOp.getNumLoops() == exactRank.value; - }); -} - -StringRef stringifyShapeKind(transform_ext::ShapeKind kind) { - switch (kind) { - case transform_ext::ShapeKind::Static: - return "static"; - case transform_ext::ShapeKind::Dynamic: - return "dynamic"; - } - llvm_unreachable("unhandled shape kind"); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(SmallVector &&dimensions, - ShapeKind kind) { - return addPredicate([dimensions = std::move(dimensions), - kind](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "dimensions ["; - llvm::interleaveComma(dimensions, llvm::dbgs()); - llvm::dbgs() << "] are " << stringifyShapeKind(kind)); - SmallVector shape = linalgOp.getStaticLoopRanges(); - for (auto dimension : dimensions) { - int64_t transformedDimension = - dimension >= 0 ? dimension : shape.size() + dimension; - if (transformedDimension < 0 || transformedDimension >= shape.size()) { - return false; - } - if (ShapedType::isDynamic(shape[transformedDimension]) ^ - (kind == ShapeKind::Static)) { - continue; - } - return false; - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all dimensions are " << stringifyShapeKind(kind)); - SmallVector shape = linalgOp.getStaticLoopRanges(); - return llvm::all_of(shape, [=](int64_t dimension) { - return ShapedType::isDynamic(dimension) ^ (kind == ShapeKind::Static); - }); - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(SmallVector &&dimensions, - utils::IteratorType kind) { - return addPredicate([dimensions = std::move(dimensions), - kind](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "dimensions ["; - llvm::interleaveComma(dimensions, llvm::dbgs()); - llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind)); - int64_t rank = linalgOp.getNumLoops(); - for (auto dimension : dimensions) { - int64_t transformedDimension = - dimension >= 0 ? dimension : rank + dimension; - if (transformedDimension < 0 || transformedDimension >= rank) { - return false; - } - utils::IteratorType iteratorKind = - linalgOp.getIteratorTypesArray()[transformedDimension]; - if (iteratorKind == kind) { - continue; - } - return false; - } - return true; - }); -} -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDims tag, utils::IteratorType kind) { - return dim(AllDimsExcept({}), kind); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDimsExcept &&dims, - utils::IteratorType kind) { - return addPredicate([dimensions = std::move(dims), - kind](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all dimensions except ["; - llvm::interleaveComma(dimensions.getExcluded(), llvm::dbgs()); - llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind)); - int64_t rank = linalgOp.getNumLoops(); - llvm::SmallDenseSet excludedDims; - for (int64_t dim : dimensions.getExcluded()) { - excludedDims.insert(dim >= 0 ? dim : rank + dim); - } - - for (auto [index, type] : - llvm::enumerate(linalgOp.getIteratorTypesArray())) { - if (excludedDims.contains(index)) { - continue; - } - if (type == kind) { - continue; - } - return false; - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(int64_t dimension, - DivisibleBy divisibleBy) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "dimension " << dimension << " is divisible by " - << divisibleBy.value); - int64_t rank = linalgOp.getNumLoops(); - int64_t transformedDimension = - dimension >= 0 ? dimension : rank + dimension; - if (transformedDimension >= rank) { - return false; - } - - int64_t size = linalgOp.getStaticLoopRanges()[transformedDimension]; - return !ShapedType::isDynamic(size) && (size % divisibleBy.value == 0); - }); -} - -//===---------------------------------------------------------------------===// -// Capture directives. -//===---------------------------------------------------------------------===// -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::rank(CaptureRank capture) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture rank"); - capture.value = linalgOp.getNumLoops(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(int64_t dimension, CaptureDim capture) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture dimension"); - int64_t rank = linalgOp.getNumLoops(); - int64_t transformedDimension = - dimension >= 0 ? dimension : rank + dimension; - if (transformedDimension >= rank) { - return false; - } - - capture.value = linalgOp.getStaticLoopRanges()[transformedDimension]; - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::dim(AllDims tag, CaptureDims captures) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture all dimensions"); - captures.value = linalgOp.getStaticLoopRanges(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::indexingMaps( - CaptureIndexingMaps indexingMaps) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture indexing maps"); - indexingMaps.value = linalgOp.getIndexingMapsArray(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::contractionDims( - CaptureContractionDims contractionDims) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture contraction dimensions"); - StringRef convMessage = linalg::detail::getMatchContractionMessage( - mlir::linalg::detail::isContractionInterfaceImpl( - linalgOp, &contractionDims.value)); - if (convMessage.empty()) { - return true; - } - LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")"); - return false; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::convolutionDims(CaptureConvDims convDims) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "capture convolution dimensions"); - StringRef convMessage = linalg::detail::getMatchConvolutionMessage( - mlir::linalg::detail::isConvolutionInterfaceImpl(linalgOp, - &convDims.value)); - if (convMessage.empty()) { - return true; - } - LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")"); - return false; - }); -} - -transform_ext::StructuredOpMatcher::StructuredOpMatcher( - StructuredOpMatcher &A, StructuredOpMatcher &B) { - - addPredicate([&A, &B](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "start recursive lhs OR match {\n"); - { - auto debugRAII = llvm::scope_exit( - [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (A.match(linalgOp)) { - return true; - } - } - LLVM_DEBUG(DBGS() << "start recursive rhs OR match {\n"); - { - auto debugRAII = llvm::scope_exit( - [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (B.match(linalgOp)) { - return true; - } - } - return false; - }); - recordNestedMatcher(A); - recordNestedMatcher(B); -} - -//===---------------------------------------------------------------------===// -// Constraints on input operands. -//===---------------------------------------------------------------------===// - -void transform_ext::StructuredOpMatcher::addInputMatcher( - int64_t position, std::function matcher, - OptionalMatch optional) { - addInputMatcher( - position, - // No need to handle optional inside the lambda, the wrapper will do that. - [matcher = std::move(matcher)](Value value) { - Operation *definingOp = value.getDefiningOp(); - return definingOp && matcher(definingOp); - }, - optional); -} - -void transform_ext::StructuredOpMatcher::addInputMatcher( - int64_t position, std::function matcher, - OptionalMatch optional) { - addPredicate([position, optional, matcher = std::move(matcher)]( - linalg::LinalgOp linalgOp) -> bool { - int64_t transformedPosition = - position >= 0 ? position : linalgOp.getNumDpsInputs() + position; - if (transformedPosition >= linalgOp.getNumDpsInputs()) { - LLVM_DEBUG(DBGS() << "input operand #" << position - << " does not exist but match required"); - return false; - } - - LLVM_DEBUG(DBGS() << "input operand #" << position - << (optional.value ? " (optional match) " : " ") - << "is\n"); - - // We MUST run the matcher at this point, even if the match is optional, - // to allow for capture. - LLVM_DEBUG(DBGS() << "start recursive match {\n"); - auto debugRAII = - llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (matcher(linalgOp.getDpsInputOperand(transformedPosition)->get())) { - return true; - } - return optional.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(AllOperands tag, IsPermutation) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all input operands have permutation maps"); - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand *operand : linalgOp.getDpsInputOperands()) { - if (!linalgOp.getMatchingIndexingMap(operand).isPermutation()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(AllOperands tag, - IsProjectedPermutation) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all input operands have projected permutation maps"); - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand *operand : linalgOp.getDpsInputOperands()) { - if (!linalgOp.getMatchingIndexingMap(operand).isProjectedPermutation()) { - return false; - } - } - return true; - }); -} - -/// Helper to check if the map is an identity map with a projected dim. -static bool isProjectedMap(AffineMap map, int64_t projectedDim) { - if (!map.isProjectedPermutation()) { - return false; - } - int64_t dimCounter = 0; - for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { - // Skip the project dim. - if (dimCounter == projectedDim) { - dimCounter++; - } - if (map.getDimPosition(i) != dimCounter++) { - return false; - } - } - return true; -} - -/// Helper to turn a potentially negative index to positive within the range -/// [0, ub) and indicate whether the transformed index is in bounds. -static bool makeValidPositiveIndex(int64_t &index, int64_t ub) { - int64_t positiveIndex = index >= 0 ? index : ub + index; - if (positiveIndex < 0 || ub < positiveIndex) { - LLVM_DEBUG(DBGSNL() << " index out of range"); - return false; - } - index = positiveIndex; - return true; -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(SmallVector &&positions, - IsProjected dim) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "operands "; - llvm::interleaveComma(positions, llvm::dbgs()); - llvm::dbgs() << " have a permutation maps with " << dim.value - << " projected"); - int64_t updatedDim = dim.value; - if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops())) { - return false; - } - for (int64_t position : positions) { - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, - linalgOp.getNumDpsInputs())) { - return false; - } - OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); - if (!isProjectedMap(linalgOp.getMatchingIndexingMap(operand), - updatedDim)) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(AllOperands tag, IsIdentity) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all input operands have identity maps"); - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand *operand : linalgOp.getDpsInputOperands()) { - if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(SmallVector &&positions, - IsIdentity) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operands "; - llvm::interleaveComma(positions, llvm::dbgs()); - llvm::dbgs() << " have identity maps"); - // all_of with a lambda requires const-casting dance, so using a loop. - for (int64_t position : positions) { - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, - linalgOp.getNumDpsInputs())) { - return false; - } - OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); - if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, - ElementTypeBitWidth width) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operand #" << position - << " has elemental type with bit width " << width.value); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); - return shapedType && shapedType.getElementType().isIntOrFloat() && - shapedType.getElementType().getIntOrFloatBitWidth() == width.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, - CaptureElementTypeBitWidth width) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operand #" << position << " capture bitwidth"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); - if (!shapedType || !shapedType.getElementType().isIntOrFloat()) { - return false; - } - width.value = shapedType.getElementType().getIntOrFloatBitWidth(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, - CaptureElementType elem) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operand #" << position - << " capture element type"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); - if (!shapedType) { - LLVM_DEBUG(DBGSNL() << " not a shaped type"); - return false; - } - elem.value = shapedType.getElementType(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(NumEqualsTo num) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "number of input operands == " << num.value); - return linalgOp.getNumDpsInputs() == num.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, - ConstantFloatMinOrMinusInf) { - return input(position, [](llvm::APFloat f) { - return (f.isLargest() || f.isInfinity()) && f.isNegative(); - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::input(int64_t position, ConstantFloatZero) { - return input(position, [](llvm::APFloat f) { return f.isZero(); }); -} - -transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::input( - int64_t position, std::function floatValueFn) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "input operand #" << position - << " is a special floating point constant"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) { - return false; - } - auto cstOp = linalgOp.getDpsInputOperand(updatedPosition) - ->get() - .getDefiningOp(); - if (!cstOp) { - return false; - } - return floatValueFn(cstOp.value()); - }); -} - -//===---------------------------------------------------------------------===// -// Constraints on output operands. -//===---------------------------------------------------------------------===// - -void transform_ext::StructuredOpMatcher::addOutputMatcher( - int64_t position, std::function matcher, - OptionalMatch optional) { - addPredicate([position, optional, matcher = std::move(matcher)]( - linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position - << (optional.value ? " (optional match) " - : " (mandatory match) ") - << "is produced by\n"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - Operation *definingOp = - linalgOp.getDpsInitOperand(updatedPosition)->get().getDefiningOp(); - if (!definingOp) { - return optional.value; - } - // We MUST run the matcher at this point, even if the match is optional, - // to allow for capture. - LLVM_DEBUG(DBGS() << "start recursive match {\n"); - auto debugRAII = - llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (matcher(definingOp)) { - return true; - } - return optional.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, IsPermutation) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all output operands have permutation maps"); - for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!linalgOp.getMatchingIndexingMap(&operand).isPermutation()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, - IsProjectedPermutation) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all output operands have projected permutation maps"); - for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!linalgOp.getMatchingIndexingMap(&operand).isProjectedPermutation()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, IsProjected dim) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all output operands have a maps with projected"); - int64_t updatedDim = dim.value; - if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops())) { - return false; - } - // all_of with a lambda requires const-casting dance, so using a loop. - for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!isProjectedMap(linalgOp.getMatchingIndexingMap(&operand), - updatedDim)) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(AllOperands tag, IsIdentity) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "all output operands have identity permutation maps"); - for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!linalgOp.getMatchingIndexingMap(&operand).isIdentity()) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - ElementTypeBitWidth width) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position - << " has elemental type with bit width " << width.value); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); - return shapedType && shapedType.getElementType().isIntOrFloat() && - shapedType.getElementType().getIntOrFloatBitWidth() == width.value; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - CaptureElementTypeBitWidth width) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position << " capture bitwidth"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); - if (!shapedType || !shapedType.getElementType().isIntOrFloat()) { - LLVM_DEBUG(DBGSNL() << " could not infer element type"); - return false; - } - width.value = shapedType.getElementType().getIntOrFloatBitWidth(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - CaptureElementType elem) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position - << " capture element type"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - auto shapedType = dyn_cast( - linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); - if (!shapedType) { - LLVM_DEBUG(DBGSNL() << " not a shaped type"); - return false; - } - elem.value = shapedType.getElementType(); - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(int64_t position, - SingleCombinerReduction tag) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "output operand #" << position - << " is populated by a single-combiner reduction"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) { - return false; - } - SmallVector combinerOps; - return matchReduction(linalgOp.getRegionOutputArgs(), updatedPosition, - combinerOps) && - llvm::hasSingleElement(combinerOps); - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::output(NumEqualsTo num) { - return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "number of output operands == " << num.value); - return linalgOp.getNumDpsInits() == num.value; - }); -} - -//===---------------------------------------------------------------------===// -// Constraints on results. -//===---------------------------------------------------------------------===// - -void transform_ext::StructuredOpMatcher::addResultMatcher( - int64_t position, HasAnyUse tag, std::function matcher, - OptionalMatch optional) { - addPredicate([matcher = std::move(matcher), optional, - position](linalg::LinalgOp linalgOp) -> bool { - LLVM_DEBUG(DBGS() << "result #" << position - << (optional.value ? " (optional match) " - : " (mandatory match) ") - << "has a use\n"); - int64_t updatedPosition = position; - if (!makeValidPositiveIndex(updatedPosition, linalgOp->getNumResults())) { - return false; - } - - // We MUST run the matcher at this point, even if the match is optional, - // to allow for capture. - LLVM_DEBUG(DBGS() << "start recursive match {\n"); - auto debugRAII = - llvm::scope_exit([] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); - if (llvm::any_of(linalgOp->getResult(updatedPosition).getUsers(), - [&matcher](Operation *op) { return matcher(op); })) { - return true; - } - return optional.value; - }); -} - -//===-------------------------------------------------------------------===// -// Constraints on op region. -//===-------------------------------------------------------------------===// - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::singleOpWithCanonicaleArgs( - StringRef opcode, bool commutative) { - return addPredicate([=](linalg::LinalgOp linalgOp) { - if (linalgOp.getBlock()->getOperations().size() != 2) { - return false; - } - Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); - if (innerOp->getName().getStringRef() != opcode || - innerOp->getNumResults() != 1) { - return false; - } - Operation *yieldOp = linalgOp.getBlock()->getTerminator(); - if (yieldOp->getNumOperands() != 1) { - return false; - } - if (yieldOp->getOperand(0).getDefiningOp() != innerOp) { - return false; - } - if (commutative && innerOp->getNumOperands() == 2) { - auto arg0 = dyn_cast(innerOp->getOperand(0)); - auto arg1 = dyn_cast(innerOp->getOperand(1)); - if (!arg0 || !arg1) { - return false; - } - if (arg0.getParentBlock() != linalgOp.getBlock() || - arg1.getParentBlock() != linalgOp.getBlock()) { - return false; - } - if (!((arg0.getArgNumber() == 0 && arg1.getArgNumber() == 1) || - (arg1.getArgNumber() == 0 && arg0.getArgNumber() == 1))) { - return false; - } - } else { - for (auto [index, operand] : llvm::enumerate(innerOp->getOperands())) { - auto arg = dyn_cast(operand); - if (!arg || arg.getParentBlock() != linalgOp.getBlock() || - arg.getArgNumber() != index) { - return false; - } - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::isFloatReciprocal() { - return addPredicate([=](linalg::LinalgOp linalgOp) { - LLVM_DEBUG(DBGS() << "op region represents a reciprocal operation"); - if (linalgOp.getBlock()->getOperations().size() != 2) { - return false; - } - Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); - if (!isa(innerOp) || innerOp->getNumResults() != 1) { - return false; - } - Operation *yieldOp = linalgOp.getBlock()->getTerminator(); - if (yieldOp->getNumOperands() != 1) { - return false; - } - if (yieldOp->getOperand(0).getDefiningOp() != innerOp) { - return false; - } - auto cst = innerOp->getOperand(0).getDefiningOp(); - if (!cst || cst.value().convertToDouble() != 1.0) { - return false; - } - auto arg = dyn_cast(innerOp->getOperand(1)); - if (!arg || arg.getParentBlock() != linalgOp.getBlock() || - arg.getArgNumber() != 0) { - return false; - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::passThroughOp() { - return addPredicate([=](linalg::LinalgOp linalgOp) { - if (linalgOp.getBlock()->getOperations().size() != 1) { - return false; - } - Operation *yieldOp = linalgOp.getBlock()->getTerminator(); - for (auto [index, operand] : llvm::enumerate(yieldOp->getOperands())) { - auto arg = dyn_cast(operand); - if (!arg || arg.getParentBlock() != linalgOp.getBlock() || - arg.getArgNumber() != index) { - return false; - } - } - return true; - }); -} - -transform_ext::StructuredOpMatcher & -transform_ext::StructuredOpMatcher::hasContractionBody( - function_ref isaElemOpTy, - function_ref isaReductionOpTy, StringRef elemOpName, - StringRef reductionOpName) { - return addPredicate([=](linalg::LinalgOp linalgOp) { - LLVM_DEBUG(DBGS() << "op region is a " << elemOpName << "/" - << reductionOpName << " contraction ("); - auto scopeExitPrinter = - llvm::scope_exit([] { LLVM_DEBUG(llvm::dbgs() << " check failed)"); }); - - Block *body = linalgOp.getBlock(); - if (!llvm::hasNItems(*body, 3)) { - LLVM_DEBUG(llvm::dbgs() << "three-operation body"); - return false; - } - if (body->getNumArguments() != 3) { - LLVM_DEBUG(llvm::dbgs() << "three-argument block"); - return false; - } - - Operation *elemOp = &(*linalgOp.getBlock()->getOperations().begin()); - Operation *reductionOp = elemOp->getNextNode(); - Operation *yieldOp = reductionOp->getNextNode(); - if (!isaElemOpTy(elemOp)) { - LLVM_DEBUG(llvm::dbgs() << "first operation is a " << elemOpName); - return false; - } - if (!isaReductionOpTy(reductionOp)) { - LLVM_DEBUG(llvm::dbgs() << "second operation is a " << reductionOpName); - return false; - } - if (yieldOp->getNumOperands() != 1) { - LLVM_DEBUG(llvm::dbgs() << "one value yielded"); - return false; - } - if (yieldOp->getOperand(0).getDefiningOp() != reductionOp) { - LLVM_DEBUG(llvm::dbgs() << "yielded value produced by the second op"); - return false; - } - if (elemOp->getNumOperands() != 2 || elemOp->getNumResults() != 1) { - LLVM_DEBUG(llvm::dbgs() << "first op has two operands and one result"); - return false; - } - if (reductionOp->getNumOperands() != 2 || - reductionOp->getNumResults() != 1) { - LLVM_DEBUG(llvm::dbgs() << "second op has two operands and one result"); - return false; - } - - SmallVector expectedReductionOperands = {body->getArgument(2), - elemOp->getResult(0)}; - if (!llvm::equal(expectedReductionOperands, reductionOp->getOperands()) && - !llvm::equal(llvm::reverse(expectedReductionOperands), - reductionOp->getOperands())) { - LLVM_DEBUG(llvm::dbgs() << "operands of the second op"); - return false; - } - - ValueRange expectedElemOperands = body->getArguments().take_front(2); - if (!llvm::equal(expectedElemOperands, elemOp->getOperands()) && - !llvm::equal(llvm::reverse(expectedElemOperands), - elemOp->getOperands())) { - LLVM_DEBUG(llvm::dbgs() << "operands of the first op"); - return false; - } - - scopeExitPrinter.release(); - LLVM_DEBUG(llvm::dbgs() << "success)"); - return true; - }); -} - -void transform_ext::detail::debugOutputForConcreteOpMatcherConstructor( - StringRef name) { - LLVM_DEBUG(DBGS() << "op is a " << name << "'"); -} - -//===---------------------------------------------------------------------===// -// TensorPadOpMatcher -//===---------------------------------------------------------------------===// - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::low(ArrayRef sizes) { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG({ - DBGS() << "low pad sizes are "; - llvm::interleaveComma(sizes, llvm::dbgs()); - }); - for (auto [ofr, sz] : llvm::zip(tensorPad.getMixedLowPad(), sizes)) { - if (isConstantIntValue(ofr, sz)) { - return false; - } - } - return true; - }); -} - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::low(AllDims tag, int64_t size) { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG(DBGS() << "all low pad sizes are " << size); - return llvm::all_of(tensorPad.getMixedLowPad(), [&](OpFoldResult ofr) { - return isConstantIntValue(ofr, size); - }); - }); -} - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::high(ArrayRef sizes) { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG({ - DBGS() << "high pad sizes are "; - llvm::interleaveComma(sizes, llvm::dbgs()); - }); - for (auto [ofr, sz] : llvm::zip(tensorPad.getMixedHighPad(), sizes)) { - if (isConstantIntValue(ofr, sz)) { - return false; - } - } - return true; - }); -} - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::high(AllDims tag, int64_t size) { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG(DBGS() << "all high pad sizes are " << size); - return llvm::all_of(tensorPad.getMixedHighPad(), [&](OpFoldResult ofr) { - return isConstantIntValue(ofr, size); - }); - }); -} - -transform_ext::TensorPadOpMatcher & -transform_ext::TensorPadOpMatcher::yieldsExternalValue() { - return addPredicate([=](tensor::PadOp tensorPad) { - LLVM_DEBUG(DBGS() << "pad body yields an externally-defined value"); - Block *body = tensorPad.getBody(); - if (!llvm::hasSingleElement(*body)) { - return false; - } - return llvm::all_of(body->getTerminator()->getOperands(), - [body](Value operand) { - auto arg = dyn_cast(operand); - return !arg || arg.getOwner() != body; - }); - }); -} - -//===---------------------------------------------------------------------===// -// MatchCallbackResult. -//===---------------------------------------------------------------------===// - -ArrayRef -transform_ext::MatchCallbackResult::getPayloadGroup(int64_t position) const { - assert(position < payloadGroupLengths.size()); - int64_t start = 0; - for (int64_t i = 0; i < position; ++i) { - start += payloadGroupLengths[i]; - } - return llvm::ArrayRef(payloadOperations) - .slice(start, payloadGroupLengths[position]); -} - -//===---------------------------------------------------------------------===// -// Case-specific matcher builders. -//===---------------------------------------------------------------------===// - -static constexpr int64_t kCudaWarpSize = 32; - -void transform_ext::makeReductionMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&reductionCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::StructuredOpMatcher *&leadingCapture, - transform_ext::StructuredOpMatcher *&trailingCapture, - MatchedReductionCaptures &captures, bool mustMatchEntireFunc) { - // The core part of the matcher is anchored on a particular reduction op. - auto &reduction = - m_StructuredOp(matcherContext) - // Op has at least a parallel a reduction dimension and at - // most 3 parallel dimensions. - // TODO: relax once we have global collapse/expand_shape. - // - .rank(NumGreaterEqualTo(2)) - .rank(NumLowerEqualTo(4)) - .rank(CaptureRank(captures.reductionRank)) - // Op has a single most-minor reduction. - .dim(-1, utils::IteratorType::reduction) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.reductionOpSizes)) - // All other dimensions are parallel. - .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) - // Single input for now, can be arbitrary projected permutations. - // TODO: Multiple inputs, can be arbitrary projected permutations. - // TODO: Watch out for multiple inputs though as a reduction turns - // into a contraction when mixed with projected - // permutations. A reduction is often bandwidth bound but - // contraction is a different beast that is compute bound - // and has a very different schedule. - // - .input(NumEqualsTo(1)) - .input(AllOperands(), IsProjectedPermutation()) - // Single output supported atm. - // TODO: Multiple outputs. - // - .output(NumEqualsTo(1)) - // A reduction output must be a projected permutation, match it but we - // could also drop this technically. - .output(AllOperands(), IsProjectedPermutation()) - // Only single combiner for now due to reduction warp - // distribution. - // TODO: relax this once reduction distribution is more powerful. - // - .output(0, CaptureElementTypeBitWidth( - captures.reductionOutputElementalTypeBitWidth)) - .output(0, SingleCombinerReduction()); - reductionCapture = &reduction; - - // Mandatory FillOp must create the unique output of the reduction. - // TODO: Relax this, as any map, broadcast, transpose should also work. - // - auto &fill = m_StructuredOp(matcherContext); - reduction = reduction.output(NumEqualsTo(1)).output(0, fill); - fillCapture = &fill; - - // Optional leading or trailing op can be any map, transpose, broadcast but - // not reduce or windowing operation for now. - // It must create the unique input for the reduction. - // TODO: match more optional leading ops, one per input of the reduction. - // TODO: careful about multi-output and turning into a contraction. - // - transform_ext::StructuredOpMatcher commonLeadingOrTrailing = - m_StructuredOp(matcherContext) - // All parallel dimensions. - .dim(AllDims(), utils::IteratorType::parallel) - // All inputs are any projected permutation. - .input(AllOperands(), IsProjectedPermutation()) - .output(AllOperands(), IsPermutation()) - // leading and trailing may have 0, 1 or more input as long as they do - // not come from unmatched ops. This extra constraint is taken care of - // separately. This is also a noop but we document it. - // TODO: Base and derived classes, atm this does not compile. - // .input(NumGreaterEqualTo(0)) - // Single output supported atm. - // TODO: extend this. - // - .output(NumEqualsTo(1)); - // TODO: match more optional leading ops, one per input of the reduction. - // TODO: careful about multi-output and turning into a contraction. - // - auto &leading = - m_StructuredOp(matcherContext, commonLeadingOrTrailing) - .rank(CaptureRank(captures.maybeLeadingRank)) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.leadingOpSizes)) - // Capture output elemental type. - .output(0, CaptureElementTypeBitWidth( - captures.maybeLeadingOutputElementalTypeBitWidth)); - reduction = reduction.input(0, leading, OptionalMatch()); - leadingCapture = &leading; - - // Optional trailing can be any map, transpose, broadcast but not reduce or - // windowing operation for now. - // It must be fed by the unique input for the reduction. - // TODO: match more optional leading ops, one per input of the reduction. - // TODO: careful about multi-output and turning into a contraction. - // - auto &trailing = - m_StructuredOp(matcherContext, commonLeadingOrTrailing) - .rank(CaptureRank(captures.maybeTrailingRank)) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.trailingOpSizes)) - // Capture output elemental type. - .output(0, CaptureElementTypeBitWidth( - captures.maybeTrailingOutputElementalTypeBitWidth)); - reduction = reduction.result(0, HasAnyUse(), trailing, OptionalMatch()); - if (mustMatchEntireFunc) { - reduction = reduction.allTilableOpsCaptured(); - } - trailingCapture = &trailing; -} - -void transform_ext::makeReductionMatcher(transform_ext::MatcherContext &context, - StructuredOpMatcher *&reductionCapture, - MatchedReductionCaptures &captures, - bool mustMatchEntireFunc) { - StructuredOpMatcher *fill; - StructuredOpMatcher *leading; - StructuredOpMatcher *trailing; - makeReductionMatcher(context, reductionCapture, fill, leading, trailing, - captures, mustMatchEntireFunc); -} - -void transform_ext::makeMatmulMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&matmulCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::StructuredOpMatcher *&trailingCapture, - transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { - auto &matmul = transform_ext::m_StructuredOp(matcherContext) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) - // Capture input/output element types. - .input(0, CaptureElementType(captures.lhsElementType)) - .input(1, CaptureElementType(captures.rhsElementType)) - .output(0, CaptureElementType(captures.outputElementType)); - matmulCapture = &matmul; - // Mandatory FillOp must create the unique output of the reduction. - auto &fill = transform_ext::m_StructuredOp(matcherContext); - matmul = matmul.output(transform_ext::NumEqualsTo(1)).output(0, fill); - fillCapture = &fill; - - auto &trailing = m_StructuredOp(matcherContext); - matmul = matmul.result(0, HasAnyUse(), trailing, OptionalMatch()); - if (mustMatchEntireFunc) { - matmul = matmul.allTilableOpsCaptured(); - } - trailingCapture = &trailing; -} - -void transform_ext::makeBatchMatmulMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&bmmCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { - auto &bmm = - transform_ext::m_StructuredOp( - matcherContext) - .hasContractionBody() - .rank(NumEqualsTo(4)) - .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) - .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) - .dim(-1, utils::IteratorType::reduction) - .contractionDims(CaptureContractionDims(captures.contractionDims)) - .input(NumEqualsTo(2)) - .input(0, CaptureElementType(captures.lhsElementType)) - .input(1, CaptureElementType(captures.rhsElementType)) - .output(0, CaptureElementType(captures.outputElementType)); - bmmCapture = &bmm; - - auto &fill = transform_ext::m_StructuredOp(matcherContext); - bmm = bmm.output(0, fill); - fillCapture = &fill; - - if (mustMatchEntireFunc) { - bmm = bmm.allTilableOpsCaptured(); - } -} - -/// Match sum(%src, broadcast(%reduction)) -static void -matchSubBroadcast(transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher &maxReduction, - transform_ext::CapturingValueMatcher &softmaxSourceOperand, - transform_ext::StructuredOpMatcher *&sub) { - using namespace transform_ext; - - auto &broadcast = - transform_ext::m_StructuredOp(matcherContext) - .passThroughOp() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(0, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - broadcast = broadcast.input(0, maxReduction); - - auto &subParallel = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsIdentity()) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - subParallel = subParallel.input(0, softmaxSourceOperand); - subParallel = subParallel.input(1, broadcast); - - auto &subBroadcast = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - subBroadcast = subBroadcast.input(0, softmaxSourceOperand); - subBroadcast = subBroadcast.input(1, maxReduction); - auto &subOr = transform_ext::m_StructuredOp_Or(matcherContext, subBroadcast, - subParallel); - sub = &subOr; -} - -/// Match sum(%exp, broadcast(%sum)) -static void matchdivBroadcast(transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher &expOperand, - transform_ext::StructuredOpMatcher &sum, - transform_ext::StructuredOpMatcher *&div) { - using namespace transform_ext; - - auto &broadcast = - transform_ext::m_StructuredOp(matcherContext) - .passThroughOp() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(0, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - broadcast = broadcast.input(0, sum); - - auto &divNoBroadcast = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsIdentity()) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - - divNoBroadcast = divNoBroadcast.input(0, expOperand); - divNoBroadcast = divNoBroadcast.input(1, broadcast); - - auto &divBroadcast = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - - divBroadcast = divBroadcast.input(0, expOperand); - divBroadcast = divBroadcast.input(1, sum); - - auto &divMerge = transform_ext::m_StructuredOp_Or( - matcherContext, divNoBroadcast, divBroadcast); - div = &divMerge; -} - -void transform_ext::makeSoftmaxMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&maxReductionCapture, - transform_ext::StructuredOpMatcher *&softmaxRootCapture) { - auto &softmaxSourceOperand = m_Value(matcherContext); - - auto &fillMinusInf = m_StructuredOp(matcherContext) - .input(0, ConstantFloatMinOrMinusInf()); - auto &maxReduction = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs(/*commutative=*/true) - // Only handle most inner reduction for now. - .dim(-1, utils::IteratorType::reduction) - .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(AllOperands(), IsIdentity()) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsProjected(-1)); - maxReduction = maxReduction.input(0, softmaxSourceOperand); - maxReduction = maxReduction.output(0, fillMinusInf); - maxReductionCapture = &maxReduction; - - transform_ext::StructuredOpMatcher *subOperand; - matchSubBroadcast(matcherContext, maxReduction, softmaxSourceOperand, - subOperand); - - auto &expOperand = m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(AllOperands(), IsIdentity()) - .output(AllOperands(), IsIdentity()) - .output(NumEqualsTo(1)); - expOperand = expOperand.input(0, *subOperand); - - auto &fillZero = m_StructuredOp(matcherContext) - .input(0, ConstantFloatZero()); - auto &sum = - m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs(/*commutative=*/true) - // Only handle most inner reduction for now. - .dim(-1, utils::IteratorType::reduction) - .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(AllOperands(), IsIdentity()) - .output(AllOperands(), IsProjected(-1)) - .output(NumEqualsTo(1)); - sum = sum.input(0, expOperand); - sum = sum.output(0, fillZero); - - auto &rcpOperand = m_StructuredOp(matcherContext) - .isFloatReciprocal() - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(1)) - .input(AllOperands(), IsIdentity()) - .output(AllOperands(), IsIdentity()) - .output(NumEqualsTo(1)); - rcpOperand = rcpOperand.input(0, sum); - - auto &mulOperand = - transform_ext::m_StructuredOp(matcherContext) - .singleOpWithCanonicaleArgs(/*commutative=*/true) - .dim(AllDims(), utils::IteratorType::parallel) - .input(NumEqualsTo(2)) - .input(0, IsIdentity()) - .input(1, IsProjected(-1)) - .output(NumEqualsTo(1)) - .output(AllOperands(), IsIdentity()); - - mulOperand = mulOperand.input(0, expOperand); - mulOperand = mulOperand.input(1, rcpOperand); - - transform_ext::StructuredOpMatcher *divOperand; - matchdivBroadcast(matcherContext, expOperand, sum, divOperand); - - auto &softmaxRoot = - transform_ext::m_StructuredOp_Or(matcherContext, mulOperand, *divOperand); - softmaxRootCapture = &softmaxRoot; -} - -/// Matcher for convolutions. -void transform_ext::makeConvolutionMatcher( - transform_ext::MatcherContext &matcherContext, - transform_ext::StructuredOpMatcher *&convolutionCapture, - transform_ext::StructuredOpMatcher *&fillCapture, - transform_ext::StructuredOpMatcher *&trailingCapture, - MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { - // The core part of the matcher is anchored on a particular convolution op. - auto &convolution = - m_StructuredOp( - matcherContext) - // Capture convolution dim classifications. - .convolutionDims(CaptureConvDims(captures.convolutionDims)) - // Capture op sizes. - .dim(AllDims(), CaptureDims(captures.convolutionOpSizes)) - // Capture convolution element types. - .input(0, CaptureElementType(captures.inputElementType)) - .input(1, CaptureElementType(captures.filterElementType)) - .output(0, CaptureElementType(captures.outputElementType)); - convolutionCapture = &convolution; - - // Optional FillOp to create the unique output of the convolution. - auto &fill = m_StructuredOp(matcherContext) - .output(0, CaptureElementTypeBitWidth( - captures.maybeFillElementalTypeBitWidth)); - convolution = - convolution.output(NumEqualsTo(1)).output(0, fill, OptionalMatch()); - fillCapture = &fill; - - // Optional trailing op can be any map, transpose, broadcast but - // not reduce or windowing operation for now. - // It must create the unique input for the reduction. - auto &trailing = - m_StructuredOp(matcherContext) - // All parallel dimensions. - .dim(AllDims(), utils::IteratorType::parallel) - // All inputs are any projected permutation. - .input(AllOperands(), IsProjectedPermutation()) - .output(AllOperands(), IsPermutation()) - .output(NumEqualsTo(1)) - .dim(AllDims(), CaptureDims(captures.trailingOpSizes)) - // Capture output elemental type. - .output(0, CaptureElementTypeBitWidth( - captures.maybeTrailingOutputElementalTypeBitWidth)); - - // Optional trailing can be any map, transpose, broadcast but not reduce or - // windowing operation for now. - convolution = convolution.result(0, HasAnyUse(), trailing, OptionalMatch()); - if (mustMatchEntireFunc) { - convolution = - convolution.allTilableOpsCaptured(); - } - trailingCapture = &trailing; -} - -void transform_ext::makeConvolutionMatcher( - transform_ext::MatcherContext &context, - StructuredOpMatcher *&convolutionCapture, - MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { - StructuredOpMatcher *fill; - StructuredOpMatcher *trailing; - makeConvolutionMatcher(context, convolutionCapture, fill, trailing, captures, - mustMatchEntireFunc); -} - -void transform_ext::makePadMatcher(MatcherContext &context, - CapturingOpMatcher *&padCapture, - MatchedPadCaptures &captures, - bool mustMatchEntireFunc) { - auto &value = transform_ext::m_ShapedValue(context); - value.rank(transform_ext::CaptureRank(captures.rank)) - .dim(transform_ext::AllDims(), transform_ext::CaptureDims(captures.dims)) - .elementType(CaptureElementType(captures.elementType)); - auto &opMatcher = transform_ext::m_tensorPad(context) - .result(0, value) - .low(AllDims(), 0) - .yieldsExternalValue(); - if (mustMatchEntireFunc) { - opMatcher = opMatcher.allTilableOpsCaptured(); - } - padCapture = &opMatcher; -} diff --git a/llvm-external-projects/iree-dialects/python/CMakeLists.txt b/llvm-external-projects/iree-dialects/python/CMakeLists.txt index 0fc3506f5420..488e5129834a 100644 --- a/llvm-external-projects/iree-dialects/python/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/python/CMakeLists.txt @@ -15,16 +15,6 @@ declare_mlir_python_sources(IREEDialectsPythonSources.Dialects ADD_TO_PARENT IREEDialectsPythonSources ) -declare_mlir_dialect_extension_python_bindings( - ADD_TO_PARENT IREEDialectsPythonSources.Dialects - ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler" - TD_FILE dialects/IreeStructuredTransformOps.td - SOURCES - dialects/transform/iree_structured.py - dialects/_iree_structured_transform_ops_ext.py - DIALECT_NAME transform - EXTENSION_NAME iree_structured_transform) - ################################################################################ # Extensions ################################################################################ @@ -36,7 +26,7 @@ declare_mlir_python_extension(IREEDialectsPythonExtensions.Main SOURCES IREEDialectsModule.cpp EMBED_CAPI_LINK_LIBS - IREEDialectsCAPI + MLIRCAPITransformDialect PRIVATE_LINK_LIBS LLVMSupport ) diff --git a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp index d6bab61dfe9f..10a50f8e43dc 100644 --- a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp +++ b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp @@ -4,10 +4,10 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree-dialects-c/Dialects.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" +#include "mlir-c/Dialect/Transform.h" #include "mlir-c/RegisterEverything.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" @@ -30,7 +30,6 @@ NB_MODULE(_ireeDialects, m) { [](MlirContext context, bool load) { MlirDialectHandle handle = mlirGetDialectHandle__transform__(); mlirDialectHandleRegisterDialect(handle, context); - ireeRegisterTransformExtensions(context); if (load) { mlirDialectHandleLoadDialect(handle, context); } diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeStructuredTransformOps.td b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeStructuredTransformOps.td deleted file mode 100644 index 0a10cd36d778..000000000000 --- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeStructuredTransformOps.td +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef PYTHON_BINDINGS_IREE_TRANSFORMEXT_BINDING -#define PYTHON_BINDINGS_IREE_TRANSFORMEXT_BINDING - -include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td" - -#endif // PYTHON_BINDINGS_IREE_TRANSFORMEXT_BINDING diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py deleted file mode 100644 index f1f89550bd42..000000000000 --- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Disable PyType, it does not seem to like the specialization pattern used in -# MLIR. -# pytype: skip-file -try: - from ..ir import * - from ..dialects import pdl - from ._ods_common import ( - extend_opview_class as _ods_extend_opview_class, - segmented_accessor as _ods_segmented_accessor, - equally_sized_accessor as _ods_equally_sized_accessor, - get_default_loc_context as _ods_get_default_loc_context, - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, - ) - from typing import Optional, overload, Sequence, Union -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e -BoolArg = Optional[Union[bool, BoolAttr]] -IntListArg = Optional[Union[Sequence[int], ArrayAttr]] -StringArg = Optional[Union[str, StringAttr]] - - -def _defaulted_ensure(f): - def inner(value, default=None): - assert value is not None or default is not None - return f(default if value is None else value) - - return inner - - -@_defaulted_ensure -def _ensure_int_array_attr(value: IntListArg): - i64 = IntegerType.get_signless(64) - if isinstance(value, Sequence): - return ArrayAttr.get([IntegerAttr.get(i64, i) for i in value]) - return value - - -@_defaulted_ensure -def _ensure_bool_attr(value: BoolArg): - if isinstance(value, bool): - return BoolAttr.get(value) - return value - - -@_defaulted_ensure -def _ensure_string_attr(value: StringArg): - if isinstance(value, str): - return StringAttr.get(value) - return value - - -class LowerToLLVMOp: - """Specialization for the LowerToLLVMOp class.""" - - def __init__( - self, - *, - reassociate_fp_reductions: BoolArg = None, - enable_index_optimizations: BoolArg = None, - enable_arm_neon: BoolArg = None, - enable_arm_sve: BoolArg = None, - enable_amx: BoolArg = None, - enable_x86vector: BoolArg = None, - enable_async: BoolArg = None, - loc=None, - ip=None - ): - super().__init__( - _ensure_bool_attr(reassociate_fp_reductions, False), - _ensure_bool_attr(enable_index_optimizations, False), - _ensure_bool_attr(enable_arm_neon, False), - _ensure_bool_attr(enable_arm_sve, False), - _ensure_bool_attr(enable_amx, False), - _ensure_bool_attr(enable_x86vector, False), - _ensure_bool_attr(enable_async, False), - loc=loc, - ip=ip, - ) diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/transform/iree_structured.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/transform/iree_structured.py deleted file mode 100644 index 563e20bc0b04..000000000000 --- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/transform/iree_structured.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from .._iree_structured_transform_ops_gen import * diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt index 669a616d9e89..7d699b5d14eb 100644 --- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt @@ -1,6 +1,4 @@ set(LIBS - # Local dialects. - IREELinalgTransformDialect # Core dialects. MLIRAffineDialect MLIRArithDialect diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp index b4a100caa1b2..2824d307fa64 100644 --- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp +++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" @@ -71,7 +70,6 @@ int main(int argc, char **argv) { mlir::LLVM::registerInlinerInterface(registry); mlir::linalg::registerTilingInterfaceExternalModels(registry); - registry.addExtensions(); mlir::bufferization::registerTransformDialectExtension(registry); mlir::linalg::registerTransformDialectExtension(registry); mlir::scf::registerTransformDialectExtension(registry); From 668d5e19c2a6019ab501845c58bfbe029a427372 Mon Sep 17 00:00:00 2001 From: Alex-Wengg Date: Thu, 18 Jun 2026 10:33:35 -0400 Subject: [PATCH 7/8] [GlobalOpt][NFC] Trim contextual comment in softmax_maxnumf lit test Shorten the stabilizing-max comment to a single line per review feedback: contextual rationale about how linalg.softmax decomposes belongs in the commit history (git blame), not in persistent LIT test comments where it can drift from the implementation. Signed-off-by: Alex-Wengg --- .../compiler/GlobalOptimization/test/raise_special_ops.mlir | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir index f2c9ebd525dd..6dfcc1fd6e1f 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir @@ -168,8 +168,7 @@ util.func public @softmax_broadcast(%93 : tensor<12x128x128xf32>) -> (tensor<12x // ----- -// The stabilizing max may use arith.maxnumf (NaN-ignoring) instead of -// arith.maximumf -- this is the form linalg.softmax itself decomposes to. +// The stabilizing max may use arith.maxnumf (NaN-ignoring). // CHECK-LABEL: @softmax_maxnumf // CHECK-SAME: %[[ARG:.+]]: tensor<2x4xf32> // CHECK: %[[S:.+]] = linalg.softmax dimension(1) ins(%[[ARG]] : tensor<2x4xf32>) From 01f9db1dab6503e408543b94028c92c26bc42e1f Mon Sep 17 00:00:00 2001 From: Alex-Wengg Date: Thu, 18 Jun 2026 15:00:39 -0400 Subject: [PATCH 8/8] [GlobalOpt][NFC] Trim negative-test comments in raise_special_ops Shorten the two verbose "Negative test: ... because ..." rationale comments to one-line labels stating what each test checks, consistent with the earlier softmax_maxnumf trim. Rationale belongs in commit history, not persistent LIT test comments where it can drift. Signed-off-by: Alex-Wengg --- .../GlobalOptimization/test/raise_special_ops.mlir | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir index 6dfcc1fd6e1f..877b169c70c5 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir @@ -210,8 +210,7 @@ util.func public @softmax_maxnumf(%src : tensor<2x4xf32>) -> (tensor<2x4xf32>) { // ----- -// Negative test: the max reduction is initialized with 0.0 instead of -inf, so -// this is not a numerically-stabilized softmax and must not be raised. +// Negative test: max init is 0.0, not -inf/lowest. // CHECK-LABEL: @not_softmax_wrong_max_init // CHECK-NOT: linalg.softmax util.func public @not_softmax_wrong_max_init(%src : tensor) -> (tensor) { @@ -263,9 +262,7 @@ util.func public @not_softmax_wrong_max_init(%src : tensor) -> (tenso // ----- -// Negative test: the max reduction reduces %src but the subtraction reads a -// different tensor %other, so the captured source is inconsistent and the -// pattern must not be raised. +// Negative test: max and subtraction read different sources. // CHECK-LABEL: @not_softmax_mismatched_source // CHECK-NOT: linalg.softmax util.func public @not_softmax_mismatched_source(%src : tensor, %other : tensor) -> (tensor) {