@@ -2257,37 +2257,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
22572257 return success ();
22582258}
22592259
2260- LogicalResult ConversionPatternRewriter::legalize (Region *r) {
2261- // Fast path: If the region is empty, there is nothing to legalize.
2262- if (r->empty ())
2263- return success ();
2264-
2265- // Gather a list of all operations to legalize. This is done before
2266- // converting the entry block signature because unrealized_conversion_cast
2267- // ops should not be included.
2268- SmallVector<Operation *> ops;
2269- for (Block &b : *r)
2270- for (Operation &op : b)
2271- ops.push_back (&op);
2272-
2273- // If the current pattern runs with a type converter, convert the entry block
2274- // signature.
2275- if (const TypeConverter *converter = impl->currentTypeConverter ) {
2276- std::optional<TypeConverter::SignatureConversion> conversion =
2277- converter->convertBlockSignature (&r->front ());
2278- if (!conversion)
2279- return failure ();
2280- applySignatureConversion (&r->front (), *conversion, converter);
2281- }
2282-
2283- // Legalize all operations in the region.
2284- for (Operation *op : ops)
2285- if (failed (legalize (op)))
2286- return failure ();
2287-
2288- return success ();
2289- }
2290-
22912260void ConversionPatternRewriter::inlineBlockBefore (Block *source, Block *dest,
22922261 Block::iterator before,
22932262 ValueRange argValues) {
@@ -3287,8 +3256,20 @@ struct OperationConverter {
32873256 : rewriter(ctx, config, *this ), opLegalizer(rewriter, target, patterns),
32883257 mode(mode) {}
32893258
3290- // / Converts the given operations to the conversion target.
3291- LogicalResult convertOperations (ArrayRef<Operation *> ops);
3259+ // / Applies the conversion to the given operations (and their nested
3260+ // / operations).
3261+ LogicalResult applyConversion (ArrayRef<Operation *> ops);
3262+
3263+ // / Legalizes the given operations (and their nested operations) to the
3264+ // / conversion target.
3265+ template <typename Fn>
3266+ LogicalResult legalizeOperations (ArrayRef<Operation *> ops, Fn onFailure,
3267+ bool isRecursiveLegalization = false );
3268+ LogicalResult legalizeOperations (ArrayRef<Operation *> ops,
3269+ bool isRecursiveLegalization = false ) {
3270+ return legalizeOperations (
3271+ ops, /* onFailure=*/ [&]() {}, isRecursiveLegalization);
3272+ }
32923273
32933274 // / Converts a single operation. If `isRecursiveLegalization` is "true", the
32943275 // / conversion is a recursive legalization request, triggered from within a
@@ -3297,6 +3278,8 @@ struct OperationConverter {
32973278 // / legalization mechanism).
32983279 LogicalResult convert (Operation *op, bool isRecursiveLegalization = false );
32993280
3281+ const ConversionTarget &getTarget () { return opLegalizer.getTarget (); }
3282+
33003283private:
33013284 // / The rewriter to use when converting operations.
33023285 ConversionPatternRewriter rewriter;
@@ -3309,10 +3292,6 @@ struct OperationConverter {
33093292};
33103293} // namespace mlir
33113294
3312- LogicalResult ConversionPatternRewriter::legalize (Operation *op) {
3313- return impl->opConverter .convert (op, /* isRecursiveLegalization=*/ true );
3314- }
3315-
33163295LogicalResult OperationConverter::convert (Operation *op,
33173296 bool isRecursiveLegalization) {
33183297 const ConversionConfig &config = rewriter.getConfig ();
@@ -3398,12 +3377,15 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
33983377 return failure ();
33993378}
34003379
3401- LogicalResult OperationConverter::convertOperations (ArrayRef<Operation *> ops) {
3380+ template <typename Fn>
3381+ LogicalResult
3382+ OperationConverter::legalizeOperations (ArrayRef<Operation *> ops, Fn onFailure,
3383+ bool isRecursiveLegalization) {
34023384 const ConversionTarget &target = opLegalizer.getTarget ();
34033385
34043386 // Compute the set of operations and blocks to convert.
34053387 SmallVector<Operation *> toConvert;
3406- for (auto *op : ops) {
3388+ for (Operation *op : ops) {
34073389 op->walk <WalkOrder::PreOrder, ForwardDominanceIterator<>>(
34083390 [&](Operation *op) {
34093391 toConvert.push_back (op);
@@ -3415,25 +3397,67 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
34153397 return WalkResult::advance ();
34163398 });
34173399 }
3400+ for (Operation *op : toConvert) {
3401+ if (failed (convert (op, isRecursiveLegalization))) {
3402+ // Failed to convert an operation.
3403+ onFailure ();
3404+ return failure ();
3405+ }
3406+ }
3407+ return success ();
3408+ }
34183409
3419- // Convert each operation and discard rewrites on failure.
3420- ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl ();
3410+ LogicalResult ConversionPatternRewriter::legalize (Operation *op) {
3411+ return impl->opConverter .legalizeOperations (op,
3412+ /* isRecursiveLegalization=*/ true );
3413+ }
34213414
3422- for (auto *op : toConvert) {
3423- if (failed (convert (op))) {
3424- // Dialect conversion failed.
3425- if (rewriterImpl.config .allowPatternRollback ) {
3426- // Rollback is allowed: restore the original IR.
3427- rewriterImpl.undoRewrites ();
3428- } else {
3429- // Rollback is not allowed: apply all modifications that have been
3430- // performed so far.
3431- rewriterImpl.applyRewrites ();
3432- }
3415+ LogicalResult ConversionPatternRewriter::legalize (Region *r) {
3416+ // Fast path: If the region is empty, there is nothing to legalize.
3417+ if (r->empty ())
3418+ return success ();
3419+
3420+ // Gather a list of all operations to legalize. This is done before
3421+ // converting the entry block signature because unrealized_conversion_cast
3422+ // ops should not be included.
3423+ SmallVector<Operation *> ops;
3424+ for (Block &b : *r)
3425+ for (Operation &op : b)
3426+ ops.push_back (&op);
3427+
3428+ // If the current pattern runs with a type converter, convert the entry block
3429+ // signature.
3430+ if (const TypeConverter *converter = impl->currentTypeConverter ) {
3431+ std::optional<TypeConverter::SignatureConversion> conversion =
3432+ converter->convertBlockSignature (&r->front ());
3433+ if (!conversion)
34333434 return failure ();
3434- }
3435+ applySignatureConversion (&r-> front (), *conversion, converter);
34353436 }
34363437
3438+ // Legalize all operations in the region. This includes all nested
3439+ // operations.
3440+ return impl->opConverter .legalizeOperations (ops,
3441+ /* isRecursiveLegalization=*/ true );
3442+ }
3443+
3444+ LogicalResult OperationConverter::applyConversion (ArrayRef<Operation *> ops) {
3445+ // Convert each operation and discard rewrites on failure.
3446+ ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl ();
3447+ LogicalResult status = legalizeOperations (ops, /* onFailure=*/ [&]() {
3448+ // Dialect conversion failed.
3449+ if (rewriterImpl.config .allowPatternRollback ) {
3450+ // Rollback is allowed: restore the original IR.
3451+ rewriterImpl.undoRewrites ();
3452+ } else {
3453+ // Rollback is not allowed: apply all modifications that have been
3454+ // performed so far.
3455+ rewriterImpl.applyRewrites ();
3456+ }
3457+ });
3458+ if (failed (status))
3459+ return failure ();
3460+
34373461 // After a successful conversion, apply rewrites.
34383462 rewriterImpl.applyRewrites ();
34393463
@@ -4143,7 +4167,7 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops,
41434167 [&] {
41444168 OperationConverter opConverter (ops.front ()->getContext (), target,
41454169 patterns, config, mode);
4146- status = opConverter.convertOperations (ops);
4170+ status = opConverter.applyConversion (ops);
41474171 },
41484172 irUnits);
41494173 return status;
0 commit comments