Skip to content

Commit 78711b6

Browse files
[mlir][Transforms] Legalize nested operations (#172158)
This commit align the implementation of `ConversionPatternRewriter::legalize` with its documentation: ``` /// Attempt to legalize the given region. This can be used within ... LogicalResult legalize(Region *r); ``` This function now legalizes the entire region, including nested ops. The implementation follows the same logic as the "main" traversal: pre-order, forward-dominance.
1 parent cd806d7 commit 78711b6

File tree

2 files changed

+86
-54
lines changed

2 files changed

+86
-54
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 78 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,37 +2257,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
22572257
return success();
22582258
}
22592259

2260-
LogicalResult ConversionPatternRewriter::legalize(Region *r) {
2261-
// Fast path: If the region is empty, there is nothing to legalize.
2262-
if (r->empty())
2263-
return success();
2264-
2265-
// Gather a list of all operations to legalize. This is done before
2266-
// converting the entry block signature because unrealized_conversion_cast
2267-
// ops should not be included.
2268-
SmallVector<Operation *> ops;
2269-
for (Block &b : *r)
2270-
for (Operation &op : b)
2271-
ops.push_back(&op);
2272-
2273-
// If the current pattern runs with a type converter, convert the entry block
2274-
// signature.
2275-
if (const TypeConverter *converter = impl->currentTypeConverter) {
2276-
std::optional<TypeConverter::SignatureConversion> conversion =
2277-
converter->convertBlockSignature(&r->front());
2278-
if (!conversion)
2279-
return failure();
2280-
applySignatureConversion(&r->front(), *conversion, converter);
2281-
}
2282-
2283-
// Legalize all operations in the region.
2284-
for (Operation *op : ops)
2285-
if (failed(legalize(op)))
2286-
return failure();
2287-
2288-
return success();
2289-
}
2290-
22912260
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
22922261
Block::iterator before,
22932262
ValueRange argValues) {
@@ -3287,8 +3256,20 @@ struct OperationConverter {
32873256
: rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
32883257
mode(mode) {}
32893258

3290-
/// Converts the given operations to the conversion target.
3291-
LogicalResult convertOperations(ArrayRef<Operation *> ops);
3259+
/// Applies the conversion to the given operations (and their nested
3260+
/// operations).
3261+
LogicalResult applyConversion(ArrayRef<Operation *> ops);
3262+
3263+
/// Legalizes the given operations (and their nested operations) to the
3264+
/// conversion target.
3265+
template <typename Fn>
3266+
LogicalResult legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
3267+
bool isRecursiveLegalization = false);
3268+
LogicalResult legalizeOperations(ArrayRef<Operation *> ops,
3269+
bool isRecursiveLegalization = false) {
3270+
return legalizeOperations(
3271+
ops, /*onFailure=*/[&]() {}, isRecursiveLegalization);
3272+
}
32923273

32933274
/// Converts a single operation. If `isRecursiveLegalization` is "true", the
32943275
/// conversion is a recursive legalization request, triggered from within a
@@ -3297,6 +3278,8 @@ struct OperationConverter {
32973278
/// legalization mechanism).
32983279
LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
32993280

3281+
const ConversionTarget &getTarget() { return opLegalizer.getTarget(); }
3282+
33003283
private:
33013284
/// The rewriter to use when converting operations.
33023285
ConversionPatternRewriter rewriter;
@@ -3309,10 +3292,6 @@ struct OperationConverter {
33093292
};
33103293
} // namespace mlir
33113294

3312-
LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3313-
return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
3314-
}
3315-
33163295
LogicalResult OperationConverter::convert(Operation *op,
33173296
bool isRecursiveLegalization) {
33183297
const ConversionConfig &config = rewriter.getConfig();
@@ -3398,12 +3377,15 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
33983377
return failure();
33993378
}
34003379

3401-
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
3380+
template <typename Fn>
3381+
LogicalResult
3382+
OperationConverter::legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
3383+
bool isRecursiveLegalization) {
34023384
const ConversionTarget &target = opLegalizer.getTarget();
34033385

34043386
// Compute the set of operations and blocks to convert.
34053387
SmallVector<Operation *> toConvert;
3406-
for (auto *op : ops) {
3388+
for (Operation *op : ops) {
34073389
op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
34083390
[&](Operation *op) {
34093391
toConvert.push_back(op);
@@ -3415,25 +3397,67 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
34153397
return WalkResult::advance();
34163398
});
34173399
}
3400+
for (Operation *op : toConvert) {
3401+
if (failed(convert(op, isRecursiveLegalization))) {
3402+
// Failed to convert an operation.
3403+
onFailure();
3404+
return failure();
3405+
}
3406+
}
3407+
return success();
3408+
}
34183409

