Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9356,6 +9356,41 @@ class ConvertVPTOUnrealizedCastOp final
}
};

class ConvertArithSelectOp final : public OpConversionPattern<arith::SelectOp> {
public:
ConvertArithSelectOp(TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<arith::SelectOp>(typeConverter, context,
PatternBenefit(2)) {}

LogicalResult
matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!hasVPTOConvertibleType(op->getOperandTypes()) &&
!hasVPTOConvertibleType(op->getResultTypes()))
return failure();
if (!op.getCondition().getType().isInteger(1))
return rewriter.notifyMatchFailure(
op, "only scalar i1 conditions supported for VPTO arith.select");

Type convertedResultType =
getTypeConverter()->convertType(op.getResult().getType());
if (!convertedResultType)
return rewriter.notifyMatchFailure(op, "failed to convert result type");

Value trueValue = adaptor.getTrueValue();
Value falseValue = adaptor.getFalseValue();
if (trueValue.getType() != convertedResultType ||
falseValue.getType() != convertedResultType)
return rewriter.notifyMatchFailure(
op, "converted true/false values must match result type");

rewriter.replaceOpWithNewOp<arith::SelectOp>(
op, convertedResultType, adaptor.getCondition(), trueValue,
falseValue);
return success();
}
};

class ConvertPtoAddPtrOp final : public OpConversionPattern<pto::AddPtrOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -10231,6 +10266,7 @@ static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS)
patterns.add<ConvertPtoLoadOp, ConvertPtoStoreOp, ConvertPtoLdgOp,
ConvertPtoStgOp>(
typeConverter, context, state);
patterns.add<ConvertArithSelectOp>(typeConverter, context);
patterns.add<ConvertVPTOUnrealizedCastOp>(typeConverter, context);
patterns.add<ConvertVPTOTypedCarrierOp>(typeConverter, context);

Expand Down
36 changes: 36 additions & 0 deletions lib/PTO/Transforms/VPTOLLVMEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9300,6 +9300,41 @@ class ConvertVPTOUnrealizedCastOp final
}
};

class ConvertArithSelectOp final : public OpConversionPattern<arith::SelectOp> {
public:
ConvertArithSelectOp(TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<arith::SelectOp>(typeConverter, context,
PatternBenefit(2)) {}

LogicalResult
matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!hasVPTOConvertibleType(op->getOperandTypes()) &&
!hasVPTOConvertibleType(op->getResultTypes()))
return failure();
if (!op.getCondition().getType().isInteger(1))
return rewriter.notifyMatchFailure(
op, "only scalar i1 conditions supported for VPTO arith.select");

Type convertedResultType =
getTypeConverter()->convertType(op.getResult().getType());
if (!convertedResultType)
return rewriter.notifyMatchFailure(op, "failed to convert result type");

Value trueValue = adaptor.getTrueValue();
Value falseValue = adaptor.getFalseValue();
if (trueValue.getType() != convertedResultType ||
falseValue.getType() != convertedResultType)
return rewriter.notifyMatchFailure(
op, "converted true/false values must match result type");

rewriter.replaceOpWithNewOp<arith::SelectOp>(
op, convertedResultType, adaptor.getCondition(), trueValue,
falseValue);
return success();
}
};

class ConvertPtoAddPtrOp final : public OpConversionPattern<pto::AddPtrOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -10177,6 +10212,7 @@ static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS)
patterns.add<ConvertPtoLoadOp, ConvertPtoStoreOp, ConvertPtoLdgOp,
ConvertPtoStgOp>(
typeConverter, context, state);
patterns.add<ConvertArithSelectOp>(typeConverter, context);
patterns.add<ConvertVPTOUnrealizedCastOp>(typeConverter, context);
patterns.add<ConvertVPTOTypedCarrierOp>(typeConverter, context);

Expand Down
54 changes: 54 additions & 0 deletions test/lit/vpto/arith_select_vpto_llvm.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2026 Huawei Technologies Co., Ltd.
// This program is free software, you can redistribute it and/or modify it under the terms and conditions of
// CANN Open Software License Agreement Version 2.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.

// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of %T is deprecated in LLVM lit because it is a shared directory across tests and can lead to race conditions or flakiness. Since lit automatically creates the directory containing %t before running the test, mkdir -p %T is redundant and can be safely removed.

// RUN: ( ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s


module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind<vector>} {
func.func @arith_select_vreg(%cond: i1, %lhs_scalar: f32, %rhs_scalar: f32,
%dst: !pto.ptr<f32, ub>) attributes {pto.kernel} {
%c0 = arith.constant 0 : index
pto.vecscope {
%mask = pto.pset_b32 "PAT_ALL" : !pto.mask<b32>
%lhs = pto.vdup %lhs_scalar, %mask
: f32, !pto.mask<b32> -> !pto.vreg<64xf32>
%rhs = pto.vdup %rhs_scalar, %mask
: f32, !pto.mask<b32> -> !pto.vreg<64xf32>
%chosen = arith.select %cond, %lhs, %rhs : !pto.vreg<64xf32>
pto.vsts %chosen, %dst[%c0], %mask
: !pto.vreg<64xf32>, !pto.ptr<f32, ub>, !pto.mask<b32>
}
return
}

func.func @arith_select_mask(%cond: i1, %value: f32,
%dst: !pto.ptr<f32, ub>) attributes {pto.kernel} {
%c0 = arith.constant 0 : index
pto.vecscope {
%all = pto.pset_b32 "PAT_ALL" : !pto.mask<b32>
%tail = pto.pge_b32 "PAT_VL4" : !pto.mask<b32>
%chosen_mask = arith.select %cond, %all, %tail : !pto.mask<b32>
%vec = pto.vdup %value, %all
: f32, !pto.mask<b32> -> !pto.vreg<64xf32>
pto.vsts %vec, %dst[%c0], %chosen_mask
: !pto.vreg<64xf32>, !pto.ptr<f32, ub>, !pto.mask<b32>
}
return
}
}

// CHECK-LABEL: llvm.func @arith_select_vreg_mix_aiv
// CHECK: %[[LHS:.*]] = llvm.call @llvm.hivm.vdups{{.*}}
// CHECK: %[[RHS:.*]] = llvm.call @llvm.hivm.vdups{{.*}}
// CHECK: %[[CHOSEN:.*]] = llvm.select %arg0, %[[LHS]], %[[RHS]] : i1, vector<64xf32>
// CHECK: llvm.call @llvm.hivm.vstsx1.v64f32(%[[CHOSEN]]

// CHECK-LABEL: llvm.func @arith_select_mask_mix_aiv
// CHECK: %[[ALL:.*]] = llvm.call @llvm.hivm.pset.b32
// CHECK: %[[TAIL:.*]] = llvm.call @llvm.hivm.pge.b32
// CHECK: %[[CHOSEN_MASK:.*]] = llvm.select %arg0, %[[ALL]], %[[TAIL]] : i1, vector<256xi1>
// CHECK: llvm.call @llvm.hivm.vstsx1.v64f32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CHOSEN_MASK]])
Loading