From de6fc78f9a92ce296f5eeea860c2cdb5c24009f6 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 15:05:06 +0800 Subject: [PATCH] Support arith select in VPTO LLVM lowering --- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 36 +++++++++++++ lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 36 +++++++++++++ test/lit/vpto/arith_select_vpto_llvm.pto | 54 +++++++++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 test/lit/vpto/arith_select_vpto_llvm.pto diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index 8362aea64b..4d4b82f5a8 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -9356,6 +9356,41 @@ class ConvertVPTOUnrealizedCastOp final } }; +class ConvertArithSelectOp final : public OpConversionPattern { +public: + ConvertArithSelectOp(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context, + PatternBenefit(2)) {} + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + op, "only scalar i1 conditions supported for VPTO arith.select"); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + Value trueValue = adaptor.getTrueValue(); + Value falseValue = adaptor.getFalseValue(); + if (trueValue.getType() != convertedResultType || + falseValue.getType() != convertedResultType) + return rewriter.notifyMatchFailure( + op, "converted true/false values must match result type"); + + rewriter.replaceOpWithNewOp( + op, convertedResultType, adaptor.getCondition(), trueValue, + falseValue); + return success(); + } +}; + class ConvertPtoAddPtrOp final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -10231,6 +10266,7 @@ static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS) patterns.add( typeConverter, context, state); + patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 35f8cc51a3..bee22fed58 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -9300,6 +9300,41 @@ class ConvertVPTOUnrealizedCastOp final } }; +class ConvertArithSelectOp final : public OpConversionPattern { +public: + ConvertArithSelectOp(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context, + PatternBenefit(2)) {} + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + op, "only scalar i1 conditions supported for VPTO arith.select"); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + Value trueValue = adaptor.getTrueValue(); + Value falseValue = adaptor.getFalseValue(); + if (trueValue.getType() != convertedResultType || + falseValue.getType() != convertedResultType) + return rewriter.notifyMatchFailure( + op, "converted true/false values must match result type"); + + rewriter.replaceOpWithNewOp( + op, convertedResultType, adaptor.getCondition(), trueValue, + falseValue); + return success(); + } +}; + class ConvertPtoAddPtrOp final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -10177,6 +10212,7 @@ static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS) patterns.add( typeConverter, context, state); + patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/test/lit/vpto/arith_select_vpto_llvm.pto b/test/lit/vpto/arith_select_vpto_llvm.pto new file mode 100644 index 0000000000..b32a7fe0de --- /dev/null +++ b/test/lit/vpto/arith_select_vpto_llvm.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @arith_select_vreg(%cond: i1, %lhs_scalar: f32, %rhs_scalar: f32, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vdup %lhs_scalar, %mask + : f32, !pto.mask -> !pto.vreg<64xf32> + %rhs = pto.vdup %rhs_scalar, %mask + : f32, !pto.mask -> !pto.vreg<64xf32> + %chosen = arith.select %cond, %lhs, %rhs : !pto.vreg<64xf32> + pto.vsts %chosen, %dst[%c0], %mask + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + return + } + + func.func @arith_select_mask(%cond: i1, %value: f32, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %all = pto.pset_b32 "PAT_ALL" : !pto.mask + %tail = pto.pge_b32 "PAT_VL4" : !pto.mask + %chosen_mask = arith.select %cond, %all, %tail : !pto.mask + %vec = pto.vdup %value, %all + : f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %vec, %dst[%c0], %chosen_mask + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-LABEL: llvm.func @arith_select_vreg_mix_aiv +// CHECK: %[[LHS:.*]] = llvm.call @llvm.hivm.vdups{{.*}} +// CHECK: %[[RHS:.*]] = llvm.call @llvm.hivm.vdups{{.*}} +// CHECK: %[[CHOSEN:.*]] = llvm.select %arg0, %[[LHS]], %[[RHS]] : i1, vector<64xf32> +// CHECK: llvm.call @llvm.hivm.vstsx1.v64f32(%[[CHOSEN]] + +// CHECK-LABEL: llvm.func @arith_select_mask_mix_aiv +// CHECK: %[[ALL:.*]] = llvm.call @llvm.hivm.pset.b32 +// CHECK: %[[TAIL:.*]] = llvm.call @llvm.hivm.pge.b32 +// CHECK: %[[CHOSEN_MASK:.*]] = llvm.select %arg0, %[[ALL]], %[[TAIL]] : i1, vector<256xi1> +// CHECK: llvm.call @llvm.hivm.vstsx1.v64f32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CHOSEN_MASK]])