3419-
// Convert each operation and discard rewrites on failure.
3420-
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3410+
LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3411+
return impl->opConverter.legalizeOperations(op,
3412+
/*isRecursiveLegalization=*/true);
3413+
}
34213414

3422-
for (auto *op : toConvert) {
3423-
if (failed(convert(op))) {
3424-
// Dialect conversion failed.
3425-
if (rewriterImpl.config.allowPatternRollback) {
3426-
// Rollback is allowed: restore the original IR.
3427-
rewriterImpl.undoRewrites();
3428-
} else {
3429-
// Rollback is not allowed: apply all modifications that have been
3430-
// performed so far.
3431-
rewriterImpl.applyRewrites();
3432-
}
3415+
LogicalResult ConversionPatternRewriter::legalize(Region *r) {
3416+
// Fast path: If the region is empty, there is nothing to legalize.
3417+
if (r->empty())
3418+
return success();
3419+
3420+
// Gather a list of all operations to legalize. This is done before
3421+
// converting the entry block signature because unrealized_conversion_cast
3422+
// ops should not be included.
3423+
SmallVector<Operation *> ops;
3424+
for (Block &b : *r)
3425+
for (Operation &op : b)
3426+
ops.push_back(&op);
3427+
3428+
// If the current pattern runs with a type converter, convert the entry block
3429+
// signature.
3430+
if (const TypeConverter *converter = impl->currentTypeConverter) {
3431+
std::optional<TypeConverter::SignatureConversion> conversion =
3432+
converter->convertBlockSignature(&r->front());
3433+
if (!conversion)
34333434
return failure();
3434-
}
3435+
applySignatureConversion(&r->front(), *conversion, converter);
34353436
}
34363437

3438+
// Legalize all operations in the region. This includes all nested
3439+
// operations.
3440+
return impl->opConverter.legalizeOperations(ops,
3441+
/*isRecursiveLegalization=*/true);
3442+
}
3443+
3444+
LogicalResult OperationConverter::applyConversion(ArrayRef<Operation *> ops) {
3445+
// Convert each operation and discard rewrites on failure.
3446+
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3447+
LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() {
3448+
// Dialect conversion failed.
3449+
if (rewriterImpl.config.allowPatternRollback) {
3450+
// Rollback is allowed: restore the original IR.
3451+
rewriterImpl.undoRewrites();
3452+
} else {
3453+
// Rollback is not allowed: apply all modifications that have been
3454+
// performed so far.
3455+
rewriterImpl.applyRewrites();
3456+
}
3457+
});
3458+
if (failed(status))
3459+
return failure();
3460+
34373461
// After a successful conversion, apply rewrites.
34383462
rewriterImpl.applyRewrites();
34393463

@@ -4143,7 +4167,7 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops,
41434167
[&] {
41444168
OperationConverter opConverter(ops.front()->getContext(), target,
41454169
patterns, config, mode);
4146-
status = opConverter.convertOperations(ops);
4170+
status = opConverter.applyConversion(ops);
41474171
},
41484172
irUnits);
41494173
return status;

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,16 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
454454
// The region of "test.post_order_legalization" is converted before the op.
455455

456456
// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
457+
// CHECK: notifyOperationInserted: test.remaining_consumer
458+
// CHECK: notifyOperationInserted: test.legal_op
457459
// CHECK: notifyOperationInserted: test.invalid
458460
// CHECK: notifyBlockErased
459461
// CHECK: notifyOperationInserted: test.valid, was unlinked
460462
// CHECK: notifyOperationReplaced: test.invalid
461463
// CHECK: notifyOperationErased: test.invalid
464+
// CHECK: notifyOperationInserted: test.valid, was unlinked
465+
// CHECK: notifyOperationReplaced: test.invalid
466+
// CHECK: notifyOperationErased: test.invalid
462467
// CHECK: notifyOperationModified: test.post_order_legalization
463468

464469
// CHECK-LABEL: func @test_preorder_legalization
@@ -475,6 +480,9 @@ func.func @test_preorder_legalization() {
475480
^bb0(%arg0: i64):
476481
// expected-remark @+1 {{'test.remaining_consumer' is not legalizable}}
477482
"test.remaining_consumer"(%arg0) : (i64) -> ()
483+
"test.legal_op"() ({
484+
"test.invalid"(%arg0) : (i64) -> ()
485+
}) : () -> ()
478486
"test.invalid"(%arg0) : (i64) -> ()
479487
}) : () -> ()
480488
// expected-remark @+1 {{'func.return' is not legalizable}}

0 commit comments

Comments
 (0)