From 35e0874542eeeab87816676b357ae5a607d95b41 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sat, 27 Jun 2026 09:17:56 +0800 Subject: [PATCH 1/4] Add VPTO vector address memory ops --- docs/designs/vpto-vector-address-ops.md | 1052 +++++++++++++++++ docs/isa/micro-isa/18-vaddr-loop-memory.md | 372 ++++++ include/PTO/IR/VPTOOps.td | 201 ++++ include/PTO/IR/VPTOTypeDefs.td | 20 + lib/PTO/IR/VPTO.cpp | 319 ++++- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 604 +++++++++- .../lit/vpto/vector_address_ops_vpto_llvm.pto | 68 ++ ...tor_address_update_seed_verify_invalid.pto | 35 + .../vector_address_vag_verify_invalid.pto | 21 + ...tor_address_vag_verify_invalid_counter.pto | 26 + .../vector-address/vald-vast/compare.py | 50 + .../vector-address/vald-vast/golden.py | 42 + .../vector-address/vald-vast/kernel.pto | 50 + .../vector-address/vald-vast/launch.cpp | 47 + .../vector-address/vald-vast/main.cpp | 99 ++ .../x2-predicate-unaligned/compare.py | 53 + .../x2-predicate-unaligned/golden.py | 49 + .../x2-predicate-unaligned/kernel.pto | 92 ++ .../x2-predicate-unaligned/launch.cpp | 35 + .../x2-predicate-unaligned/main.cpp | 100 ++ 20 files changed, 3319 insertions(+), 16 deletions(-) create mode 100644 docs/designs/vpto-vector-address-ops.md create mode 100644 docs/isa/micro-isa/18-vaddr-loop-memory.md create mode 100644 test/lit/vpto/vector_address_ops_vpto_llvm.pto create mode 100644 test/lit/vpto/vector_address_update_seed_verify_invalid.pto create mode 100644 test/lit/vpto/vector_address_vag_verify_invalid.pto create mode 100644 test/lit/vpto/vector_address_vag_verify_invalid_counter.pto create mode 100644 test/vpto/cases/micro-op/vector-address/vald-vast/compare.py create mode 100644 test/vpto/cases/micro-op/vector-address/vald-vast/golden.py create mode 100644 test/vpto/cases/micro-op/vector-address/vald-vast/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-address/vald-vast/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-address/vald-vast/main.cpp create mode 100644 test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/compare.py create mode 100644 test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/golden.py create mode 100644 test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/kernel.pto create mode 100644 test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/launch.cpp create mode 100644 test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/main.cpp diff --git a/docs/designs/vpto-vector-address-ops.md b/docs/designs/vpto-vector-address-ops.md new file mode 100644 index 0000000000..17a7167973 --- /dev/null +++ b/docs/designs/vpto-vector-address-ops.md @@ -0,0 +1,1052 @@ +# VPTO vector-address op support design + +## 1. Goal + +This document defines the VPTO IR surface for VISA `VAG` and vector-address +load/store forms. The goal is to support the LLVM-path `vector_address` +intrinsic families without reusing overly generic op names such as `vld`, `vst`, +`pld`, or `pst`. + +Naming rule: + +- `va` means vector-addressed form, i.e. the memory offset operand is a + `vector_address`/address-register offset. +- `vald`, `vast`, `pald`, and `past` are the vector-addressed counterparts of + `vld`, `vst`, `pld`, and `pst`. +- Stateful unaligned forms return updated SSA values instead of relying on + source-level argument mutation. + +This document describes the VPTO vector-address support contract for this +change. + +## 2. Type model + +### 2.1 `!pto.vaddr` + +```mlir +!pto.vaddr +!pto.vaddr +!pto.vaddr +``` + +`!pto.vaddr` represents a CCE `vector_address` value. It is an offset token, +not a complete pointer. A consumer computes the effective UB address as: + +```text +effective_address = base + vaddr +``` + +The granularity parameter records the element-width family used to create the +address. It must be one of `b8`, `b16`, or `b32`. + +The value must not be treated as `!pto.ptr`. Pointer provenance and memory space +come from the explicit base operand of each memory op. + +The compiler ABI represents `vector_address` as `uint32_t +__attribute__((ext_vector_type(1)))`, so the VPTO LLVM conversion type is +`<1 x i32>`. + +### 2.2 Vector LLVM ABI suffixes + +Vector-address memory intrinsics are typed LLVM intrinsic families. The textual +LLVM IR name carries the selected vector ABI suffix: + +| VPTO element family | LLVM vector ABI | Intrinsic suffix | +| --- | --- | --- | +| `i8`/`ui8` | `<256 x i8>` | `v256i8` | +| `i16`/`ui16` | `<128 x i16>` | `v128i16` | +| `f16` | `<128 x half>` | `v128f16` | +| `bf16` | `<128 x bfloat>` | `v128bf16` | +| `i32`/`ui32` | `<64 x i32>` | `v64i32` | +| `f32` | `<64 x float>` | `v64f32` | +| `i64`/`ui64` | `<32 x i64>` | `v32i64` | + +For example, a `!pto.vreg<64xf32>` vector-address load lowers to +`@llvm.hivm.vldx1.v64f32`, not to a suffix-less `@llvm.hivm.vldx1`. + +### 2.3 Offset operands + +`pto.vag` operands are 32-bit unsigned byte strides. CCE source wrappers such as +`vag_b16` and `vag_b32` accept element strides and multiply them by the element +byte width before calling the compiler builtin. VPTO lowers directly to LLVM IR, +so its op boundary uses byte strides and does not repeat that source-level +scaling. + +### 2.4 Evidence model for LLVM lowering + +Do not infer an LLVM intrinsic signature from `strings` output. `strings +$ASCEND_HOME_PATH/bin/bisheng` only proves that an intrinsic name is present in +the installed compiler binary; it does not provide argument types, result type, +or operand order. + +This document uses three evidence levels: + +- Installed CANN Clang headers define source wrapper signatures, builtin names, + operand order, units, and fixed control constants such as `0 /* #loop */`. +- Generated LLVM IR or recovered AICore bitcode defines the final LLVM function + type for a concrete installed compiler. +- `strings bisheng` is only a name inventory used to check that the intrinsic + spelling exists in the installed compiler. + +The repo lit coverage prints VPTO-generated LLVM IR and checks lowering +signatures for the implemented forms. VAG is special: the VPTO direct LLVM input +uses the same no-IV builtin shape exposed by the CCE source wrapper, and the +call must remain inside a loop so later CCE middle-end passes can bind it to +loop state. + +### 2.5 Canonical signatures + +The table below is the implementation contract for the A5 VPTO LLVM path. `V` +means the LLVM vector ABI selected from Section 2.2, `S` means the matching +typed intrinsic suffix, and all pointers are UB pointers lowered to +`ptr addrspace(6)`. + +The VPTO column is intentionally written in the same assembly order as the ODS +`assemblyFormat`. Attributes shown inside `{...}` are named attributes. +Quoted operands such as `"D"` and `"POST_UPDATE"` are required string +attributes printed positionally by that op. + +| VPTO op signature | LLVM lowering signature | +| --- | --- | +| `%addr = pto.vag %s... : i32 -> !pto.vaddr` | `declare <1 x i32> @llvm.hivm.vag.32(i32, i32, i32, i32)` | +| `%v = pto.vald %base[%addr] {dist = "D"} : !pto.ptr, !pto.vaddr -> !pto.vreg` | `declare V @llvm.hivm.vldx1.S(ptr addrspace(6) nocapture readonly, <1 x i32>, i32, i32)` | +| `%lo, %hi = pto.valdx2 %base[%addr], "D" : !pto.ptr, !pto.vaddr -> !pto.vreg, !pto.vreg` | `declare { V, V } @llvm.hivm.vldx2.S(ptr addrspace(6) nocapture readonly, <1 x i32>, i32, i32)` | +| `pto.vast %v, %base[%addr], %mask {dist = "D"} : !pto.vreg, !pto.ptr, !pto.vaddr, !pto.mask` | `declare void @llvm.hivm.vstx1.S(V, ptr addrspace(6) nocapture writeonly, <1 x i32>, i32, i32, <256 x i1>)` | +| `pto.vastx2 %lo, %hi, %base[%addr], "D", %mask : !pto.vreg, !pto.vreg, !pto.ptr, !pto.vaddr, !pto.mask` | `declare void @llvm.hivm.vstx2.S(V, V, ptr addrspace(6) nocapture writeonly, <1 x i32>, i32, i32, <256 x i1>)` | +| `%mask = pto.pald %base[%addr], "D" : !pto.ptr, !pto.vaddr -> !pto.mask` | `declare <256 x i1> @llvm.hivm.pld.b8(ptr addrspace(6) nocapture readonly, <1 x i32>, i32, i32)` | +| `pto.past %mask, %base[%addr], "D" : !pto.mask, !pto.ptr, !pto.vaddr` | `declare void @llvm.hivm.pst.b8(<256 x i1>, ptr addrspace(6) nocapture writeonly, <1 x i32>, i32, i32)` | +| `%align = pto.valda %base[%addr] : !pto.ptr, !pto.vaddr -> !pto.align` | `declare <32 x i8> @llvm.hivm.vlda(ptr addrspace(6), <1 x i32>, i32)` | +| `%v, %align1, %addr1 = pto.valdu %base[%addr0], %align0, %inc : !pto.ptr, !pto.vaddr, !pto.align, i32 -> !pto.vreg, !pto.align, !pto.vaddr` | `declare { V, <32 x i8>, <1 x i32> } @llvm.hivm.vldu.v300.S(ptr addrspace(6) nocapture readonly, <1 x i32>, <32 x i8>, i32, i32)` | +| `pto.vasta %align, %base[%addr] : !pto.align, !pto.ptr, !pto.vaddr` | `declare void @llvm.hivm.vsta(<32 x i8>, ptr addrspace(6) nocapture writeonly, <1 x i32>, i32)` | +| `%align1, %addr1 = pto.vastu %align0, %addr0, %v, %base, "POST_UPDATE" : !pto.align, !pto.vaddr, !pto.vreg, !pto.ptr -> !pto.align, !pto.vaddr` | `declare { <32 x i8>, <1 x i32> } @llvm.hivm.vstu.S(V, ptr addrspace(6) nocapture writeonly, <1 x i32>, <32 x i8>, i32, i32)` | + +All LLVM lowering definitions below use the same declaration shapes. Aggregate +returns such as `{ V, V }` may appear as `!llvm.struct<(...)>` in MLIR LLVM +dialect dumps; that is the same final LLVM IR aggregate-return declaration. The +final control operand named loop mode is currently emitted as `i32 0`. + +### 2.6 Immediate tokens used by these signatures + +The current A5 lowering accepts the following string tokens and emits these +integer immediates. These are part of the lowering contract for this change. + +| VPTO attribute | Legal tokens and emitted `i32` values | +| --- | --- | +| `pto.vald` `dist` | omitted/`"NORM"` -> `0`; `"BRC_B8"` -> `1`; `"BRC_B16"` -> `2`; `"BRC_B32"` -> `3`; `"US_B8"` -> `6`; `"US_B16"` -> `7`; `"DS_B8"` -> `8`; `"DS_B16"` -> `9`; `"UNPK_B8"` -> `13`; `"UNPK_B16"` -> `14`; `"UNPK_B32"` -> `18`; `"BRC_BLK"` -> `15`; `"E2B_B16"` -> `16`; `"E2B_B32"` -> `17`; `"UNPK4"` -> `20` for b8 element vectors only; `"SPLT4CHN"` -> `21` for b8 element vectors only; `"SPLT2CHN_B8"` -> `22`; `"SPLT2CHN_B16"` -> `23` | +| `pto.valdx2` `"D"` | `"BDINTLV"` -> `10`; `"DINTLV_B8"` -> `11`; `"DINTLV_B16"` -> `12`; `"DINTLV_B32"` -> `19` | +| `pto.vast` `dist` | omitted -> element-width default (`b8` -> `0`, `b16` -> `1`, `b32` -> `2`); `"NORM_B8"` -> `0`; `"NORM_B16"` -> `1`; `"NORM_B32"` -> `2`; `"1PT_B8"` -> `3`; `"1PT_B16"` -> `4`; `"1PT_B32"` -> `5`; `"PK_B16"` -> `6`; `"PK_B32"` -> `7`; `"PK_B64"` -> `10`; `"PK4_B32"` -> `12`; `"MRG4CHN_B8"` -> `13`; `"MRG2CHN_B8"` -> `14`; `"MRG2CHN_B16"` -> `15` | +| `pto.vastx2` `"D"` | `"INTLV_B8"` -> `8`; `"INTLV_B16"` -> `9`; `"INTLV_B32"` -> `11` | +| `pto.pald` `"D"` | `"NORM"` -> `0`; `"US"` -> `1`; `"DS"` -> `2` | +| `pto.past` `"D"` | `"NORM"` -> `0`; `"PK"` -> `1` | +| `pto.vastu` `"MODE"` | `"POST_UPDATE"` -> `1` | + +### 2.7 Implementation signature ledger + +This section is the implementation checklist for the current VPTO lowering. +Every operand is listed in VPTO assembly order first, then in the LLVM call +order emitted by `VPTOCANN900LLVMEmitter`. The LLVM snippets in this section +use aliases for compactness instead of literal copy-paste LLVM IR. `V` means +the payload LLVM vector ABI from Section 2.2, `A` means `<1 x i32>` +vector-address ABI, `P` means `ptr addrspace(6)`, `M` means `<256 x i1>`, and +`L` means `<32 x i8>`. + +#### 2.7.1 `pto.vag` + +VPTO: + +```text +%addr = pto.vag %s0 : i32 -> !pto.vaddr +``` + +Current A5 LLVM lowering: + +```text +%addr = call A @llvm.hivm.vag.32(i32 %s0, i32 0, i32 0, i32 0) +``` + +Declaration: + +```llvm +declare <1 x i32> @llvm.hivm.vag.32(i32, i32, i32, i32) +``` + +Inactive dimensions are emitted as `i32 0`. `pto.vag` must be nested under an +`i16` `scf.for`; VPTO does not synthesize that loop. + +#### 2.7.2 `pto.vald` + +VPTO: + +```text +%value = pto.vald %base[%addr] {dist = "DIST"} + : !pto.ptr, !pto.vaddr -> !pto.vreg +``` + +LLVM: + +```text +declare V @llvm.hivm.vldx1.S(P readonly, A, i32 /*dist*/, i32 /*loop*/) +``` + +Emitted call operands are `%base, %addr, dist_code, i32 0`. + +#### 2.7.3 `pto.valdx2` + +VPTO: + +```text +%lo, %hi = pto.valdx2 %base[%addr], "DIST" + : !pto.ptr, !pto.vaddr -> !pto.vreg, !pto.vreg +``` + +LLVM: + +```text +declare { V, V } @llvm.hivm.vldx2.S( + P readonly, A, i32 /*dist*/, i32 /*loop*/) +``` + +Emitted call operands are `%base, %addr, dist_code, i32 0`. Aggregate result +index 0 maps to `%lo`; index 1 maps to `%hi`. + +#### 2.7.4 `pto.vast` + +VPTO: + +```text +pto.vast %value, %base[%addr], %mask {dist = "DIST"} + : !pto.vreg, !pto.ptr, !pto.vaddr, !pto.mask +``` + +LLVM: + +```text +declare void @llvm.hivm.vstx1.S( + V, P writeonly, A, i32 /*dist*/, i32 /*loop*/, M) +``` + +Emitted call operands are `%value, %base, %addr, dist_code, i32 0, %mask`. + +#### 2.7.5 `pto.vastx2` + +VPTO: + +```text +pto.vastx2 %lo, %hi, %base[%addr], "DIST", %mask + : !pto.vreg, !pto.vreg, !pto.ptr, + !pto.vaddr, !pto.mask +``` + +LLVM: + +```text +declare void @llvm.hivm.vstx2.S( + V, V, P writeonly, A, i32 /*dist*/, i32 /*loop*/, M) +``` + +Emitted call operands are `%lo, %hi, %base, %addr, dist_code, i32 0, %mask`. + +#### 2.7.6 `pto.pald` + +VPTO: + +```text +%mask = pto.pald %base[%addr], "DIST" + : !pto.ptr, !pto.vaddr -> !pto.mask +``` + +LLVM: + +```text +declare M @llvm.hivm.pld.b8(P readonly, A, i32 /*dist*/, i32 /*loop*/) +``` + +Emitted call operands are `%base, %addr, dist_code, i32 0`. + +#### 2.7.7 `pto.past` + +VPTO: + +```text +pto.past %mask, %base[%addr], "DIST" + : !pto.mask, !pto.ptr, !pto.vaddr +``` + +LLVM: + +```text +declare void @llvm.hivm.pst.b8(M, P writeonly, A, i32 /*dist*/, i32 /*loop*/) +``` + +Emitted call operands are `%mask, %base, %addr, dist_code, i32 0`. + +#### 2.7.8 `pto.valda` + +VPTO: + +```text +%align = pto.valda %base[%addr] + : !pto.ptr, !pto.vaddr -> !pto.align +``` + +LLVM: + +```text +declare L @llvm.hivm.vlda(P, A, i32 /*loop*/) +``` + +Emitted call operands are `%base, %addr, i32 0`. + +#### 2.7.9 `pto.valdu` + +VPTO: + +```text +%value, %align_out, %addr_out = + pto.valdu %base[%addr_in], %align_in, %inc + : !pto.ptr, !pto.vaddr, !pto.align, i32 + -> !pto.vreg, !pto.align, !pto.vaddr +``` + +LLVM: + +```text +declare { V, L, A } @llvm.hivm.vldu.v300.S( + P readonly, A, L, i32 /*inc*/, i32 /*loop*/) +``` + +Emitted call operands are `%base, %addr_in, %align_in, %inc, i32 0`. +Aggregate result indexes map to `%value`, `%align_out`, `%addr_out` in that +order. + +#### 2.7.10 `pto.vasta` + +VPTO: + +```text +pto.vasta %align, %base[%addr] + : !pto.align, !pto.ptr, !pto.vaddr +``` + +LLVM: + +```text +declare void @llvm.hivm.vsta(L, P writeonly, A, i32 /*loop*/) +``` + +Emitted call operands are `%align, %base, %addr, i32 0`. + +#### 2.7.11 `pto.vastu` + +VPTO: + +```text +%align_out, %addr_out = + pto.vastu %align_in, %addr_in, %value, %base, "POST_UPDATE" + : !pto.align, !pto.vaddr, !pto.vreg, !pto.ptr + -> !pto.align, !pto.vaddr +``` + +LLVM: + +```text +declare { L, A } @llvm.hivm.vstu.S( + V, P writeonly, A, L, i32 /*mode*/, i32 /*loop*/) +``` + +Emitted call operands are `%value, %base, %addr_in, %align_in, i32 1, i32 0`. +Aggregate result index 0 maps to `%align_out`; index 1 maps to `%addr_out`. + +## 3. Address generation + +### 3.1 `pto.vag` + +```mlir +%addr = pto.vag %s1 : i32 -> !pto.vaddr +``` + +Semantics: + +```text +addr = s1 * i1 +``` + +The stride is in bytes. VISA supports up to four VAG stride registers for +nested loop layers, but this VPTO implementation currently exposes only the +one-stride form until nested vector-loop IVs are represented in lowering state. + +Verifier constraints: + +- Operand count must be exactly 1. +- The operand must be `i32`. +- Result type must be `!pto.vaddr`, `!pto.vaddr`, or + `!pto.vaddr`. +- VPTO does not currently enforce a source-position restriction. The intended + use is to define the address pattern once for the vector scope and not mutate + it differently across dynamic vector-loop iterations. + +LLVM lowering: + +| Result type | LLVM intrinsic family | +| --- | --- | +| `!pto.vaddr` | `llvm.hivm.vag.32` on A5 32-bit VAG ABI targets | +| `!pto.vaddr` | `llvm.hivm.vag.32` on A5 32-bit VAG ABI targets | +| `!pto.vaddr` | `llvm.hivm.vag.32` on A5 32-bit VAG ABI targets | + +The installed CANN 9.0.0 source wrapper accepts element strides and calls the +compiler builtin in reverse byte-stride order: + +```llvm +%addr = call <1 x i32> @llvm.hivm.vag.32( + i32 %s4, i32 %s3, i32 %s2, i32 %s1) + +declare <1 x i32> @llvm.hivm.vag.32(i32, i32, i32, i32) +``` + +The VPTO LLVM path emits this no-IV form directly. The call must be nested under +an `i16` `scf.for` in VPTO IR so CCE middle-end VAG lowering can associate it +with loop state before object generation. In MLIR assembly, non-`index` +`scf.for` loops require an explicit loop type marker, for example: + +```mlir +%c0_i16 = arith.constant 0 : i16 +%c1_i16 = arith.constant 1 : i16 +%c2_i16 = arith.constant 2 : i16 +scf.for %i = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + %addr = pto.vag %stride : i32 -> !pto.vaddr +} +``` + +Normal vector-address load/store operations may share one `!pto.vaddr` value. +Update forms that return a next vector address, currently `pto.valdu` and +`pto.vastu`, are different: one `!pto.vaddr` value must not seed multiple +update chains. Generate a separate `pto.vag` for independent update chains. + +Lowering operands: + +| LLVM operand | VPTO source | Value | +| --- | --- | --- | +| arg0 | `%s1` | byte stride for the innermost active loop layer | +| arg1 | `%s2` | byte stride for the next loop layer, or `i32 0` | +| arg2 | `%s3` | byte stride for the next loop layer, or `i32 0` | +| arg3 | `%s4` | byte stride for the next loop layer, or `i32 0` | + +Examples: + +```llvm +; pto.vag %s1 +call <1 x i32> @llvm.hivm.vag.32(i32 %s1, i32 0, i32 0, i32 0) + +``` + +The installed CANN 9.0.0 `bisheng` name inventory contains these VAG intrinsic +spellings: + +```text +llvm.hivm.vag.16 +llvm.hivm.vag.32 +llvm.hivm.vag.iv.16 +llvm.hivm.vag.iv.16.se +llvm.hivm.vag.iv.32 +llvm.hivm.vag.iv.32.se +llvm.hivm.vag.v210 +``` + +For non-A5 targets that still use 16-bit or v210 VAG ABI, the target lowering +may select a different builtin family, but the VPTO op contract remains +byte-stride based. Do not lower `pto.vag` through source-level wrapper emission. + +## 4. Aligned vector load/store + +### 4.1 `pto.vald` + +```mlir +%result = pto.vald %base[%addr] {dist = "DIST"} + : !pto.ptr, !pto.vaddr -> !pto.vreg +``` + +Semantics: + +```text +result = load_vector(base + addr, dist) +``` + +Verifier constraints: + +- `%base` must be a UB pointer or UB memref lowered to a VPTO pointer. +- `%addr` must be `!pto.vaddr`. +- `G` must match the element or distribution granularity required by `DIST`. +- `DIST` follows the existing `pto.vlds` distribution vocabulary. + +LLVM lowering: + +```llvm +%result = call <64 x float> @llvm.hivm.vldx1.v64f32( + ptr addrspace(6) %base, <1 x i32> %addr, i32 %dist, i32 0) + +declare <64 x float> @llvm.hivm.vldx1.v64f32( + ptr addrspace(6) nocapture readonly, <1 x i32>, i32, i32) +``` + +Select the intrinsic suffix from the result vector ABI in Section 2.2. The +installed CANN 9.0.0 LLVM evidence also includes `v32i64`, `v64i32`, +`v128bf16`, `v128f16`, `v128i16`, and `v256i8` overloads with the same operand +list. + +Lowering operands: + +| LLVM operand | VPTO source | Value | +| --- | --- | --- | +| arg0 | `%base` | UB pointer converted to `ptr addrspace(6)` | +| arg1 | `%addr` | `!pto.vaddr` converted to `<1 x i32>` | +| arg2 | `DIST` | existing `vlds` distribution enum code, passed as `i32` | +| arg3 | implicit loop mode | `i32 0`, matching `0 /* #loop */` in the CANN wrapper | + +### 4.2 `pto.valdx2` + +```mlir +%low, %high = pto.valdx2 %base[%addr], "DIST" + : !pto.ptr, !pto.vaddr -> !pto.vreg, !pto.vreg +``` + +Semantics: + +```text +(low, high) = load_vector_pair(base + addr, dist) +``` + +Verifier constraints: + +- Same base/address constraints as `pto.vald`. +- `DIST` must be a distribution that produces two destination vector registers, + such as de-interleave/block-deinterleave forms already accepted by + `pto.vldsx2`. + +LLVM lowering: + +```llvm +%pair = call { <64 x float>, <64 x float> } @llvm.hivm.vldx2.v64f32( + ptr addrspace(6) %base, <1 x i32> %addr, i32 %dist, i32 0) + +declare { <64 x float>, <64 x float> } @llvm.hivm.vldx2.v64f32( + ptr addrspace(6) nocapture readonly, <1 x i32>, i32, i32) +``` + +Select the intrinsic suffix from the result vector ABI in Section 2.2. The +`v64f32` declaration above matches recovered Bisheng LLVM IR as +`@llvm.hivm.vldx2.v64f32`. + +Lowering operands: + +| LLVM operand | VPTO source | Value | +| --- | --- | --- | +| arg0 | `%base` | UB pointer converted to `ptr addrspace(6)` | +| arg1 | `%addr` | `!pto.vaddr` converted to `<1 x i32>` | +| arg2 | `DIST` | existing pair-load distribution enum code, passed as `i32` | +| arg3 | implicit loop mode | `i32 0`, matching `0 /* #loop */` in the CANN wrapper | + +Return extraction: + +| Aggregate index | VPTO result | Value | +| --- | --- | --- | +| 0 | `%low` | first loaded vector | +| 1 | `%high` | second loaded vector | + +The installed CANN wrapper stores the builtin result in `vector_x2_t` and +then assigns `dst0 = ret.val[0]`, `dst1 = ret.val[1]`; the VPTO lowering +preserves that order. + +### 4.3 `pto.vast` + +```mlir +pto.vast %value, %base[%addr], %mask {dist = "DIST"} + : !pto.vreg, !pto.ptr, !pto.vaddr, !pto.mask +``` + +Semantics: + +```text +store_vector(base + addr, value, mask, dist) +``` + +Verifier constraints: + +- `%base` must be a UB pointer or UB memref lowered to a VPTO pointer. +- `%addr` must be `!pto.vaddr`. +- `%mask` granularity must be legal for `%value` and `DIST`. +- `DIST` follows the existing `pto.vsts` distribution vocabulary. + +LLVM lowering: + +```llvm +call void @llvm.hivm.vstx1.v64f32( + <64 x float> %value, ptr addrspace(6) %base, <1 x i32> %addr, + i32 %dist, i32 0, <256 x i1> %mask) + +declare void @llvm.hivm.vstx1.v64f32( + <64 x float>, ptr addrspace(6) nocapture writeonly, + <1 x i32>, i32, i32, <256 x i1>) +``` + +Select the intrinsic suffix from the source vector ABI in Section 2.2. The +installed CANN 9.0.0 LLVM evidence also includes `v32i64`, `v64i32`, +`v128bf16`, `v128f16`, `v128i16`, and `v256i8` overloads with the same operand +list. + +Lowering operands: + +| LLVM operand | VPTO source | Value | +| --- | --- | --- | +| arg0 | `%value` | vector value converted to the selected `` LLVM vector ABI | +| arg1 | `%base` | UB pointer converted to `ptr addrspace(6)` | +| arg2 | `%addr` | `!pto.vaddr` converted to `<1 x i32>` | +| arg3 | `DIST` | existing `vsts` distribution enum code, passed as `i32` | +| arg4 | implicit loop mode | `i32 0`, matching `0 /* #loop */` in the CANN wrapper | +| arg5 | `%mask` | predicate converted to `<256 x i1>` | + +### 4.4 `pto.vastx2` + +```mlir +pto.vastx2 %low, %high, %base[%addr], "DIST", %mask + : !pto.vreg, !pto.vreg, !pto.ptr, + !pto.vaddr, !pto.mask +``` + +Semantics: + +```text +store_vector_pair(base + addr, low, high, mask, dist) +``` + +Verifier constraints: + +- Same base/address/mask constraints as `pto.vast`. +- `DIST` must be a distribution that consumes two source vector registers, + matching the existing `pto.vstsx2` distribution set. + +LLVM lowering: + +```llvm +call void @llvm.hivm.vstx2.v64i32( + <64 x i32> %low, <64 x i32> %high, ptr addrspace(6) %base, + <1 x i32> %addr, i32 %dist, i32 0, <256 x i1> %mask) + +declare void @llvm.hivm.vstx2.v64i32( + <64 x i32>, <64 x i32>, ptr addrspace(6) nocapture writeonly, + <1 x i32>, i32, i32, <256 x i1>) +``` + +Select the intrinsic suffix from the source vector ABI in Section 2.2, but only +for overloads accepted by the installed compiler. CANN 9.0.0 LLVM evidence +captures `v64i32`, `v128bf16`, `v128f16`, `v128i16`, and `v256i8` `vstx2` +overloads; it does not expose a `v64f32` store-pair overload even though the +source wrapper inventory contains a `vldx2_v64f32` load-pair builtin. + +Lowering operands: + +| LLVM operand | VPTO source | Value | +| --- | --- | --- | +| arg0 | `%low` | first source vector converted to the selected `` LLVM vector ABI | +| arg1 | `%high` | second source vector converted to the selected `` LLVM vector ABI | +| arg2 | `%base` | UB pointer converted to `ptr addrspace(6)` | +| arg3 | `%addr` | `!pto.vaddr` converted to `<1 x i32>` | +| arg4 | `DIST` | existing pair-store distribution enum code, passed as `i32` | +| arg5 | implicit loop mode | `i32 0`, matching `0 /* #loop */` in the CANN wrapper | +| arg6 | `%mask` | predicate converted to `<256 x i1>` | + +## 5. Predicate load/store + +### 5.1 `pto.pald` + +```mlir +%mask = pto.pald %base[%addr], "DIST" + : !pto.ptr, !pto.vaddr -> !pto.mask +``` + +Semantics: + +```text +mask = load_predicate(base + addr, dist) +``` + +Verifier constraints: + +- `%base` must be a UB pointer. The CCE wrapper uses `__ubuf__ uint32_t *`. +- `%addr` must be `!pto.vaddr`. +- `DIST` must be one of the predicate load distributions supported by the + existing predicate load surface. + +LLVM lowering: + +```llvm +%mask = call <256 x i1> @llvm.hivm.pld.b8( + ptr addrspace(6) %base, <1 x i32> %addr, i32 %dist, i32 0) + +declare <256 x i1> @llvm.hivm.pld.b8( + ptr addrspace(6) nocapture readonly, <1 x i32>, i32, i32) +``` + +Lowering operands: + +| LLVM operand | VPTO source | Value | +| --- | --- | --- | +| arg0 | `%base` | UB `uint32_t` pointer converted to `ptr addrspace(6)` | +| arg1 | `%addr` | `!pto.vaddr` converted to `<1 x i32>` | +| arg2 | `DIST` | predicate-load distribution enum code, passed as `i32` | +| arg3 | implicit loop mode | `i32 0`, matching `0 /* #loop */` in the CANN wrapper | + +The declaration above is captured in the local CANN 9.0.0 LLVM evidence. + +### 5.2 `pto.past` + +```mlir +pto.past %mask, %base[%addr], "DIST" + : !pto.mask, !pto.ptr, !pto.vaddr +``` + +Semantics: + +```text +store_predicate(base + addr, mask, dist) +``` + +Verifier constraints: + +- `%base` must be a UB pointer. The CCE wrapper uses `__ubuf__ uint32_t *`. +- `%addr` must be `!pto.vaddr`. +- `DIST` must be one of the predicate store distributions supported by the + existing predicate store surface. + +LLVM lowering: + +```llvm +call void @llvm.hivm.pst.b8( + <256 x i1> %mask, ptr addrspace(6) %base, <1 x i32> %addr, + i32 %dist, i32 0) + +declare void @llvm.hivm.pst.b8( + <256 x i1>, ptr addrspace(6) nocapture writeonly, + <1 x i32>, i32, i32) +``` + +Lowering operands: + +| LLVM operand | VPTO source | Value | +| --- | --- | --- | +| arg0 | `%mask` | predicate converted to `<256 x i1>` | +| arg1 | `%base` | UB `uint32_t` pointer converted to `ptr addrspace(6)` | +| arg2 | `%addr` | `!pto.vaddr` converted to `<1 x i32>` | +| arg3 | `DIST` | predicate-store distribution enum code, passed as `i32` | +| arg4 | implicit loop mode | `i32 0`, matching `0 /* #loop */` in the CANN wrapper | + +The installed CANN 9.0.0 Clang header maps the source-level vector-address form +to `__builtin_cce_pst_b8(src, base, offset, dist, 0 /* #loop */)`. The +declaration above follows the same argument order and UB base memory effect. + +## 6. Unaligned vector load/store + +### 6.1 `pto.valda` + +```mlir +%align = pto.valda %base[%addr] + : !pto.ptr, !pto.vaddr -> !pto.align +``` + +Semantics: + +```text +align = init_load_alignment(base + addr) +``` + +LLVM lowering: + +```llvm +%align = call <32 x i8> @llvm.hivm.vlda( + ptr addrspace(6) %base, <1 x i32> %addr, i32 0) + +declare <32 x i8> @llvm.hivm.vlda(ptr addrspace(6), <1 x i32>, i32) +``` + +Lowering operands: + +| LLVM operand | VPTO source | Value | +| --- | --- | --- | +| arg0 | `%base` | UB pointer converted to `ptr addrspace(6)` | +| arg1 | `%addr` | `!pto.vaddr` converted to `<1 x i32>` | +| arg2 | implicit loop mode | `i32 0`, matching `0 /* #loop */` in the CANN wrapper | + +Observed LLVM intrinsic family: + +```text +llvm.hivm.vlda +``` + +The declaration above is captured in the local CANN 9.0.0 LLVM evidence. + +### 6.2 `pto.valdu` + +```mlir +%value, %align_out, %addr_out = pto.valdu %base[%addr_in], + %align_in, %inc + : !pto.ptr, !pto.vaddr, !pto.align, i32 + -> !pto.vreg, !pto.align, !pto.vaddr +``` + +Semantics: + +```text +(value, align_out, addr_out) = + unaligned_load(base, addr_in, align_in, inc) +``` + +`inc` is a byte increment. `addr_out` represents the post-updated +`vector_address` value. + +Verifier constraints: + +- `%base` must be a UB pointer. +- `%addr_in` and `%addr_out` must have the same `!pto.vaddr` type. +- `%align_in` must be `!pto.align`. +- `%inc` must be `i32`. + +LLVM lowering: + +```llvm +%triple = call { <64 x float>, <32 x i8>, <1 x i32> } + @llvm.hivm.vldu.v300.v64f32( + ptr addrspace(6) %base, <1 x i32> %addr_in, + <32 x i8> %align_in, i32 %inc, i32 0) + +declare { <64 x float>, <32 x i8>, <1 x i32> } + @llvm.hivm.vldu.v300.v64f32( + ptr addrspace(6) nocapture readonly, <1 x i32>, + <32 x i8>, i32, i32) +``` + +Select the intrinsic suffix from the loaded vector ABI in Section 2.2. The +installed CANN 9.0.0 LLVM evidence also includes `v32i64`, `v64i32`, +`v128bf16`, `v128f16`, `v128i16`, and `v256i8` overloads with the same operand +list and aggregate return layout. + +The lowering extracts the loaded vector, updated alignment value, and updated +vector-address value from the returned aggregate. + +Lowering operands: + +| LLVM operand | VPTO source | Value | +| --- | --- | --- | +| arg0 | `%base` | UB pointer converted to `ptr addrspace(6)` | +| arg1 | `%addr_in` | input `!pto.vaddr` converted to `<1 x i32>` | +| arg2 | `%align_in` | input `!pto.align` converted to `<32 x i8>` | +| arg3 | `%inc` | byte increment, passed as `i32` | +| arg4 | implicit loop mode | `i32 0`, matching `0 /* #loop */` in the CANN wrapper | + +Return extraction: + +| Aggregate index | VPTO result | Value | +| --- | --- | --- | +| 0 | `%value` | loaded vector | +| 1 | `%align_out` | updated load-alignment state | +| 2 | `%addr_out` | post-updated vector-address offset token | + +### 6.3 `pto.vasta` + +```mlir +pto.vasta %align, %base[%addr] + : !pto.align, !pto.ptr, !pto.vaddr +``` + +Semantics: + +```text +flush_store_alignment(base + addr, align) +``` + +LLVM lowering: + +```llvm +call void @llvm.hivm.vsta( + <32 x i8> %align, ptr addrspace(6) %base, + <1 x i32> %addr, i32 0) + +declare void @llvm.hivm.vsta( + <32 x i8>, ptr addrspace(6) nocapture writeonly, <1 x i32>, i32) +``` + +Lowering operands: + +| LLVM operand | VPTO source | Value | +| --- | --- | --- | +| arg0 | `%align` | store-alignment state converted to `<32 x i8>` | +| arg1 | `%base` | UB pointer converted to `ptr addrspace(6)` | +| arg2 | `%addr` | `!pto.vaddr` converted to `<1 x i32>` | +| arg3 | implicit loop mode | `i32 0`, matching `0 /* #loop */` in the CANN wrapper | + +### 6.4 `pto.vastu` + +```mlir +%align_out, %addr_out = pto.vastu %align_in, %addr_in, %value, + %base, "POST_UPDATE" + : !pto.align, !pto.vaddr, !pto.vreg, !pto.ptr + -> !pto.align, !pto.vaddr +``` + +Semantics: + +```text +(align_out, addr_out) = + unaligned_store_post_update(base, addr_in, align_in, value) +``` + +`pto.vastu` is the vector-address stateful unaligned store. Its address state +is the `!pto.vaddr` offset token, not a post-updated base pointer. The A5 +CANN 9.0.0 wrapper accepts only `POST_UPDATE` for this vector-address form, so +the VPTO op exposes only `"POST_UPDATE"` rather than an arbitrary integer post +amount. A no-post scalar-offset store uses the scalar-offset stateful store +family instead. + +Verifier constraints: + +- `%base` must be a UB pointer. +- `%addr_in` and `%addr_out` must have the same `!pto.vaddr` type. +- `%align_in` and `%align_out` must be `!pto.align`. +- The mode attribute must be `"POST_UPDATE"` on A5. + +LLVM lowering: + +```llvm +%pair = call { <32 x i8>, <1 x i32> } @llvm.hivm.vstu.v64f32( + <64 x float> %value, ptr addrspace(6) %base, <1 x i32> %addr_in, + <32 x i8> %align_in, i32 1, i32 0) +%align_out = extractvalue { <32 x i8>, <1 x i32> } %pair, 0 +%addr_out = extractvalue { <32 x i8>, <1 x i32> } %pair, 1 + +declare { <32 x i8>, <1 x i32> } @llvm.hivm.vstu.v64f32( + <64 x float>, ptr addrspace(6) nocapture writeonly, + <1 x i32>, <32 x i8>, i32, i32) +``` + +Select the intrinsic suffix from the source vector ABI in Section 2.2. The +installed compiler canonicalizes a handwritten suffix-less `@llvm.hivm.vstu` +declaration to the typed textual IR name `@llvm.hivm.vstu.v64f32` for the f32 +case. + +Lowering operands: + +| LLVM operand | VPTO source | Value | +| --- | --- | --- | +| arg0 | `%value` | source vector converted to the selected `` LLVM vector ABI | +| arg1 | `%base` | UB pointer converted to `ptr addrspace(6)` | +| arg2 | `%addr_in` | input `!pto.vaddr` converted to `<1 x i32>` | +| arg3 | `%align_in` | input `!pto.align` converted to `<32 x i8>` | +| arg4 | `"POST_UPDATE"` | `i32 1`, matching `1 /*post update mode*/` in the CANN wrapper | +| arg5 | implicit loop mode | `i32 0`, matching `0 /*loop*/` in the CANN wrapper | + +Return extraction: + +| Aggregate index | VPTO result | Value | +| --- | --- | --- | +| 0 | `%align_out` | updated store-alignment state | +| 1 | `%addr_out` | post-updated vector-address offset token | + +The operand order and result aggregate are taken from the installed CANN 9.0.0 +Clang wrapper for `vstu`: it creates a return object with `alignData` followed +by `offset`, then calls `__builtin_cce_vstu_(&ret, src, base, offset, +alignData, 1 /*post update mode*/, 0 /*loop*/)`. The declaration above follows +the same aggregate extraction order. + +The local `pto-isa` copy under +`/home/mouliangyu/projects/github.com/hw-native-sys/pypto/build_output/_deps/pto-isa` +documents the public `pto.vstu` surface as an align-plus-offset state update: +`%align_out, %offset_out = pto.vstu %align_in, %offset_in, %value, %base, +"MODE"`. This matches the VPTO `!pto.vaddr` result contract and confirms +that `vastu` threads offset state, not a full pointer. + +## 7. Excluded scalar-offset store forms + +### 7.1 `VSSTB` + +`VSSTB` is not included in this vector-address support set. VISA lists it as a +deprecated scalar-addressing instruction: + +```text +VSSTB.type Vd, [Sn], Sm, Pg, #p +``` + +The installed CANN 9.0.0 Clang wrapper also takes a scalar offset, not a +`vector_address` value: + +```cpp +void vsstb(vector_ST data, __ubuf__ LT *base, int32_t offset, + vector_bool mask) +``` + +The compiler-facing LLVM family follows the same scalar-offset contract: + +```llvm +declare void @llvm.hivm.vsstb.S(V, ptr addrspace(6), i32, i32, <256 x i1>) +``` + +An attempted `pto.vasst %base[%vaddr]` wrapper would have to extract a scalar +lane from the `<1 x i32>` vector-address ABI value inside the SIMD-VF loop. +Bisheng rejects that shape during AIV object generation with "Unsupported +scalar instruction in AIV loop". Use the existing scalar-offset `pto.vsstb` +operation for this family instead of adding a vector-address alias. + +No matching A5 `vsld` vector-address wrapper was found in the local CANN 9.0.0 +inventory, so a `pto.vasld` op is intentionally not included either. + +## 8. Lowering ownership + +The VPTO implementation lowers these ops in the vector LLVM emitter. The +normative target is direct LLVM IR, not source-level wrapper emission: + +- Use the installed CANN headers as the source of truth for wrapper semantics, + operand order, units, and fixed constants. +- Use generated LLVM IR from the current toolchain as the source of truth for + exact `llvm.hivm.*` function types. Treat `strings bisheng` only as an + intrinsic-name inventory. +- Reuse the existing VPTO pointer, vector, mask, and align type conversion + helpers where possible. +- Add a dedicated conversion for `!pto.vaddr` to the compiler ABI used for + `vector_address`; local probes show the vector-address load/store families + accepting a `<1 x i32>` operand. +- Do not lower `!pto.vaddr` through pointer arithmetic or `pto.addptr`. + +Current implementation status: + +- `vald`, `valdx2`, `vast`, `vastx2`, `pald`, `past`, `valda`, `valdu`, + `vasta`, and `vastu` lower directly to the typed LLVM intrinsic signatures + listed in Section 2.5. +- `vag` lowers to the no-IV `@llvm.hivm.vag.32(i32, i32, i32, i32)` builtin + form. It must be written under an `i16` `scf.for` so CCE middle-end VAG + lowering can associate the builtin with loop state. +- A `!pto.vaddr` value may be shared by normal non-update vector-address + operations, but must not be used as the `addr_in` seed of multiple update + chains such as `valdu` and `vastu`. +- The current VPTO lowering supports one to four byte-stride operands and pads + inactive dimensions with `i32 0`. + +## 9. Initial implementation checklist + +1. Add `VAddrType` with `b8`/`b16`/`b32` granularity. +2. Add ODS definitions for: + - `pto.vag` + - `pto.vald` + - `pto.valdx2` + - `pto.vast` + - `pto.vastx2` + - `pto.pald` + - `pto.past` + - `pto.valda` + - `pto.valdu` + - `pto.vasta` + - `pto.vastu` +3. Add verifiers for vaddr granularity, UB pointer spaces, distribution + legality, mask compatibility, and stateful result type equality. +4. Add lowering patterns and targeted lit tests that check emitted intrinsic + calls or wrapper-equivalent LLVM IR. +5. Add end-to-end micro-op coverage as runtime-safe cases become available. + The initial simulator cases cover `vag`, aligned `vald`/`vast`, + pair `valdx2`/`vastx2`, predicate `pald`/`past`, and stateful unaligned + `valda`/`valdu`/`vastu`/`vasta`. diff --git a/docs/isa/micro-isa/18-vaddr-loop-memory.md b/docs/isa/micro-isa/18-vaddr-loop-memory.md new file mode 100644 index 0000000000..d7b2c292c0 --- /dev/null +++ b/docs/isa/micro-isa/18-vaddr-loop-memory.md @@ -0,0 +1,372 @@ +# 18. Vector-Address Loop Memory + +> **Category:** loop-derived UB ↔ Vector/Predicate data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector-address memory ops use a loop-derived address token together with an +explicit UB base pointer. They are useful when a vector loop needs address +progression that is tied to the loop counter rather than to scalar pointer +arithmetic. + +This chapter documents the PTO surface contract for `!pto.vaddr` and the +`pto.vag` / `pto.vald` / `pto.vast` / `pto.pald` / `pto.past` families. + +--- + +## Common Operand Model + +- `%base` is the explicit UB base pointer. It MUST have type `!pto.ptr`. +- `%addr` is a `!pto.vaddr` vector-address offset token. It is not a pointer. +- The effective UB address is logically `base + addr`. +- `%mask` is a predicate mask used by predicated vector stores. +- `!pto.align` is the explicit SSA carrier for unaligned load/store state. + +`!pto.vaddr` supports these granularities: + +```mlir +!pto.vaddr +!pto.vaddr +!pto.vaddr +``` + +The granularity records the address family used by the vector-address producer. +It does not replace the element type of the explicit `%base` pointer. + +--- + +## Loop Form + +`pto.vag` MUST be nested under an `i16` `scf.for` loop. The loop type is part of +the valid vector-address form. + +Use the explicit non-`index` loop type marker: + +```mlir +%c0_i16 = arith.constant 0 : i16 +%c1_i16 = arith.constant 1 : i16 +%c2_i16 = arith.constant 2 : i16 + +pto.vecscope { + scf.for %i = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + %addr = pto.vag %stride : i32 -> !pto.vaddr + // vector-address memory ops using %addr + } +} +``` + +An `index` loop is not valid for `pto.vag`: + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%c2 = arith.constant 2 : index + +// Invalid for pto.vag. +scf.for %i = %c0 to %c2 step %c1 { + %addr = pto.vag %stride : i32 -> !pto.vaddr +} +``` + +`pto.vag` does not create or outline a loop. The surrounding loop must already +exist in the PTO IR. + +--- + +## Address Generation + +### `pto.vag` + +- **syntax:** `%addr = pto.vag %s0 : i32 -> !pto.vaddr` +- **syntax:** `%addr = pto.vag %s0, %s1 : i32, i32 -> !pto.vaddr` +- **syntax:** `%addr = pto.vag %s0, %s1, %s2 : i32, i32, i32 -> !pto.vaddr` +- **syntax:** `%addr = pto.vag %s0, %s1, %s2, %s3 : i32, i32, i32, i32 -> !pto.vaddr` +- **semantics:** Create a vector-address offset value for the surrounding loop. +- **inputs:** + `%s0` ... `%s3` are byte strides for active loop dimensions. +- **outputs:** + `%addr` is a `!pto.vaddr` offset token. +- **constraints and limitations:** + `pto.vag` takes one to four `i32` byte-stride operands. It MUST be nested in + an `i16` `scf.for`. The result granularity MUST be `b8`, `b16`, or `b32`. + +**Example:** + +```mlir +%stride = arith.constant 4 : i32 +%addr = pto.vag %stride : i32 -> !pto.vaddr +``` + +--- + +## Vector-Address Vector Loads + +### `pto.vald` + +- **syntax:** `%result = pto.vald %source[%addr] {dist = "DIST"} : !pto.ptr, !pto.vaddr -> !pto.vreg` +- **semantics:** Load a vector register from UB using a vector-address offset. +- **inputs:** + `%source` is the UB base pointer, `%addr` is the vector-address offset, and + `DIST` selects the load distribution mode. +- **outputs:** + `%result` is the loaded vector register. +- **constraints and limitations:** + `%source` MUST be UB-backed. `%addr` MUST be `!pto.vaddr<...>`. The result + element type MUST match the source pointer element type. + +Supported `DIST` values: + +| Family | Notes | +|--------|-------| +| `NORM` | normal vector-address load | +| `BRC_B8` / `BRC_B16` / `BRC_B32` | broadcast family | +| `US_B8` / `US_B16` | upsample family | +| `DS_B8` / `DS_B16` | downsample family | +| `UNPK_B8` / `UNPK_B16` / `UNPK_B32` | unpack family | +| `BRC_BLK` | block broadcast path | +| `E2B_B16` / `E2B_B32` | element-to-byte expansion family | +| `UNPK4` | `b8` unpack-4 family | +| `SPLT4CHN` | `b8` split-channel family | +| `SPLT2CHN_B8` / `SPLT2CHN_B16` | 2-channel split family | + +**Example:** + +```mlir +%v = pto.vald %ub[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> +``` + +### `pto.valdx2` + +- **syntax:** `%lo, %hi = pto.valdx2 %source[%addr], "DIST" : !pto.ptr, !pto.vaddr -> !pto.vreg, !pto.vreg` +- **semantics:** Load two vector registers from UB using a vector-address offset + and a deinterleave-style distribution. +- **inputs:** `%source` is the UB base pointer, `%addr` is the vector-address + offset, and `DIST` selects the x2 load distribution mode. +- **outputs:** `%lo` and `%hi` are the two loaded vector registers. +- **constraints and limitations:** + `%lo` and `%hi` MUST have the same vector type. `%source` MUST be UB-backed. + +Supported `DIST` values: + +```text +BDINTLV +DINTLV_B8, DINTLV_B16, DINTLV_B32 +``` + +--- + +## Vector-Address Vector Stores + +### `pto.vast` + +- **syntax:** `pto.vast %value, %dest[%addr], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.vaddr, !pto.mask` +- **semantics:** Store a vector register to UB using a vector-address offset and + a predicate mask. +- **inputs:** + `%value` is the vector to store, `%dest` is the UB base pointer, `%addr` is + the vector-address offset, `%mask` controls active lanes, and `DIST` selects + the store distribution mode. +- **constraints and limitations:** + `%dest` MUST be UB-backed. `%addr` MUST be `!pto.vaddr<...>`. + +Supported `DIST` values: + +```text +NORM_B8, NORM_B16, NORM_B32 +1PT_B8, 1PT_B16, 1PT_B32 +PK_B16, PK_B32, PK_B64 +PK4_B32 +MRG4CHN_B8 +MRG2CHN_B8, MRG2CHN_B16 +INTLV_B8, INTLV_B16, INTLV_B32 +``` + +### `pto.vastx2` + +- **syntax:** `pto.vastx2 %lo, %hi, %dest[%addr], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, !pto.vaddr, !pto.mask` +- **semantics:** Store two vector registers to UB using a vector-address offset + and an x2 store distribution. +- **inputs:** `%lo` and `%hi` are the vectors to store, `%dest` is the UB base, + `%addr` is the vector-address offset, and `%mask` controls active lanes. +- **constraints and limitations:** + `%lo` and `%hi` MUST have the same vector type. `%dest` MUST be UB-backed. + +Supported `DIST` values: + +```text +INTLV_B8, INTLV_B16, INTLV_B32 +``` + +--- + +## Vector-Address Predicate Loads And Stores + +### `pto.pald` + +- **syntax:** `%mask = pto.pald %source[%addr], "DIST" : !pto.ptr, !pto.vaddr -> !pto.mask` +- **semantics:** Load a predicate mask from UB using a vector-address offset. +- **DIST:** mandatory token, one of `NORM`, `US`, or `DS`. +- **constraints and limitations:** + `%source` MUST be UB-backed. `%addr` MUST be `!pto.vaddr<...>`. + +### `pto.past` + +- **syntax:** `pto.past %mask, %dest[%addr], "DIST" : !pto.mask, !pto.ptr, !pto.vaddr` +- **semantics:** Store a predicate mask to UB using a vector-address offset. +- **DIST:** mandatory token, one of `NORM` or `PK`. +- **constraints and limitations:** + `%dest` MUST be UB-backed. `%addr` MUST be `!pto.vaddr<...>`. + +--- + +## Vector-Address Unaligned Update Chains + +Unaligned vector-address ops carry explicit alignment state. Some forms return +an updated vector address, which makes the address value part of an update +chain. + +The update-chain seed rule is: + +> One `!pto.vaddr` SSA value MUST NOT be used as the `addr_in` seed of multiple +> update chains. + +Normal non-update vector-address operations may share one `!pto.vaddr`. The +restriction applies to `pto.valdu` and `pto.vastu` `addr_in` operands. + +### `pto.valda` + +- **syntax:** `%align = pto.valda %source[%addr] : !pto.ptr, !pto.vaddr -> !pto.align` +- **semantics:** Prime load-side alignment state for a later vector-address + unaligned load. +- **constraints and limitations:** + `%source` MUST be UB-backed. `%addr` MUST be `!pto.vaddr<...>`. + +### `pto.valdu` + +- **syntax:** `%value, %align_out, %addr_out = pto.valdu %source[%addr_in], %align_in, %inc : !pto.ptr, !pto.vaddr, !pto.align, i32 -> !pto.vreg, !pto.align, !pto.vaddr` +- **semantics:** Unaligned vector-address load using incoming alignment state. +- **inputs:** + `%source` is the UB base pointer, `%addr_in` is the input vector address, + `%align_in` is the input alignment state, and `%inc` is a byte increment used + to compute `%addr_out`. +- **outputs:** + `%value` is the loaded vector, `%align_out` is the updated alignment state, + and `%addr_out` is the updated vector address. +- **constraints and limitations:** + `%addr_in` and `%addr_out` MUST have the same `!pto.vaddr` type. `%inc` MUST + be `i32`. The same `%addr_in` value MUST NOT seed another update chain. + +### `pto.init_align` + +- **syntax:** `%align = pto.init_align : !pto.align` +- **semantics:** Create a fresh store-side alignment state. + +### `pto.vastu` + +- **syntax:** `%align_out, %addr_out = pto.vastu %align_in, %addr_in, %value, %base, "POST_UPDATE" : !pto.align, !pto.vaddr, !pto.vreg, !pto.ptr -> !pto.align, !pto.vaddr` +- **semantics:** Unaligned vector-address store using incoming alignment and + address state. +- **inputs:** + `%align_in` is the incoming store alignment state, `%addr_in` is the input + vector address, `%value` is the vector to store, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` and `%addr_out` are the updated store state. +- **constraints and limitations:** + The mode MUST be `"POST_UPDATE"`. `%base` MUST be UB-backed. `%addr_in` and + `%addr_out` MUST have the same `!pto.vaddr` type. The same `%addr_in` value + MUST NOT seed another update chain. + +### `pto.vasta` + +- **syntax:** `pto.vasta %align, %base[%addr] : !pto.align, !pto.ptr, !pto.vaddr` +- **semantics:** Complete a vector-address unaligned store chain by consuming + the final alignment and address state. +- **constraints and limitations:** + `%base` MUST be UB-backed. `%align` should be produced by `pto.vastu` or a + compatible store-align producer. + +--- + +## Sharing Patterns + +### Valid normal sharing + +```mlir +%addr = pto.vag %stride : i32 -> !pto.vaddr +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%v = pto.vald %src[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> +pto.vast %v, %dst[%addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask +``` + +### Invalid update-chain fanout + +```mlir +%addr = pto.vag %stride : i32 -> !pto.vaddr +%load_align = pto.valda %src[%addr] + : !pto.ptr, !pto.vaddr -> !pto.align +%v, %next_load_align, %next_load_addr = + pto.valdu %src[%addr], %load_align, %inc + : !pto.ptr, !pto.vaddr, !pto.align, i32 + -> !pto.vreg<64xf32>, !pto.align, !pto.vaddr +%store_align = pto.init_align : !pto.align + +// Invalid: %addr also seeds a store update chain. +%next_store_align, %next_store_addr = + pto.vastu %store_align, %addr, %v, %dst, "POST_UPDATE" + : !pto.align, !pto.vaddr, !pto.vreg<64xf32>, !pto.ptr + -> !pto.align, !pto.vaddr +``` + +### Valid independent update chains + +```mlir +%load_addr = pto.vag %stride : i32 -> !pto.vaddr +%store_addr = pto.vag %stride : i32 -> !pto.vaddr + +%load_align = pto.valda %src[%load_addr] + : !pto.ptr, !pto.vaddr -> !pto.align +%v, %next_load_align, %next_load_addr = + pto.valdu %src[%load_addr], %load_align, %inc + : !pto.ptr, !pto.vaddr, !pto.align, i32 + -> !pto.vreg<64xf32>, !pto.align, !pto.vaddr + +%store_align = pto.init_align : !pto.align +%next_store_align, %next_store_addr = + pto.vastu %store_align, %store_addr, %v, %dst, "POST_UPDATE" + : !pto.align, !pto.vaddr, !pto.vreg<64xf32>, !pto.ptr + -> !pto.align, !pto.vaddr +pto.vasta %next_store_align, %dst[%next_store_addr] + : !pto.align, !pto.ptr, !pto.vaddr +``` + +--- + +## Complete Example + +```mlir +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vaddr_loop_copy(%src: !pto.ptr, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0_i16 = arith.constant 0 : i16 + %c1_i16 = arith.constant 1 : i16 + %c2_i16 = arith.constant 2 : i16 + %stride = arith.constant 4 : i32 + + pto.vecscope { + scf.for %i = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + %addr = pto.vag %stride : i32 -> !pto.vaddr + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %value = pto.vald %src[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %value, %dst[%addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + } + } + + return + } +} +``` diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 253e36a22a..14d34d67bf 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -36,6 +36,8 @@ def PTO_B32MaskTypeConstraint : Type< "PTO low-level b32 mask type">; def PTO_AlignTypeConstraint : Type($_self)">, "PTO low-level align type">; +def PTO_VAddrTypeConstraint : Type($_self)">, + "PTO low-level vector_address offset type">; def PTO_BufferType : Type< CPred<"::llvm::isa<::mlir::pto::PtrType>($_self)">, @@ -1348,6 +1350,52 @@ def PTO_VldsOp : PTO_Op<"vlds", [ }]; } +def PTO_VagOp : PTO_Op<"vag"> { + let arguments = (ins Variadic:$strides); + let results = (outs PTO_VAddrTypeConstraint:$addr); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $strides attr-dict `:` type($strides) `->` type($addr) + }]; +} + +def PTO_ValdOp : PTO_Op<"vald", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_VAddrTypeConstraint:$addr, + OptionalAttr:$dist + ); + + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $addr `]` attr-dict `:` type($source) `,` type($addr) `->` type($result) + }]; +} + +def PTO_Valdx2Op : PTO_Op<"valdx2", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_VAddrTypeConstraint:$addr, + StrAttr:$dist + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $addr `]` `,` $dist attr-dict `:` type($source) `,` type($addr) `->` type($low) `,` type($high) + }]; +} + def PTO_Vldsx2Op : PTO_Op<"vldsx2", [ DeclareOpInterfaceMethods ]> { @@ -1381,6 +1429,23 @@ def PTO_VldasOp : PTO_Op<"vldas", [ }]; } +def PTO_ValdaOp : PTO_Op<"valda", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_VAddrTypeConstraint:$addr + ); + + let results = (outs PTO_AlignTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $addr `]` attr-dict `:` type($source) `,` type($addr) `->` type($result) + }]; +} + def PTO_InitAlignOp : PTO_Op<"init_align", []> { let arguments = (ins); @@ -1461,6 +1526,29 @@ def PTO_VldusOp : PTO_Op<"vldus", [ }]; } +def PTO_ValduOp : PTO_Op<"valdu", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_VAddrTypeConstraint:$addr_in, + PTO_AlignTypeConstraint:$align_in, + I32:$inc + ); + + let results = (outs + PTO_VectorType:$result, + PTO_AlignTypeConstraint:$align_out, + PTO_VAddrTypeConstraint:$addr_out + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $addr_in `]` `,` $align_in `,` $inc attr-dict `:` type($source) `,` type($addr_in) `,` type($align_in) `,` type($inc) `->` type($result) `,` type($align_out) `,` type($addr_out) + }]; +} + def PTO_UvldOp : PTO_Op<"uvld", [ DeclareOpInterfaceMethods ]> { @@ -1758,6 +1846,23 @@ def PTO_PldsOp : PTO_Op<"plds", [ }]; } +def PTO_PaldOp : PTO_Op<"pald", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_VAddrTypeConstraint:$addr, + StrAttr:$dist + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $addr `]` `,` $dist attr-dict `:` type($source) `,` type($addr) `->` type($result) + }]; +} + def PTO_PldiOp : PTO_Op<"pldi", [ DeclareOpInterfaceMethods ]> { @@ -2539,6 +2644,26 @@ def PTO_VstsOp : PTO_Op<"vsts", [ }]; } +def PTO_VastOp : PTO_Op<"vast", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferLikeType:$destination, + PTO_VAddrTypeConstraint:$addr, + OptionalAttr:$dist, + PTO_MaskTypeConstraint:$mask + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $addr `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($addr) `,` type($mask) + }]; +} + def PTO_VscatterOp : PTO_Op<"vscatter", [ DeclareOpInterfaceMethods ]> { @@ -2577,6 +2702,25 @@ def PTO_PstsOp : PTO_Op<"psts", [ }]; } +def PTO_PastOp : PTO_Op<"past", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_MaskTypeConstraint:$value, + PTO_BufferLikeType:$destination, + PTO_VAddrTypeConstraint:$addr, + StrAttr:$dist + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $addr `]` `,` $dist attr-dict `:` type($value) `,` type($destination) `,` type($addr) + }]; +} + def PTO_CopyUbufToGmOp : PTO_Op<"copy_ubuf_to_gm", [ DeclareOpInterfaceMethods ]> { @@ -3302,6 +3446,26 @@ def PTO_Vstsx2Op : PTO_Op<"vstsx2", [ }]; } +def PTO_Vastx2Op : PTO_Op<"vastx2", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$low, + PTO_VectorType:$high, + PTO_BufferLikeType:$destination, + PTO_VAddrTypeConstraint:$addr, + StrAttr:$dist, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $low `,` $high `,` $destination `[` $addr `]` `,` $dist `,` $mask attr-dict `:` type($low) `,` type($high) `,` type($destination) `,` type($addr) `,` type($mask) + }]; +} + def PTO_VsldbOp : PTO_Op<"vsldb", [ DeclareOpInterfaceMethods ]> { @@ -3357,6 +3521,23 @@ def PTO_VstasOp : PTO_Op<"vstas", [ }]; } +def PTO_VastaOp : PTO_Op<"vasta", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$value, + PTO_BufferLikeType:$destination, + PTO_VAddrTypeConstraint:$addr + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $addr `]` attr-dict `:` type($value) `,` type($destination) `,` type($addr) + }]; +} + def PTO_VstarOp : PTO_Op<"vstar", [ DeclareOpInterfaceMethods ]> { @@ -3413,6 +3594,26 @@ def PTO_VstusOp : PTO_Op<"vstus", [ }]; } +def PTO_VastuOp : PTO_Op<"vastu", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$align_in, + PTO_VAddrTypeConstraint:$addr_in, + PTO_VectorType:$value, + PTO_BufferLikeType:$base, + StrAttr:$mode + ); + let results = (outs PTO_AlignTypeConstraint:$align_out, + PTO_VAddrTypeConstraint:$addr_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $align_in `,` $addr_in `,` $value `,` $base `,` $mode attr-dict `:` type($align_in) `,` type($addr_in) `,` type($value) `,` type($base) `->` type($align_out) `,` type($addr_out) + }]; +} + def PTO_VsturOp : PTO_Op<"vstur", [ DeclareOpInterfaceMethods ]> { diff --git a/include/PTO/IR/VPTOTypeDefs.td b/include/PTO/IR/VPTOTypeDefs.td index ed62e7655d..ac98ceeb75 100644 --- a/include/PTO/IR/VPTOTypeDefs.td +++ b/include/PTO/IR/VPTOTypeDefs.td @@ -53,6 +53,26 @@ def MaskType : TypeDef { }]; } +def VAddrType : TypeDef { + let mnemonic = "vaddr"; + let summary = "A PTO low-level vector_address offset token"; + + let parameters = (ins + StringRefParameter<"address granularity view">:$granularity + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + static bool isSupportedGranularity(::llvm::StringRef granularity); + + bool isB8() const { return getGranularity() == "b8"; } + bool isB16() const { return getGranularity() == "b16"; } + bool isB32() const { return getGranularity() == "b32"; } + }]; +} + def AlignType : TypeDef { let mnemonic = "align"; let summary = "A PTO low-level vector_align carrier"; diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 4d9ae3401d..02a533ef94 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -61,6 +61,13 @@ static std::string formatMaskType(StringRef granularity) { return storage; } +static std::string formatVAddrType(StringRef granularity) { + std::string storage; + llvm::raw_string_ostream os(storage); + os << "!pto.vaddr<" << granularity << ">"; + return storage; +} + static LogicalResult verifyVRegTypeLike(Operation *op, Type type, StringRef roleDescription) { auto vecType = dyn_cast(type); @@ -92,6 +99,37 @@ static LogicalResult verifyMaskTypeWithGranularityLike(Operation *op, Type type, return success(); } +static LogicalResult verifyVAddrTypeLike(Operation *op, Type type, + StringRef roleDescription) { + if (!isa(type)) + return op->emitOpError() << roleDescription << " must be !pto.vaddr<...>"; + return success(); +} + +static bool isVectorAddressUpdateSeedUse(OpOperand &use) { + Operation *owner = use.getOwner(); + if (auto op = dyn_cast(owner)) + return &use == &op.getAddrInMutable(); + if (auto op = dyn_cast(owner)) + return &use == &op.getAddrInMutable(); + return false; +} + +static LogicalResult verifySingleVectorAddressUpdateSeedUse(Operation *op, + Value addr, + StringRef role) { + unsigned updateSeedUses = 0; + for (OpOperand &use : addr.getUses()) { + if (isVectorAddressUpdateSeedUse(use)) + ++updateSeedUses; + } + if (updateSeedUses > 1) + return op->emitOpError() + << role + << " must not seed multiple vector-address update chains"; + return success(); +} + static LogicalResult verifyVPTOScalarAccessTypes(Operation *op, Type ptrTy, Type valueTy, StringRef opNameForDiag) { @@ -814,15 +852,15 @@ static bool isSupportedVtrcRoundMode(StringRef mode) { } static bool isStoreAlignProducer(Operation *op) { - return isa(op); + return isa(op); } static bool isStoreAlignSink(Operation *op) { - return isa(op); + return isa(op); } static bool isLoadAlignProducer(Operation *op) { - return isa(op); + return isa(op); } static scf::IfOp getEnclosingBranchIf(Operation *op) { @@ -887,6 +925,10 @@ static FailureOr resolveStoreAlignRootImpl( current = stateOp.getAlignIn(); continue; } + if (auto stateOp = dyn_cast(def)) { + current = stateOp.getAlignIn(); + continue; + } if (auto forOp = dyn_cast(def)) { auto result = dyn_cast(current); if (!result) @@ -987,6 +1029,11 @@ static LogicalResult verifyStoreAlignLinearUses(Value value, Operation *user) { branchUsers.push_back(owner); continue; } + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getAlignOut()); + branchUsers.push_back(owner); + continue; + } if (auto forOp = dyn_cast(owner)) { unsigned firstInitArg = forOp.getNumControlOperands(); if (use.getOperandNumber() < firstInitArg) @@ -1111,10 +1158,16 @@ static FailureOr resolveLoadAlignRootImpl( if (Operation *def = current.getDefiningOp()) { if (isa(def)) return current; + if (isa(def)) + return current; if (auto stateOp = dyn_cast(def)) { current = stateOp.getAlign(); continue; } + if (auto stateOp = dyn_cast(def)) { + current = stateOp.getAlignIn(); + continue; + } if (auto forOp = dyn_cast(def)) { auto result = dyn_cast(current); if (!result) @@ -1188,6 +1241,11 @@ static LogicalResult verifyLoadAlignLinearUses(Value value, Operation *user) { branchUsers.push_back(owner); continue; } + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getAlignOut()); + branchUsers.push_back(owner); + continue; + } if (auto forOp = dyn_cast(owner)) { unsigned firstInitArg = forOp.getNumControlOperands(); if (use.getOperandNumber() < firstInitArg) { @@ -3387,6 +3445,34 @@ MaskType::verify(function_ref emitError, return success(); } +bool VAddrType::isSupportedGranularity(StringRef granularity) { + return granularity == "b8" || granularity == "b16" || + granularity == "b32"; +} + +Type VAddrType::parse(AsmParser &parser) { + auto loc = parser.getCurrentLocation(); + StringRef granularity; + if (failed(parser.parseLess()) || failed(parser.parseKeyword(&granularity)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), granularity); +} + +void VAddrType::print(AsmPrinter &printer) const { + printer << "<" << getGranularity() << ">"; +} + +LogicalResult +VAddrType::verify(function_ref emitError, + StringRef granularity) { + if (!isSupportedGranularity(granularity)) + return emitError() << "'" << formatVAddrType(granularity) + << "' expected granularity to be one of b8, b16, b32"; + return success(); +} + void CopyGmToUbufOp::getEffects( SmallVectorImpl> &effects) { @@ -4628,6 +4714,59 @@ LogicalResult VldsOp::verify() { } return success(); } + +LogicalResult VagOp::verify() { + if (getStrides().empty() || getStrides().size() > 4) + return emitOpError("requires one to four i32 byte stride operands"); + for (Value stride : getStrides()) { + if (!stride.getType().isInteger(32)) + return emitOpError("requires all stride operands to be i32"); + } + auto forOp = (*this)->getParentOfType(); + if (!forOp) + return emitOpError("must be nested under scf.for"); + if (!forOp.getInductionVar().getType().isInteger(16)) + return emitOpError("requires enclosing scf.for induction variable to be i16"); + if (failed(verifyVAddrTypeLike(*this, getAddr().getType(), "result type"))) + return failure(); + return success(); +} + +void ValdOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult ValdOp::verify() { + if (failed(verifyVldsCommon(*this))) + return failure(); + if (failed(verifyVAddrTypeLike(*this, getAddr().getType(), "addr type"))) + return failure(); + return success(); +} + +void Valdx2Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult Valdx2Op::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (failed(verifyVAddrTypeLike(*this, getAddr().getType(), "addr type")) || + failed(verifyVRegTypeLike(*this, getLow().getType(), "low result type")) || + failed(verifyVRegTypeLike(*this, getHigh().getType(), "high result type"))) + return failure(); + if (getLow().getType() != getHigh().getType()) + return emitOpError("requires low/high results to share one vector type"); + if (!isSupportedVldx2DistToken(getDist())) + return emitOpError("requires a supported x2 load distribution token"); + return success(); +} void VldasOp::getEffects( SmallVectorImpl> &effects) { @@ -4644,10 +4783,56 @@ LogicalResult VldasOp::verify() { return success(); } +void ValdaOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult ValdaOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (failed(verifyVAddrTypeLike(*this, getAddr().getType(), "addr type")) || + failed(verifyAlignTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + return success(); +} + LogicalResult InitAlignOp::verify() { return verifyAlignTypeLike(*this, getResult().getType(), "result type"); } +void ValduOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult ValduOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (failed(verifyVAddrTypeLike(*this, getAddrIn().getType(), "addr_in type")) || + failed(verifyVAddrTypeLike(*this, getAddrOut().getType(), "addr_out type")) || + failed(verifyAlignTypeLike(*this, getAlignIn().getType(), "align_in type")) || + failed(verifyAlignTypeLike(*this, getAlignOut().getType(), "align_out type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getAddrIn().getType() != getAddrOut().getType()) + return emitOpError("requires addr_in and addr_out to have identical types"); + if (getAlignIn().getType() != getAlignOut().getType()) + return emitOpError("requires align_in and align_out to have identical types"); + if (!getInc().getType().isInteger(32)) + return emitOpError("requires inc to be i32"); + if (failed(verifySingleVectorAddressUpdateSeedUse(*this, getAddrIn(), + "addr_in"))) + return failure(); + return success(); +} + LogicalResult SprclrOp::verify() { if (!isSupportedSprToken(getSpr())) return emitOpError("requires spr to be \"AR\""); @@ -5081,6 +5266,25 @@ LogicalResult PldsOp::verify() { return success(); } +void PaldOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult PaldOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (failed(verifyVAddrTypeLike(*this, getAddr().getType(), "addr type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (!isSupportedPredicateLoadDist(getDist())) + return emitOpError("requires predicate load dist to be NORM, US, or DS"); + return success(); +} + void PldiOp::getEffects( SmallVectorImpl> &effects) { @@ -6194,6 +6398,21 @@ LogicalResult VstsOp::verify() { return emitOpError("requires updated base result to match base type"); return success(); } + +void VastOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VastOp::verify() { + if (failed(verifyVstsCommon(*this))) + return failure(); + if (failed(verifyVAddrTypeLike(*this, getAddr().getType(), "addr type"))) + return failure(); + return success(); +} void Vstsx2Op::getEffects( SmallVectorImpl> &effects) { @@ -6220,6 +6439,31 @@ LogicalResult Vstsx2Op::verify() { return success(); } +void Vastx2Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLowMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getHighMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult Vastx2Op::verify() { + if (failed(verifyVRegTypeLike(*this, getLow().getType(), "low value type")) || + failed(verifyVRegTypeLike(*this, getHigh().getType(), "high value type")) || + failed(verifyVAddrTypeLike(*this, getAddr().getType(), "addr type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (getLow().getType() != getHigh().getType()) + return emitOpError("requires low/high values to share one vector type"); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!isSupportedVstsx2DistToken(getDist())) + return emitOpError("requires a supported x2 store distribution token"); + return success(); +} + void VscatterOp::getEffects( SmallVectorImpl> &effects) { @@ -6315,6 +6559,26 @@ LogicalResult PstsOp::verify() { return success(); } +void PastOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult PastOp::verify() { + if (failed(verifyMaskTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyVAddrTypeLike(*this, getAddr().getType(), "addr type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!isSupportedPredicateStoreDist(getDist())) + return emitOpError("requires predicate store dist to be NORM or PK"); + return success(); +} + void VsstbOp::getEffects( SmallVectorImpl> &effects) { @@ -6357,6 +6621,24 @@ LogicalResult VstasOp::verify() { return success(); } +void VastaOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VastaOp::verify() { + if (failed(verifyStoreAlignChain(getValue(), *this, "value type")) || + failed(verifyVAddrTypeLike(*this, getAddr().getType(), "addr type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + return success(); +} + void VstarOp::getEffects( SmallVectorImpl> &effects) { @@ -6425,6 +6707,37 @@ LogicalResult VstusOp::verify() { return success(); } +void VastuOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getAlignInMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable()); +} + +LogicalResult VastuOp::verify() { + if (failed(verifyStoreAlignChain(getAlignIn(), *this, "align_in type")) || + failed(verifyAlignTypeLike(*this, getAlignOut().getType(), "align_out type")) || + failed(verifyVAddrTypeLike(*this, getAddrIn().getType(), "addr_in type")) || + failed(verifyVAddrTypeLike(*this, getAddrOut().getType(), "addr_out type")) || + failed(verifyVRegTypeLike(*this, getValue().getType(), "value type"))) + return failure(); + if (getAddrIn().getType() != getAddrOut().getType()) + return emitOpError("requires addr_in and addr_out to have identical types"); + if (getAlignIn().getType() != getAlignOut().getType()) + return emitOpError("requires align_in and align_out to have identical types"); + if (!isBufferLike(getBase().getType())) + return emitOpError("requires a pointer-like base"); + if (classifyMemoryRole(getBase().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed base"); + if (getMode() != "POST_UPDATE") + return emitOpError("requires mode to be \"POST_UPDATE\""); + if (failed(verifySingleVectorAddressUpdateSeedUse(*this, getAddrIn(), + "addr_in"))) + return failure(); + return success(); +} + void VsturOp::getEffects( SmallVectorImpl> &effects) { diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index 2b881c6f6d..ef3e0e1f1e 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -25,6 +25,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -157,6 +158,8 @@ static Type convertVPTOType(Type type, Builder &builder) { } if (isa(type)) return VectorType::get({256}, builder.getI1Type()); + if (isa(type)) + return VectorType::get({1}, builder.getI32Type()); if (isa(type)) return VectorType::get({32}, builder.getI8Type()); if (auto ptrType = dyn_cast(type)) { @@ -199,7 +202,8 @@ static unsigned getNaturalByteAlignment(Type type) { } static bool hasVPTOConvertibleType(Type type) { - return isa(type); + return isa(type); } static bool hasVPTOConvertibleType(TypeRange types) { @@ -557,6 +561,18 @@ static std::string getMemoryElementTypeFragment(Type type) { return getLowPrecisionElementFragment(type); } +static std::string getVectorAddressElementTypeFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) + return "i" + std::to_string(intType.getWidth()); + return {}; +} + static bool isLowpPayloadElementType(Type type) { return pto::isPTOFloat8Type(type) || pto::isPTOHiFloat8Type(type) || pto::isPTOFloat4PackedType(type); @@ -3001,6 +3017,61 @@ static FailureOr buildVldusCallee(MLIRContext *context, .getValue(); } +static StringRef buildVagCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vag.32").getValue(); +} + +static FailureOr +buildVectorAddressTypedCallee(MLIRContext *context, Type vectorType, + StringRef stem) { + std::string vec = + getVectorAddressElementTypeFragment(getElementTypeFromVectorLike(vectorType)); + auto lanes = getElementCountFromVectorLike(vectorType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm." + stem.str() + ".v" + + std::to_string(*lanes) + vec) + .getValue(); +} + +static FailureOr buildValdCallee(MLIRContext *context, + Type resultType) { + return buildVectorAddressTypedCallee(context, resultType, "vldx1"); +} + +static FailureOr buildValdx2Callee(MLIRContext *context, + Type resultType) { + return buildVectorAddressTypedCallee(context, resultType, "vldx2"); +} + +static FailureOr buildVastCallee(MLIRContext *context, + Type valueType) { + return buildVectorAddressTypedCallee(context, valueType, "vstx1"); +} + +static FailureOr buildVastx2Callee(MLIRContext *context, + Type valueType) { + return buildVectorAddressTypedCallee(context, valueType, "vstx2"); +} + +static FailureOr buildValduCallee(MLIRContext *context, + Type resultType) { + return buildVectorAddressTypedCallee(context, resultType, "vldu.v300"); +} + +static FailureOr buildVastuCallee(MLIRContext *context, + Type valueType) { + return buildVectorAddressTypedCallee(context, valueType, "vstu"); +} + +static StringRef buildValdaCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vlda").getValue(); +} + +static StringRef buildVastaCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vsta").getValue(); +} + static FailureOr buildVcmpCallee(MLIRContext *context, Type inputType, StringRef cmpMode, bool isScalarCompare) { @@ -3353,6 +3424,14 @@ static StringRef buildPldsCallee(MLIRContext *context) { return StringAttr::get(context, "llvm.hivm.plds.b8").getValue(); } +static StringRef buildPaldCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pld.b8").getValue(); +} + +static StringRef buildPastCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pst.b8").getValue(); +} + static StringRef buildPnotCallee(MLIRContext *context) { return StringAttr::get(context, "llvm.hivm.pnot.z").getValue(); } @@ -6517,6 +6596,152 @@ class LowerVldsOpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerVagOpPattern final : public OpConversionPattern { +public: + explicit LowerVagOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VagOp op, pto::VagOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = this->getTypeConverter()->convertType(op.getAddr().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vag result type"); + + ValueRange strides = adaptor.getStrides(); + if (strides.empty() || strides.size() > 4) + return rewriter.notifyMatchFailure(op, "unexpected vag stride count"); + + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args(strides.begin(), strides.end()); + while (args.size() < 4) + args.push_back(zeroValue); + + StringRef calleeName = buildVagCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI32Type(), rewriter.getI32Type(), + rewriter.getI32Type(), rewriter.getI32Type()}, + TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerValdOpPattern final : public OpConversionPattern { +public: + explicit LowerValdOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::ValdOp op, pto::ValdOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type ptoResultType = op.getResult().getType(); + Type elementType = getElementTypeFromVectorLike(ptoResultType); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + auto dist = parseLoadDistImmediate(op.getDist().value_or("NORM"), elementType); + if (!elementType || !basePtr || !dist) + return rewriter.notifyMatchFailure(op, "failed to materialize vald operands"); + + Type resultType = this->getTypeConverter()->convertType(ptoResultType); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vald result type"); + + Type callValueType = + getPayloadABIType(ptoResultType, resultType, rewriter.getContext()); + FailureOr calleeName = + buildValdCallee(op.getContext(), ptoResultType); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vald signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getSource(), adaptor.getAddr(), distValue, + zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getAddr().getType(), + distValue.getType(), zeroValue.getType()}, + TypeRange{callValueType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{callValueType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + Value loaded = castFromPayloadABI(op.getLoc(), call.getResult(0), + ptoResultType, resultType, rewriter); + rewriter.replaceOp(op, loaded); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerValdx2OpPattern final : public OpConversionPattern { +public: + explicit LowerValdx2OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Valdx2Op op, pto::Valdx2Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getLow().getType()); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + auto dist = parseLoadX2DistImmediate(op.getDist(), elementType); + if (!elementType || !basePtr || !dist) + return rewriter.notifyMatchFailure(op, "failed to materialize valdx2 operands"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + resultTypes)) || + resultTypes.size() != 2) + return rewriter.notifyMatchFailure(op, "failed to convert valdx2 result types"); + + Type lowCallType = + getPayloadABIType(op.getLow().getType(), resultTypes[0], + rewriter.getContext()); + Type highCallType = + getPayloadABIType(op.getHigh().getType(), resultTypes[1], + rewriter.getContext()); + SmallVector callResultTypes{lowCallType, highCallType}; + + FailureOr calleeName = + buildValdx2Callee(op.getContext(), op.getLow().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported valdx2 signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getSource(), adaptor.getAddr(), distValue, + zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getAddr().getType(), + distValue.getType(), zeroValue.getType()}, + callResultTypes); + auto call = + rewriter.create(op.getLoc(), *calleeName, callResultTypes, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + Value low = castFromPayloadABI(op.getLoc(), call.getResult(0), + op.getLow().getType(), resultTypes[0], + rewriter); + Value high = castFromPayloadABI(op.getLoc(), call.getResult(1), + op.getHigh().getType(), resultTypes[1], + rewriter); + rewriter.replaceOp(op, ValueRange{low, high}); + return success(); + } + +private: + LoweringState &state; +}; + class LowerVldsx2OpPattern final : public OpConversionPattern { public: explicit LowerVldsx2OpPattern(TypeConverter &typeConverter, @@ -6687,6 +6912,39 @@ class LowerVldasOpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerValdaOpPattern final : public OpConversionPattern { +public: + explicit LowerValdaOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::ValdaOp op, pto::ValdaOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(adaptor.getSource().getType()); + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!sourceType || !resultType) + return rewriter.notifyMatchFailure( + op, "expected converted valda operand/result types"); + + StringRef calleeName = buildValdaCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getAddr().getType(), + zeroValue.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSource(), adaptor.getAddr(), zeroValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + class LowerVldusOpPattern final : public OpConversionPattern { public: explicit LowerVldusOpPattern(TypeConverter &typeConverter, @@ -6734,6 +6992,56 @@ class LowerVldusOpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerValduOpPattern final : public OpConversionPattern { +public: + explicit LowerValduOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::ValduOp op, pto::ValduOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(adaptor.getSource().getType()); + SmallVector resultTypes; + if (!sourceType || + failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + resultTypes)) || + resultTypes.size() != 3 || adaptor.getAlignIn().getType() != resultTypes[1] || + adaptor.getAddrIn().getType() != resultTypes[2]) + return rewriter.notifyMatchFailure( + op, "expected converted valdu operand/result types"); + + FailureOr calleeName = + buildValduCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported valdu signature"); + + Type callValueType = getPayloadABIType( + op.getResult().getType(), resultTypes[0], rewriter.getContext()); + SmallVector callResultTypes{callValueType, resultTypes[1], + resultTypes[2]}; + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getSource(), adaptor.getAddrIn(), + adaptor.getAlignIn(), adaptor.getInc(), zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getAddrIn().getType(), + adaptor.getAlignIn().getType(), adaptor.getInc().getType(), + zeroValue.getType()}, + callResultTypes); + auto call = + rewriter.create(op.getLoc(), *calleeName, callResultTypes, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + Value loaded = castFromPayloadABI(op.getLoc(), call.getResult(0), + op.getResult().getType(), resultTypes[0], + rewriter); + rewriter.replaceOp(op, ValueRange{loaded, call.getResult(1), call.getResult(2)}); + return success(); + } + +private: + LoweringState &state; +}; + class LowerSprclrOpPattern final : public OpConversionPattern { public: explicit LowerSprclrOpPattern(TypeConverter &typeConverter, MLIRContext *context, @@ -6871,6 +7179,49 @@ class LowerVstsOpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerVastOpPattern final : public OpConversionPattern { +public: + explicit LowerVastOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VastOp op, pto::VastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getValue().getType()); + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); + auto dist = parseStoreDistImmediate(op.getDist().value_or(""), elementType); + if (!elementType || !basePtr || !dist) + return rewriter.notifyMatchFailure(op, "failed to materialize vast operands"); + + FailureOr calleeName = + buildVastCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vast signature"); + + Value value = + castToPayloadABI(op.getLoc(), adaptor.getValue(), op.getValue().getType(), + rewriter); + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{value, adaptor.getDestination(), adaptor.getAddr(), + distValue, zeroValue, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{value.getType(), adaptor.getDestination().getType(), + adaptor.getAddr().getType(), distValue.getType(), + zeroValue.getType(), adaptor.getMask().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + class LowerVsstbOpPattern final : public OpConversionPattern { public: explicit LowerVsstbOpPattern(TypeConverter &typeConverter, @@ -6982,6 +7333,55 @@ class LowerVstsx2OpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerVastx2OpPattern final : public OpConversionPattern { +public: + explicit LowerVastx2OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vastx2Op op, pto::Vastx2Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getLow().getType()); + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); + auto dist = parseStoreX2DistImmediate(op.getDist(), elementType); + if (!elementType || !basePtr || !dist) + return rewriter.notifyMatchFailure(op, "failed to materialize vastx2 operands"); + + FailureOr calleeName = + buildVastx2Callee(op.getContext(), op.getLow().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vastx2 signature"); + + Value low = + castToPayloadABI(op.getLoc(), adaptor.getLow(), op.getLow().getType(), + rewriter); + Value high = + castToPayloadABI(op.getLoc(), adaptor.getHigh(), op.getHigh().getType(), + rewriter); + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{low, high, adaptor.getDestination(), + adaptor.getAddr(), distValue, zeroValue, + adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{low.getType(), high.getType(), + adaptor.getDestination().getType(), adaptor.getAddr().getType(), + distValue.getType(), zeroValue.getType(), + adaptor.getMask().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + class LowerPstuOpPattern final : public OpConversionPattern { public: explicit LowerPstuOpPattern(TypeConverter &typeConverter, MLIRContext *context, @@ -7068,6 +7468,57 @@ class LowerVstusOpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerVastuOpPattern final : public OpConversionPattern { +public: + explicit LowerVastuOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VastuOp op, pto::VastuOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto postMode = parsePostModeImmediate(op.getMode()); + if (!postMode || *postMode != 1) + return rewriter.notifyMatchFailure(op, "unsupported vastu mode immediate"); + + SmallVector resultTypes; + auto baseType = dyn_cast(adaptor.getBase().getType()); + if (!baseType || + failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + resultTypes)) || + resultTypes.size() != 2 || adaptor.getAlignIn().getType() != resultTypes[0] || + adaptor.getAddrIn().getType() != resultTypes[1]) + return rewriter.notifyMatchFailure( + op, "unexpected converted vastu operand/result types"); + + FailureOr calleeName = + buildVastuCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vastu signature"); + + Value value = + castToPayloadABI(op.getLoc(), adaptor.getValue(), op.getValue().getType(), + rewriter); + Value modeValue = getI32Constant(rewriter, op.getLoc(), *postMode); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{value, adaptor.getBase(), adaptor.getAddrIn(), + adaptor.getAlignIn(), modeValue, zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{value.getType(), adaptor.getBase().getType(), + adaptor.getAddrIn().getType(), adaptor.getAlignIn().getType(), + modeValue.getType(), zeroValue.getType()}, + resultTypes); + auto call = + rewriter.create(op.getLoc(), *calleeName, resultTypes, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + class LowerVsturOpPattern final : public OpConversionPattern { public: explicit LowerVsturOpPattern(TypeConverter &typeConverter, MLIRContext *context, @@ -7184,6 +7635,40 @@ class LowerVstasOpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerVastaOpPattern final : public OpConversionPattern { +public: + explicit LowerVastaOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VastaOp op, pto::VastaOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto baseType = + dyn_cast(adaptor.getDestination().getType()); + Type alignType = this->getTypeConverter()->convertType(op.getValue().getType()); + if (!baseType || !alignType || adaptor.getValue().getType() != alignType) + return rewriter.notifyMatchFailure( + op, "unexpected converted vasta operand types"); + + StringRef calleeName = buildVastaCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), + adaptor.getAddr(), zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + adaptor.getAddr().getType(), zeroValue.getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + class LowerVgather2OpPattern final : public OpConversionPattern { public: @@ -7783,6 +8268,45 @@ class LowerPredicateStoreOpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerPastOpPattern final : public OpConversionPattern { +public: + explicit LowerPastOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::PastOp op, pto::PastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmDestType = + dyn_cast(adaptor.getDestination().getType()); + Type valueType = this->getTypeConverter()->convertType(op.getValue().getType()); + if (!llvmDestType || !valueType) + return rewriter.notifyMatchFailure( + op, "expected converted past operand types"); + + auto dist = parsePredicateStoreDistImmediate(op.getDist()); + if (!dist) + return rewriter.notifyMatchFailure(op, "unsupported past dist immediate"); + + StringRef calleeName = buildPastCallee(op.getContext()); + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), + adaptor.getAddr(), distValue, zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{valueType, llvmDestType, adaptor.getAddr().getType(), + distValue.getType(), zeroValue.getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + template class LowerPredicateLoadOpPattern final : public OpConversionPattern { public: @@ -7835,6 +8359,47 @@ class LowerPredicateLoadOpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerPaldOpPattern final : public OpConversionPattern { +public: + explicit LowerPaldOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::PaldOp op, pto::PaldOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmSourceType = + dyn_cast(adaptor.getSource().getType()); + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!llvmSourceType || !resultType) + return rewriter.notifyMatchFailure( + op, "expected converted pald operand/result types"); + + auto dist = parsePredicateLoadDistImmediate(op.getDist()); + if (!dist) + return rewriter.notifyMatchFailure(op, "unsupported pald dist immediate"); + + StringRef calleeName = buildPaldCallee(op.getContext()); + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getSource(), adaptor.getAddr(), distValue, + zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{llvmSourceType, adaptor.getAddr().getType(), distValue.getType(), + zeroValue.getType()}, + TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, resultType, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + template class LowerSetLoopConfigOpPattern final : public OpConversionPattern { public: @@ -9920,14 +10485,18 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerRuntimeQueryOpPattern, LowerRuntimeQueryOpPattern, LowerRuntimeQueryOpPattern, - LowerVldsOpPattern, LowerVldsx2OpPattern, LowerVsldbOpPattern, - LowerVldasOpPattern, LowerInitAlignOpPattern, - LowerVldusOpPattern, LowerSprclrOpPattern, + LowerVldsOpPattern, LowerVagOpPattern, LowerValdOpPattern, + LowerValdx2OpPattern, LowerVldsx2OpPattern, LowerVsldbOpPattern, + LowerVldasOpPattern, LowerValdaOpPattern, + LowerInitAlignOpPattern, + LowerVldusOpPattern, LowerValduOpPattern, LowerSprclrOpPattern, LowerSprStoreOpPattern, LowerSprStoreOpPattern, - LowerVstsOpPattern, LowerVsstbOpPattern, - LowerVstsx2OpPattern, + LowerVstsOpPattern, LowerVastOpPattern, + LowerVsstbOpPattern, + LowerVstsx2OpPattern, LowerVastx2OpPattern, LowerVstarOpPattern, LowerVstasOpPattern, + LowerVastaOpPattern, LowerVgather2OpPattern, LowerVgather2BcOpPattern, LowerVgatherbOpPattern, LowerVscatterOpPattern, LowerVaxpyOpPattern, @@ -9937,9 +10506,12 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerVbitcastOpPattern, LowerPbitcastOpPattern, LowerPredicateLoadOpPattern, LowerPredicateLoadOpPattern, + LowerPaldOpPattern, LowerPredicateStoreOpPattern, LowerPredicateStoreOpPattern, - LowerPstuOpPattern, LowerVstusOpPattern, LowerVsturOpPattern, + LowerPastOpPattern, + LowerPstuOpPattern, LowerVstusOpPattern, LowerVastuOpPattern, + LowerVsturOpPattern, LowerInterCoreSyncOpPattern, LowerInterCoreSyncOpPattern, LowerCopyGmToCbufOpPattern, LowerLoadCbufToCaOpPattern, @@ -10015,14 +10587,20 @@ static void configureVPTOOpLoweringTarget(ConversionTarget &target, pto::StoreVfSimtInfoOp, pto::SetMovPadValOp, pto::SetQuantPreOp>(); target.addIllegalOp(); - target.addIllegalOp(); + pto::PldiOp, pto::PldsOp, pto::PaldOp, pto::PstiOp, + pto::PstsOp, pto::PastOp, + pto::PstuOp, pto::VstusOp, pto::VastuOp, + pto::VsturOp>(); target.addIllegalOp/dev/null 2>&1 || true; mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vector_address_ops(%src_f32: !pto.ptr, %dst_f32: !pto.ptr, + %src_i32: !pto.ptr, %dst_i32: !pto.ptr, + %pred_base: !pto.ptr) attributes {pto.kernel} { + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_i16 = arith.constant 0 : i16 + %c1_i16 = arith.constant 1 : i16 + %c2_i16 = arith.constant 2 : i16 + pto.vecscope { + scf.for %i = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + %addr = pto.vag %c4_i32 : i32 -> !pto.vaddr + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %loaded = pto.vald %src_f32[%addr] {dist = "NORM"} : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %loaded, %dst_f32[%addr], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + %pred = pto.pald %pred_base[%addr], "NORM" : !pto.ptr, !pto.vaddr -> !pto.mask + pto.past %pred, %pred_base[%addr], "NORM" : !pto.mask, !pto.ptr, !pto.vaddr + %low, %high = pto.valdx2 %src_i32[%addr], "DINTLV_B32" : !pto.ptr, !pto.vaddr -> !pto.vreg<64xi32>, !pto.vreg<64xi32> + pto.vastx2 %low, %high, %dst_i32[%addr], "INTLV_B32", %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.ptr, !pto.vaddr, !pto.mask + %load_align = pto.valda %src_f32[%addr] : !pto.ptr, !pto.vaddr -> !pto.align + %unaligned, %next_load_align, %next_load_addr = pto.valdu %src_f32[%addr], %load_align, %c256_i32 : !pto.ptr, !pto.vaddr, !pto.align, i32 -> !pto.vreg<64xf32>, !pto.align, !pto.vaddr + %store_align = pto.init_align : !pto.align + %store_addr = pto.vag %c4_i32 : i32 -> !pto.vaddr + %next_store_align, %next_store_addr = pto.vastu %store_align, %store_addr, %unaligned, %dst_f32, "POST_UPDATE" : !pto.align, !pto.vaddr, !pto.vreg<64xf32>, !pto.ptr -> !pto.align, !pto.vaddr + pto.vasta %next_store_align, %dst_f32[%next_store_addr] : !pto.align, !pto.ptr, !pto.vaddr + } + } + return + } +} + +// CHECK-LABEL: llvm.func @vector_address_ops_mix_aiv +// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32 +// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: llvm.call @llvm.hivm.vag.32(%[[C4]], {{%[0-9]+}}, {{%[0-9]+}}, {{%[0-9]+}}) +// CHECK-SAME: : (i32, i32, i32, i32) -> vector<1xi32> +// CHECK: llvm.call @llvm.hivm.vldx1.v64f32 +// CHECK-SAME: (!llvm.ptr<6>, vector<1xi32>, i32, i32) -> vector<64xf32> +// CHECK: llvm.call @llvm.hivm.vstx1.v64f32 +// CHECK-SAME: (vector<64xf32>, !llvm.ptr<6>, vector<1xi32>, i32, i32, vector<256xi1>) -> () +// CHECK: llvm.call @llvm.hivm.pld.b8 +// CHECK-SAME: (!llvm.ptr<6>, vector<1xi32>, i32, i32) -> vector<256xi1> +// CHECK: llvm.call @llvm.hivm.pst.b8 +// CHECK-SAME: (vector<256xi1>, !llvm.ptr<6>, vector<1xi32>, i32, i32) -> () +// CHECK: llvm.call @llvm.hivm.vldx2.v64i32 +// CHECK-SAME: (!llvm.ptr<6>, vector<1xi32>, i32, i32) -> !llvm.struct<(vector<64xi32>, vector<64xi32>)> +// CHECK: llvm.call @llvm.hivm.vstx2.v64i32 +// CHECK-SAME: (vector<64xi32>, vector<64xi32>, !llvm.ptr<6>, vector<1xi32>, i32, i32, vector<256xi1>) -> () +// CHECK: llvm.call @llvm.hivm.vlda +// CHECK-SAME: (!llvm.ptr<6>, vector<1xi32>, i32) -> vector<32xi8> +// CHECK: llvm.call @llvm.hivm.vldu.v300.v64f32 +// CHECK-SAME: (!llvm.ptr<6>, vector<1xi32>, vector<32xi8>, i32, i32) -> !llvm.struct<(vector<64xf32>, vector<32xi8>, vector<1xi32>)> +// CHECK: llvm.call @llvm.hivm.vag.32(%[[C4]], {{%[0-9]+}}, {{%[0-9]+}}, {{%[0-9]+}}) +// CHECK-SAME: : (i32, i32, i32, i32) -> vector<1xi32> +// CHECK: llvm.call @llvm.hivm.vstu.v64f32 +// CHECK-SAME: (vector<64xf32>, !llvm.ptr<6>, vector<1xi32>, vector<32xi8>, i32, i32) -> !llvm.struct<(vector<32xi8>, vector<1xi32>)> +// CHECK: llvm.call @llvm.hivm.vsta +// CHECK-SAME: (vector<32xi8>, !llvm.ptr<6>, vector<1xi32>, i32) -> () diff --git a/test/lit/vpto/vector_address_update_seed_verify_invalid.pto b/test/lit/vpto/vector_address_update_seed_verify_invalid.pto new file mode 100644 index 0000000000..94ea4ff6bd --- /dev/null +++ b/test/lit/vpto/vector_address_update_seed_verify_invalid.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @bad_update_seed(%src: !pto.ptr, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0_i16 = arith.constant 0 : i16 + %c1_i16 = arith.constant 1 : i16 + %c2_i16 = arith.constant 2 : i16 + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + pto.vecscope { + scf.for %i = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + %addr = pto.vag %c4_i32 : i32 -> !pto.vaddr + %load_align = pto.valda %src[%addr] + : !pto.ptr, !pto.vaddr -> !pto.align + %value, %next_load_align, %next_load_addr = pto.valdu %src[%addr], %load_align, %c256_i32 + : !pto.ptr, !pto.vaddr, !pto.align, i32 -> !pto.vreg<64xf32>, !pto.align, !pto.vaddr + %store_align = pto.init_align : !pto.align + %next_store_align, %next_store_addr = pto.vastu %store_align, %addr, %value, %dst, "POST_UPDATE" + : !pto.align, !pto.vaddr, !pto.vreg<64xf32>, !pto.ptr -> !pto.align, !pto.vaddr + } + } + return + } +} + +// CHECK: addr_in must not seed multiple vector-address update chains diff --git a/test/lit/vpto/vector_address_vag_verify_invalid.pto b/test/lit/vpto/vector_address_vag_verify_invalid.pto new file mode 100644 index 0000000000..1d479188de --- /dev/null +++ b/test/lit/vpto/vector_address_vag_verify_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vag_requires_loop() attributes {pto.kernel} { + %c4_i32 = arith.constant 4 : i32 + pto.vecscope { + %addr = pto.vag %c4_i32 : i32 -> !pto.vaddr + } + return + } +} + +// CHECK: must be nested under scf.for diff --git a/test/lit/vpto/vector_address_vag_verify_invalid_counter.pto b/test/lit/vpto/vector_address_vag_verify_invalid_counter.pto new file mode 100644 index 0000000000..8404009dae --- /dev/null +++ b/test/lit/vpto/vector_address_vag_verify_invalid_counter.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @bad_vag_counter() attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4_i32 = arith.constant 4 : i32 + pto.vecscope { + scf.for %i = %c0 to %c2 step %c1 { + %addr = pto.vag %c4_i32 : i32 -> !pto.vaddr + } + } + return + } +} + +// CHECK: requires enclosing scf.for induction variable to be i16 diff --git a/test/vpto/cases/micro-op/vector-address/vald-vast/compare.py b/test/vpto/cases/micro-op/vector-address/vald-vast/compare.py new file mode 100644 index 0000000000..9c4c289c9a --- /dev/null +++ b/test/vpto/cases/micro-op/vector-address/vald-vast/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(diff)) + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-6) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-address/vald-vast/golden.py b/test/vpto/cases/micro-op/vector-address/vald-vast/golden.py new file mode 100644 index 0000000000..0d3a1d5f7d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-address/vald-vast/golden.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 64 +SEED = 29 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-16.0, 16.0, size=(ELEMENTS,)).astype(np.float32) + v2 = np.zeros((ELEMENTS,), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v1.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate inputs/golden for VPTO vector-address vald/vast." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-address/vald-vast/kernel.pto b/test/vpto/cases/micro-op/vector-address/vald-vast/kernel.pto new file mode 100644 index 0000000000..0f0b741094 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-address/vald-vast/kernel.pto @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vector_address_vald_vast(%src: !pto.ptr, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c0_i16 = arith.constant 0 : i16 + %c1_i16 = arith.constant 1 : i16 + %c2_i16 = arith.constant 2 : i16 + %c4_i32 = arith.constant 4 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c256_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %i = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + %addr = pto.vag %c4_i32 : i32 -> !pto.vaddr + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %value = pto.vald %ub_src[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %value, %ub_dst[%addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.mte_ub_gm %ub_dst, %dst, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-address/vald-vast/launch.cpp b/test/vpto/cases/micro-op/vector-address/vald-vast/launch.cpp new file mode 100644 index 0000000000..22f96fdcd4 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-address/vald-vast/launch.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vector_address_vald_vast(__gm__ float *src, __gm__ float *dst); + +void LaunchVectorAddressValdVast(float *src, float *dst, void *stream) { + vector_address_vald_vast<<<1, nullptr, stream>>>((__gm__ float *)src, + (__gm__ float *)dst); +} diff --git a/test/vpto/cases/micro-op/vector-address/vald-vast/main.cpp b/test/vpto/cases/micro-op/vector-address/vald-vast/main.cpp new file mode 100644 index 0000000000..2e8fd28e8e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-address/vald-vast/main.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVectorAddressValdVast(float *src, float *dst, void *stream); + +int main() { + constexpr size_t kElements = 64; + constexpr size_t kBytes = kElements * sizeof(float); + + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t srcFileSize = kBytes; + size_t dstFileSize = kBytes; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), kBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), kBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, kBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, kBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + if (!ReadFile("./v1.bin", srcFileSize, srcHost, kBytes) || + srcFileSize != kBytes) { + std::fprintf(stderr, "[ERROR] failed to read v1.bin\n"); + rc = 1; + goto cleanup; + } + if (!ReadFile("./v2.bin", dstFileSize, dstHost, kBytes) || + dstFileSize != kBytes) { + std::fprintf(stderr, "[ERROR] failed to read v2.bin\n"); + rc = 1; + goto cleanup; + } + ACL_CHECK(aclrtMemcpy(srcDevice, kBytes, srcHost, kBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, kBytes, dstHost, kBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVectorAddressValdVast(srcDevice, dstDevice, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, kBytes, dstDevice, kBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, kBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/compare.py b/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/compare.py new file mode 100644 index 0000000000..2edd0c0ccd --- /dev/null +++ b/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path): + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + mismatch = np.flatnonzero(golden != output) + idx = int(mismatch[0]) + print( + f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} " + f"out={int(output[idx])} mismatch_count={mismatch.size}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/golden.py b/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/golden.py new file mode 100644 index 0000000000..3c8747961f --- /dev/null +++ b/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +TOTAL_BYTES = 4096 +SEED = 41 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + values = rng.uniform(-32.0, 32.0, size=(TOTAL_BYTES // 4,)).astype(np.float32) + data = values.view(np.uint8) + golden = np.zeros((TOTAL_BYTES,), dtype=np.uint8) + + golden[0:512] = data[0:512] + golden[1024:1280] = data[1024:1280] + # The unaligned store chain uses an independent vector address so the + # generated code does not need to copy one address register into another. + golden[2048:2304] = data[2052:2308] + + output_dir.mkdir(parents=True, exist_ok=True) + data.tofile(output_dir / "v1.bin") + np.zeros((TOTAL_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate inputs/golden for VPTO vector-address x2/predicate/unaligned case." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/kernel.pto b/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/kernel.pto new file mode 100644 index 0000000000..6cca419674 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/kernel.pto @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vector_address_x2_predicate_unaligned(%src: !pto.ptr, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i16 = arith.constant 0 : i16 + %c1_i16 = arith.constant 1 : i16 + %c2_i16 = arith.constant 2 : i16 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + + %ub_in_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_in_f32 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_f32 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_pred = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src, %ub_in_u8, %c0_i64, %c256_i64 + nburst(%c16_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst, %ub_out_u8, %c0_i64, %c256_i64 + nburst(%c16_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %i = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + %addr = pto.vag %c4_i32 : i32 -> !pto.vaddr + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + + %low, %high = pto.valdx2 %ub_in_f32[%addr], "DINTLV_B32" + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vastx2 %low, %high, %ub_out_f32[%addr], "INTLV_B32", %mask + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + + pto.past %mask, %ub_pred[%addr], "NORM" + : !pto.mask, !pto.ptr, !pto.vaddr + pto.mem_bar "VST_VLD" + %loaded_mask = pto.pald %ub_pred[%addr], "NORM" + : !pto.ptr, !pto.vaddr -> !pto.mask + %pred_src = pto.addptr %ub_in_f32, %c256 : !pto.ptr -> !pto.ptr + %pred_dst = pto.addptr %ub_out_f32, %c256 : !pto.ptr -> !pto.ptr + %pred_value = pto.vald %pred_src[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %pred_value, %pred_dst[%addr], %loaded_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + + %unaligned_base = pto.addptr %ub_in_f32, %c512 : !pto.ptr -> !pto.ptr + %unaligned_src = pto.addptr %unaligned_base, %c1 : !pto.ptr -> !pto.ptr + %unaligned_dst = pto.addptr %ub_out_f32, %c512 : !pto.ptr -> !pto.ptr + %load_align = pto.valda %unaligned_src[%addr] + : !pto.ptr, !pto.vaddr -> !pto.align + %unaligned, %next_load_align, %next_load_addr = pto.valdu %unaligned_src[%addr], %load_align, %c256_i32 + : !pto.ptr, !pto.vaddr, !pto.align, i32 -> !pto.vreg<64xf32>, !pto.align, !pto.vaddr + %store_align = pto.init_align : !pto.align + %store_addr = pto.vag %c4_i32 : i32 -> !pto.vaddr + %next_store_align, %next_store_addr = pto.vastu %store_align, %store_addr, %unaligned, %unaligned_dst, "POST_UPDATE" + : !pto.align, !pto.vaddr, !pto.vreg<64xf32>, !pto.ptr -> !pto.align, !pto.vaddr + pto.vasta %next_store_align, %unaligned_dst[%next_store_addr] + : !pto.align, !pto.ptr, !pto.vaddr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.mte_ub_gm %ub_out_u8, %dst, %c256_i64 + nburst(%c16_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/launch.cpp b/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/launch.cpp new file mode 100644 index 0000000000..b5fa8db339 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vector_address_x2_predicate_unaligned(__gm__ uint8_t *src, __gm__ uint8_t *dst); + +void LaunchVectorAddressX2PredicateUnaligned(uint8_t *src, uint8_t *dst, + void *stream) { + vector_address_x2_predicate_unaligned<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/main.cpp b/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/main.cpp new file mode 100644 index 0000000000..6be66039c5 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/main.cpp @@ -0,0 +1,100 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVectorAddressX2PredicateUnaligned(uint8_t *src, uint8_t *dst, + void *stream); + +int main() { + constexpr size_t kBytes = 4096; + + uint8_t *srcHost = nullptr; + uint8_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *dstDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t srcFileSize = kBytes; + size_t dstFileSize = kBytes; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), kBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), kBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, kBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, kBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + if (!ReadFile("./v1.bin", srcFileSize, srcHost, kBytes) || + srcFileSize != kBytes) { + std::fprintf(stderr, "[ERROR] failed to read v1.bin\n"); + rc = 1; + goto cleanup; + } + if (!ReadFile("./v2.bin", dstFileSize, dstHost, kBytes) || + dstFileSize != kBytes) { + std::fprintf(stderr, "[ERROR] failed to read v2.bin\n"); + rc = 1; + goto cleanup; + } + ACL_CHECK(aclrtMemcpy(srcDevice, kBytes, srcHost, kBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, kBytes, dstHost, kBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVectorAddressX2PredicateUnaligned(srcDevice, dstDevice, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, kBytes, dstDevice, kBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, kBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} From c96da94ddddf298b54cd9a3d07ff268d69e9604a Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sat, 27 Jun 2026 14:29:19 +0800 Subject: [PATCH 2/4] Consolidate vector address SIM coverage --- .../compare.py | 0 .../golden.py | 25 +-- .../multidim-vald-vast/kernel.pto | 165 ++++++++++++++++++ .../launch.cpp | 9 +- .../main.cpp | 39 +++-- .../vector-address/vald-vast/compare.py | 50 ------ .../vector-address/vald-vast/golden.py | 42 ----- .../vector-address/vald-vast/kernel.pto | 50 ------ .../vector-address/vald-vast/main.cpp | 99 ----------- .../x2-predicate-unaligned/kernel.pto | 92 ---------- .../x2-predicate-unaligned/launch.cpp | 35 ---- 11 files changed, 204 insertions(+), 402 deletions(-) rename test/vpto/cases/micro-op/vector-address/{x2-predicate-unaligned => multidim-vald-vast}/compare.py (100%) rename test/vpto/cases/micro-op/vector-address/{x2-predicate-unaligned => multidim-vald-vast}/golden.py (67%) create mode 100644 test/vpto/cases/micro-op/vector-address/multidim-vald-vast/kernel.pto rename test/vpto/cases/micro-op/vector-address/{vald-vast => multidim-vald-vast}/launch.cpp (82%) rename test/vpto/cases/micro-op/vector-address/{x2-predicate-unaligned => multidim-vald-vast}/main.cpp (73%) delete mode 100644 test/vpto/cases/micro-op/vector-address/vald-vast/compare.py delete mode 100644 test/vpto/cases/micro-op/vector-address/vald-vast/golden.py delete mode 100644 test/vpto/cases/micro-op/vector-address/vald-vast/kernel.pto delete mode 100644 test/vpto/cases/micro-op/vector-address/vald-vast/main.cpp delete mode 100644 test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/kernel.pto delete mode 100644 test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/launch.cpp diff --git a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/compare.py b/test/vpto/cases/micro-op/vector-address/multidim-vald-vast/compare.py similarity index 100% rename from test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/compare.py rename to test/vpto/cases/micro-op/vector-address/multidim-vald-vast/compare.py diff --git a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/golden.py b/test/vpto/cases/micro-op/vector-address/multidim-vald-vast/golden.py similarity index 67% rename from test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/golden.py rename to test/vpto/cases/micro-op/vector-address/multidim-vald-vast/golden.py index 3c8747961f..65c5a4dda4 100644 --- a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/golden.py +++ b/test/vpto/cases/micro-op/vector-address/multidim-vald-vast/golden.py @@ -13,31 +13,34 @@ import numpy as np -TOTAL_BYTES = 4096 -SEED = 41 +SOURCE_BYTES = 4096 +OUTPUT_BYTES = 8192 +SEED = 43 def generate(output_dir: Path, seed: int) -> None: rng = np.random.default_rng(seed) - values = rng.uniform(-32.0, 32.0, size=(TOTAL_BYTES // 4,)).astype(np.float32) + values = rng.uniform(-32.0, 32.0, size=(SOURCE_BYTES // 4,)).astype(np.float32) data = values.view(np.uint8) - golden = np.zeros((TOTAL_BYTES,), dtype=np.uint8) + golden = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) - golden[0:512] = data[0:512] - golden[1024:1280] = data[1024:1280] - # The unaligned store chain uses an independent vector address so the - # generated code does not need to copy one address register into another. - golden[2048:2304] = data[2052:2308] + golden[0:1024] = data[0:1024] + golden[1024:2048] = data[0:1024] + + x2_base = 2048 + golden[x2_base : x2_base + 512] = data[0:512] + golden[x2_base + 1024 : x2_base + 1280] = data[1024:1280] + golden[x2_base + 2048 : x2_base + 2304] = data[2052:2308] output_dir.mkdir(parents=True, exist_ok=True) data.tofile(output_dir / "v1.bin") - np.zeros((TOTAL_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") golden.tofile(output_dir / "golden_v2.bin") def main() -> None: parser = argparse.ArgumentParser( - description="Generate inputs/golden for VPTO vector-address x2/predicate/unaligned case." + description="Generate inputs/golden for VPTO combined vector-address memory case." ) parser.add_argument("--output-dir", type=Path, default=Path(".")) parser.add_argument("--seed", type=int, default=SEED) diff --git a/test/vpto/cases/micro-op/vector-address/multidim-vald-vast/kernel.pto b/test/vpto/cases/micro-op/vector-address/multidim-vald-vast/kernel.pto new file mode 100644 index 0000000000..7397687595 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-address/multidim-vald-vast/kernel.pto @@ -0,0 +1,165 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vector_address_multidim_vald_vast(%src: !pto.ptr, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c192 = arith.constant 192 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + %c0_i16 = arith.constant 0 : i16 + %c1_i16 = arith.constant 1 : i16 + %c2_i16 = arith.constant 2 : i16 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c24576_i64 = arith.constant 24576 : i64 + %c4_i32 = arith.constant 4 : i32 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c96_i32 = arith.constant 96 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + + %ub_in_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_in_f32 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_f32 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_pred = pto.castptr %c24576_i64 : i64 -> !pto.ptr + + %ub_src_1 = pto.addptr %ub_in_f32, %c64 : !pto.ptr -> !pto.ptr + %ub_src_2 = pto.addptr %ub_in_f32, %c128 : !pto.ptr -> !pto.ptr + %ub_src_3 = pto.addptr %ub_in_f32, %c192 : !pto.ptr -> !pto.ptr + + %ub_1d_1 = pto.addptr %ub_out_f32, %c64 : !pto.ptr -> !pto.ptr + %ub_1d_2 = pto.addptr %ub_out_f32, %c128 : !pto.ptr -> !pto.ptr + %ub_1d_3 = pto.addptr %ub_out_f32, %c192 : !pto.ptr -> !pto.ptr + + %ub_4d = pto.addptr %ub_out_f32, %c256 : !pto.ptr -> !pto.ptr + %ub_4d_1 = pto.addptr %ub_4d, %c64 : !pto.ptr -> !pto.ptr + %ub_4d_2 = pto.addptr %ub_4d, %c128 : !pto.ptr -> !pto.ptr + %ub_4d_3 = pto.addptr %ub_4d, %c192 : !pto.ptr -> !pto.ptr + + %ub_x2_u8 = pto.addptr %ub_out_u8, %c2048 : !pto.ptr -> !pto.ptr + %ub_x2_f32 = pto.addptr %ub_out_f32, %c512 : !pto.ptr -> !pto.ptr + + pto.mte_gm_ub %src, %ub_in_u8, %c0_i64, %c256_i64 + nburst(%c16_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst, %ub_out_u8, %c0_i64, %c256_i64 + nburst(%c32_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %i = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + %addr = pto.vag %c32_i32 : i32 -> !pto.vaddr + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %value_0 = pto.vald %ub_in_f32[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %value_0, %ub_out_f32[%addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + %value_1 = pto.vald %ub_src_1[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %value_1, %ub_1d_1[%addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + %value_2 = pto.vald %ub_src_2[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %value_2, %ub_1d_2[%addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + %value_3 = pto.vald %ub_src_3[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %value_3, %ub_1d_3[%addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + } + + scf.for %k = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + scf.for %l = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + scf.for %m = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + scf.for %n = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + %addr = pto.vag %c32_i32, %c64_i32, %c96_i32, %c128_i32 + : i32, i32, i32, i32 -> !pto.vaddr + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %value_0 = pto.vald %ub_in_f32[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %value_0, %ub_4d[%addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + %value_1 = pto.vald %ub_src_1[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %value_1, %ub_4d_1[%addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + %value_2 = pto.vald %ub_src_2[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %value_2, %ub_4d_2[%addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + %value_3 = pto.vald %ub_src_3[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %value_3, %ub_4d_3[%addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + } + } + } + } + + scf.for %j = %c0_i16 to %c2_i16 step %c1_i16 : i16 { + %addr = pto.vag %c4_i32 : i32 -> !pto.vaddr + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + + %low, %high = pto.valdx2 %ub_in_f32[%addr], "DINTLV_B32" + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vastx2 %low, %high, %ub_x2_f32[%addr], "INTLV_B32", %mask + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + + pto.past %mask, %ub_pred[%addr], "NORM" + : !pto.mask, !pto.ptr, !pto.vaddr + pto.mem_bar "VST_VLD" + %loaded_mask = pto.pald %ub_pred[%addr], "NORM" + : !pto.ptr, !pto.vaddr -> !pto.mask + %pred_src = pto.addptr %ub_in_f32, %c256 : !pto.ptr -> !pto.ptr + %pred_dst = pto.addptr %ub_x2_f32, %c256 : !pto.ptr -> !pto.ptr + %pred_value = pto.vald %pred_src[%addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %pred_value, %pred_dst[%addr], %loaded_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + + %unaligned_base = pto.addptr %ub_in_f32, %c512 : !pto.ptr -> !pto.ptr + %unaligned_src = pto.addptr %unaligned_base, %c1 : !pto.ptr -> !pto.ptr + %unaligned_dst = pto.addptr %ub_x2_f32, %c512 : !pto.ptr -> !pto.ptr + %load_align = pto.valda %unaligned_src[%addr] + : !pto.ptr, !pto.vaddr -> !pto.align + %unaligned, %next_load_align, %next_load_addr = pto.valdu %unaligned_src[%addr], %load_align, %c256_i32 + : !pto.ptr, !pto.vaddr, !pto.align, i32 -> !pto.vreg<64xf32>, !pto.align, !pto.vaddr + %store_align = pto.init_align : !pto.align + %store_addr = pto.vag %c4_i32 : i32 -> !pto.vaddr + %next_store_align, %next_store_addr = pto.vastu %store_align, %store_addr, %unaligned, %unaligned_dst, "POST_UPDATE" + : !pto.align, !pto.vaddr, !pto.vreg<64xf32>, !pto.ptr -> !pto.align, !pto.vaddr + pto.vasta %next_store_align, %unaligned_dst[%next_store_addr] + : !pto.align, !pto.ptr, !pto.vaddr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.mte_ub_gm %ub_out_u8, %dst, %c256_i64 + nburst(%c32_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-address/vald-vast/launch.cpp b/test/vpto/cases/micro-op/vector-address/multidim-vald-vast/launch.cpp similarity index 82% rename from test/vpto/cases/micro-op/vector-address/vald-vast/launch.cpp rename to test/vpto/cases/micro-op/vector-address/multidim-vald-vast/launch.cpp index 22f96fdcd4..cf9d132c59 100644 --- a/test/vpto/cases/micro-op/vector-address/vald-vast/launch.cpp +++ b/test/vpto/cases/micro-op/vector-address/multidim-vald-vast/launch.cpp @@ -39,9 +39,10 @@ struct MrgSortExecutedNumList { #endif extern "C" __global__ [aicore] void -vector_address_vald_vast(__gm__ float *src, __gm__ float *dst); +vector_address_multidim_vald_vast(__gm__ uint8_t *src, __gm__ uint8_t *dst); -void LaunchVectorAddressValdVast(float *src, float *dst, void *stream) { - vector_address_vald_vast<<<1, nullptr, stream>>>((__gm__ float *)src, - (__gm__ float *)dst); +void LaunchVectorAddressMultidimValdVast(uint8_t *src, uint8_t *dst, + void *stream) { + vector_address_multidim_vald_vast<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst); } diff --git a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/main.cpp b/test/vpto/cases/micro-op/vector-address/multidim-vald-vast/main.cpp similarity index 73% rename from test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/main.cpp rename to test/vpto/cases/micro-op/vector-address/multidim-vald-vast/main.cpp index 6be66039c5..cf0032da09 100644 --- a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/main.cpp +++ b/test/vpto/cases/micro-op/vector-address/multidim-vald-vast/main.cpp @@ -9,9 +9,9 @@ #include "acl/acl.h" #include "test_common.h" +#include #include #include -#include using namespace PtoTestCommon; @@ -29,11 +29,12 @@ using namespace PtoTestCommon; } \ } while (0) -void LaunchVectorAddressX2PredicateUnaligned(uint8_t *src, uint8_t *dst, - void *stream); +void LaunchVectorAddressMultidimValdVast(uint8_t *src, uint8_t *dst, + void *stream); int main() { - constexpr size_t kBytes = 4096; + constexpr size_t kSrcBytes = 4096; + constexpr size_t kDstBytes = 8192; uint8_t *srcHost = nullptr; uint8_t *dstHost = nullptr; @@ -45,8 +46,8 @@ int main() { bool deviceSet = false; int deviceId = 0; aclrtStream stream = nullptr; - size_t srcFileSize = kBytes; - size_t dstFileSize = kBytes; + size_t srcFileSize = kSrcBytes; + size_t dstFileSize = kDstBytes; ACL_CHECK(aclInit(nullptr)); aclInited = true; @@ -56,34 +57,34 @@ int main() { deviceSet = true; ACL_CHECK(aclrtCreateStream(&stream)); - ACL_CHECK(aclrtMallocHost((void **)(&srcHost), kBytes)); - ACL_CHECK(aclrtMallocHost((void **)(&dstHost), kBytes)); - ACL_CHECK(aclrtMalloc((void **)&srcDevice, kBytes, ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc((void **)&dstDevice, kBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), kSrcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), kDstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, kSrcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, kDstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); - if (!ReadFile("./v1.bin", srcFileSize, srcHost, kBytes) || - srcFileSize != kBytes) { + if (!ReadFile("./v1.bin", srcFileSize, srcHost, kSrcBytes) || + srcFileSize != kSrcBytes) { std::fprintf(stderr, "[ERROR] failed to read v1.bin\n"); rc = 1; goto cleanup; } - if (!ReadFile("./v2.bin", dstFileSize, dstHost, kBytes) || - dstFileSize != kBytes) { + if (!ReadFile("./v2.bin", dstFileSize, dstHost, kDstBytes) || + dstFileSize != kDstBytes) { std::fprintf(stderr, "[ERROR] failed to read v2.bin\n"); rc = 1; goto cleanup; } - ACL_CHECK(aclrtMemcpy(srcDevice, kBytes, srcHost, kBytes, + ACL_CHECK(aclrtMemcpy(srcDevice, kSrcBytes, srcHost, kSrcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); - ACL_CHECK(aclrtMemcpy(dstDevice, kBytes, dstHost, kBytes, + ACL_CHECK(aclrtMemcpy(dstDevice, kDstBytes, dstHost, kDstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); - LaunchVectorAddressX2PredicateUnaligned(srcDevice, dstDevice, stream); + LaunchVectorAddressMultidimValdVast(srcDevice, dstDevice, stream); ACL_CHECK(aclrtSynchronizeStream(stream)); - ACL_CHECK(aclrtMemcpy(dstHost, kBytes, dstDevice, kBytes, + ACL_CHECK(aclrtMemcpy(dstHost, kDstBytes, dstDevice, kDstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); - WriteFile("./v2.bin", dstHost, kBytes); + WriteFile("./v2.bin", dstHost, kDstBytes); cleanup: aclrtFree(srcDevice); diff --git a/test/vpto/cases/micro-op/vector-address/vald-vast/compare.py b/test/vpto/cases/micro-op/vector-address/vald-vast/compare.py deleted file mode 100644 index 9c4c289c9a..0000000000 --- a/test/vpto/cases/micro-op/vector-address/vald-vast/compare.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -import os -import sys - -import numpy as np - - -def compare_bin(golden_path, output_path, dtype, eps): - if not os.path.exists(golden_path): - print(f"[ERROR] Golden missing: {golden_path}") - return False - if not os.path.exists(output_path): - print(f"[ERROR] Output missing: {output_path}") - return False - - golden = np.fromfile(golden_path, dtype=dtype) - output = np.fromfile(output_path, dtype=dtype) - if golden.shape != output.shape: - print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") - return False - if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): - diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) - idx = int(np.argmax(diff)) - print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") - return False - return True - - -def main(): - strict = os.getenv("COMPARE_STRICT", "1") != "0" - ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-6) - if not ok: - if strict: - print("[ERROR] compare failed") - sys.exit(2) - print("[WARN] compare failed (non-gating)") - return - print("[INFO] compare passed") - - -if __name__ == "__main__": - main() diff --git a/test/vpto/cases/micro-op/vector-address/vald-vast/golden.py b/test/vpto/cases/micro-op/vector-address/vald-vast/golden.py deleted file mode 100644 index 0d3a1d5f7d..0000000000 --- a/test/vpto/cases/micro-op/vector-address/vald-vast/golden.py +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -import argparse -from pathlib import Path - -import numpy as np - - -ELEMENTS = 64 -SEED = 29 - - -def generate(output_dir: Path, seed: int) -> None: - rng = np.random.default_rng(seed) - v1 = rng.uniform(-16.0, 16.0, size=(ELEMENTS,)).astype(np.float32) - v2 = np.zeros((ELEMENTS,), dtype=np.float32) - - output_dir.mkdir(parents=True, exist_ok=True) - v1.tofile(output_dir / "v1.bin") - v2.tofile(output_dir / "v2.bin") - v1.tofile(output_dir / "golden_v2.bin") - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Generate inputs/golden for VPTO vector-address vald/vast." - ) - parser.add_argument("--output-dir", type=Path, default=Path(".")) - parser.add_argument("--seed", type=int, default=SEED) - args = parser.parse_args() - generate(args.output_dir, args.seed) - - -if __name__ == "__main__": - main() diff --git a/test/vpto/cases/micro-op/vector-address/vald-vast/kernel.pto b/test/vpto/cases/micro-op/vector-address/vald-vast/kernel.pto deleted file mode 100644 index 0f0b741094..0000000000 --- a/test/vpto/cases/micro-op/vector-address/vald-vast/kernel.pto +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { - func.func @vector_address_vald_vast(%src: !pto.ptr, - %dst: !pto.ptr) attributes {pto.kernel} { - %c0_i64 = arith.constant 0 : i64 - %c1_i64 = arith.constant 1 : i64 - %c256_i64 = arith.constant 256 : i64 - %c0_i16 = arith.constant 0 : i16 - %c1_i16 = arith.constant 1 : i16 - %c2_i16 = arith.constant 2 : i16 - %c4_i32 = arith.constant 4 : i32 - - %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr - %ub_dst = pto.castptr %c256_i64 : i64 -> !pto.ptr - - pto.mte_gm_ub %src, %ub_src, %c0_i64, %c256_i64 - nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 - - pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] - pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] - - pto.vecscope { - scf.for %i = %c0_i16 to %c2_i16 step %c1_i16 : i16 { - %addr = pto.vag %c4_i32 : i32 -> !pto.vaddr - %mask = pto.pset_b32 "PAT_ALL" : !pto.mask - %value = pto.vald %ub_src[%addr] {dist = "NORM"} - : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> - pto.vast %value, %ub_dst[%addr], %mask {dist = "NORM_B32"} - : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask - } - } - - pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.mte_ub_gm %ub_dst, %dst, %c256_i64 - nburst(%c1_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64 - pto.barrier #pto.pipe - return - } -} diff --git a/test/vpto/cases/micro-op/vector-address/vald-vast/main.cpp b/test/vpto/cases/micro-op/vector-address/vald-vast/main.cpp deleted file mode 100644 index 2e8fd28e8e..0000000000 --- a/test/vpto/cases/micro-op/vector-address/vald-vast/main.cpp +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#include "acl/acl.h" -#include "test_common.h" - -#include -#include - -using namespace PtoTestCommon; - -#define ACL_CHECK(expr) \ - do { \ - const aclError _ret = (expr); \ - if (_ret != ACL_SUCCESS) { \ - std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ - (int)_ret, __FILE__, __LINE__); \ - const char *_recent = aclGetRecentErrMsg(); \ - if (_recent != nullptr && _recent[0] != '\0') \ - std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ - rc = 1; \ - goto cleanup; \ - } \ - } while (0) - -void LaunchVectorAddressValdVast(float *src, float *dst, void *stream); - -int main() { - constexpr size_t kElements = 64; - constexpr size_t kBytes = kElements * sizeof(float); - - float *srcHost = nullptr; - float *dstHost = nullptr; - float *srcDevice = nullptr; - float *dstDevice = nullptr; - - int rc = 0; - bool aclInited = false; - bool deviceSet = false; - int deviceId = 0; - aclrtStream stream = nullptr; - size_t srcFileSize = kBytes; - size_t dstFileSize = kBytes; - - ACL_CHECK(aclInit(nullptr)); - aclInited = true; - if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) - deviceId = std::atoi(envDevice); - ACL_CHECK(aclrtSetDevice(deviceId)); - deviceSet = true; - ACL_CHECK(aclrtCreateStream(&stream)); - - ACL_CHECK(aclrtMallocHost((void **)(&srcHost), kBytes)); - ACL_CHECK(aclrtMallocHost((void **)(&dstHost), kBytes)); - ACL_CHECK(aclrtMalloc((void **)&srcDevice, kBytes, ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc((void **)&dstDevice, kBytes, ACL_MEM_MALLOC_HUGE_FIRST)); - - if (!ReadFile("./v1.bin", srcFileSize, srcHost, kBytes) || - srcFileSize != kBytes) { - std::fprintf(stderr, "[ERROR] failed to read v1.bin\n"); - rc = 1; - goto cleanup; - } - if (!ReadFile("./v2.bin", dstFileSize, dstHost, kBytes) || - dstFileSize != kBytes) { - std::fprintf(stderr, "[ERROR] failed to read v2.bin\n"); - rc = 1; - goto cleanup; - } - ACL_CHECK(aclrtMemcpy(srcDevice, kBytes, srcHost, kBytes, - ACL_MEMCPY_HOST_TO_DEVICE)); - ACL_CHECK(aclrtMemcpy(dstDevice, kBytes, dstHost, kBytes, - ACL_MEMCPY_HOST_TO_DEVICE)); - - LaunchVectorAddressValdVast(srcDevice, dstDevice, stream); - - ACL_CHECK(aclrtSynchronizeStream(stream)); - ACL_CHECK(aclrtMemcpy(dstHost, kBytes, dstDevice, kBytes, - ACL_MEMCPY_DEVICE_TO_HOST)); - WriteFile("./v2.bin", dstHost, kBytes); - -cleanup: - aclrtFree(srcDevice); - aclrtFree(dstDevice); - aclrtFreeHost(srcHost); - aclrtFreeHost(dstHost); - if (stream != nullptr) - aclrtDestroyStream(stream); - if (deviceSet) - aclrtResetDevice(deviceId); - if (aclInited) - aclFinalize(); - return rc; -} diff --git a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/kernel.pto b/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/kernel.pto deleted file mode 100644 index 6cca419674..0000000000 --- a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/kernel.pto +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { - func.func @vector_address_x2_predicate_unaligned(%src: !pto.ptr, - %dst: !pto.ptr) attributes {pto.kernel} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c256 = arith.constant 256 : index - %c512 = arith.constant 512 : index - %c0_i16 = arith.constant 0 : i16 - %c1_i16 = arith.constant 1 : i16 - %c2_i16 = arith.constant 2 : i16 - %c0_i64 = arith.constant 0 : i64 - %c1_i64 = arith.constant 1 : i64 - %c16_i64 = arith.constant 16 : i64 - %c256_i64 = arith.constant 256 : i64 - %c4096_i64 = arith.constant 4096 : i64 - %c8192_i64 = arith.constant 8192 : i64 - %c4_i32 = arith.constant 4 : i32 - %c256_i32 = arith.constant 256 : i32 - - %ub_in_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr - %ub_out_u8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr - %ub_in_f32 = pto.castptr %c0_i64 : i64 -> !pto.ptr - %ub_out_f32 = pto.castptr %c4096_i64 : i64 -> !pto.ptr - %ub_pred = pto.castptr %c8192_i64 : i64 -> !pto.ptr - - pto.mte_gm_ub %src, %ub_in_u8, %c0_i64, %c256_i64 - nburst(%c16_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 - pto.mte_gm_ub %dst, %ub_out_u8, %c0_i64, %c256_i64 - nburst(%c16_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 - - pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] - pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] - - pto.vecscope { - scf.for %i = %c0_i16 to %c2_i16 step %c1_i16 : i16 { - %addr = pto.vag %c4_i32 : i32 -> !pto.vaddr - %mask = pto.pset_b32 "PAT_ALL" : !pto.mask - - %low, %high = pto.valdx2 %ub_in_f32[%addr], "DINTLV_B32" - : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32>, !pto.vreg<64xf32> - pto.vastx2 %low, %high, %ub_out_f32[%addr], "INTLV_B32", %mask - : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask - - pto.past %mask, %ub_pred[%addr], "NORM" - : !pto.mask, !pto.ptr, !pto.vaddr - pto.mem_bar "VST_VLD" - %loaded_mask = pto.pald %ub_pred[%addr], "NORM" - : !pto.ptr, !pto.vaddr -> !pto.mask - %pred_src = pto.addptr %ub_in_f32, %c256 : !pto.ptr -> !pto.ptr - %pred_dst = pto.addptr %ub_out_f32, %c256 : !pto.ptr -> !pto.ptr - %pred_value = pto.vald %pred_src[%addr] {dist = "NORM"} - : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> - pto.vast %pred_value, %pred_dst[%addr], %loaded_mask {dist = "NORM_B32"} - : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask - - %unaligned_base = pto.addptr %ub_in_f32, %c512 : !pto.ptr -> !pto.ptr - %unaligned_src = pto.addptr %unaligned_base, %c1 : !pto.ptr -> !pto.ptr - %unaligned_dst = pto.addptr %ub_out_f32, %c512 : !pto.ptr -> !pto.ptr - %load_align = pto.valda %unaligned_src[%addr] - : !pto.ptr, !pto.vaddr -> !pto.align - %unaligned, %next_load_align, %next_load_addr = pto.valdu %unaligned_src[%addr], %load_align, %c256_i32 - : !pto.ptr, !pto.vaddr, !pto.align, i32 -> !pto.vreg<64xf32>, !pto.align, !pto.vaddr - %store_align = pto.init_align : !pto.align - %store_addr = pto.vag %c4_i32 : i32 -> !pto.vaddr - %next_store_align, %next_store_addr = pto.vastu %store_align, %store_addr, %unaligned, %unaligned_dst, "POST_UPDATE" - : !pto.align, !pto.vaddr, !pto.vreg<64xf32>, !pto.ptr -> !pto.align, !pto.vaddr - pto.vasta %next_store_align, %unaligned_dst[%next_store_addr] - : !pto.align, !pto.ptr, !pto.vaddr - } - } - - pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - - pto.mte_ub_gm %ub_out_u8, %dst, %c256_i64 - nburst(%c16_i64, %c256_i64, %c256_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64 - pto.barrier #pto.pipe - return - } -} diff --git a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/launch.cpp b/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/launch.cpp deleted file mode 100644 index b5fa8db339..0000000000 --- a/test/vpto/cases/micro-op/vector-address/x2-predicate-unaligned/launch.cpp +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#ifndef __VEC_SCOPE__ -#define __VEC_SCOPE__ -#endif - -#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) -typedef struct { unsigned char v; } hifloat8_t; -typedef struct { unsigned char v; } float8_e4m3_t; -typedef struct { unsigned char v; } float8_e5m2_t; -typedef struct { unsigned char v; } float8_e8m0_t; -typedef struct { unsigned char v; } float4_e1m2x2_t; -typedef struct { unsigned char v; } float4_e2m1x2_t; -#endif - -#include - -#ifndef __CPU_SIM -#include "acl/acl.h" -#endif - -extern "C" __global__ [aicore] void -vector_address_x2_predicate_unaligned(__gm__ uint8_t *src, __gm__ uint8_t *dst); - -void LaunchVectorAddressX2PredicateUnaligned(uint8_t *src, uint8_t *dst, - void *stream) { - vector_address_x2_predicate_unaligned<<<1, nullptr, stream>>>( - (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst); -} From 97fd98f41f033d99f927c0764316765896d77bcd Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Sat, 27 Jun 2026 16:52:55 +0800 Subject: [PATCH 3/4] Clarify vaddr stride ordering --- docs/isa/micro-isa/18-vaddr-loop-memory.md | 40 +++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/docs/isa/micro-isa/18-vaddr-loop-memory.md b/docs/isa/micro-isa/18-vaddr-loop-memory.md index d7b2c292c0..dc1a75bf86 100644 --- a/docs/isa/micro-isa/18-vaddr-loop-memory.md +++ b/docs/isa/micro-isa/18-vaddr-loop-memory.md @@ -70,6 +70,11 @@ scf.for %i = %c0 to %c2 step %c1 { `pto.vag` does not create or outline a loop. The surrounding loop must already exist in the PTO IR. +For nested loops, the active loop stack is determined at the lexical position of +the `pto.vag` op. The immediately enclosing `scf.for` is the innermost active +loop for that `pto.vag`; each enclosing `scf.for` outside it is the next outer +active loop. + --- ## Address Generation @@ -82,13 +87,31 @@ exist in the PTO IR. - **syntax:** `%addr = pto.vag %s0, %s1, %s2, %s3 : i32, i32, i32, i32 -> !pto.vaddr` - **semantics:** Create a vector-address offset value for the surrounding loop. - **inputs:** - `%s0` ... `%s3` are byte strides for active loop dimensions. + `%s0` ... `%s3` are byte strides for active loop dimensions. The operands are + ordered from inner loop to outer loop: + `%s0` applies to the immediately enclosing loop, `%s1` applies to the next + outer loop, `%s2` applies to the next outer loop after that, and `%s3` applies + to the fourth active loop. - **outputs:** `%addr` is a `!pto.vaddr` offset token. - **constraints and limitations:** `pto.vag` takes one to four `i32` byte-stride operands. It MUST be nested in an `i16` `scf.for`. The result granularity MUST be `b8`, `b16`, or `b32`. +For a `pto.vag` nested under active loop counters `i0, i1, i2, i3`, listed from +inner to outer at the `pto.vag` location, the logical offset is: + +```text +addr = i0 * s0 + + i1 * s1 + + i2 * s2 + + i3 * s3 +``` + +If fewer than four stride operands are present, only the corresponding innermost +active loop counters participate in the address. For example, `pto.vag %s0, +%s1` uses the immediately enclosing loop and its next outer loop. + **Example:** ```mlir @@ -96,6 +119,21 @@ exist in the PTO IR. %addr = pto.vag %stride : i32 -> !pto.vaddr ``` +Four nested loops: + +```mlir +scf.for %k = %c0_i16 to %k_bound step %c1_i16 : i16 { + scf.for %l = %c0_i16 to %l_bound step %c1_i16 : i16 { + scf.for %m = %c0_i16 to %m_bound step %c1_i16 : i16 { + scf.for %n = %c0_i16 to %n_bound step %c1_i16 : i16 { + %addr = pto.vag %n_stride, %m_stride, %l_stride, %k_stride + : i32, i32, i32, i32 -> !pto.vaddr + } + } + } +} +``` + --- ## Vector-Address Vector Loads From fb23300d22c710b6a25f2b9733de7730f3d70be5 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Tue, 30 Jun 2026 20:03:40 +0800 Subject: [PATCH 4/4] Add structured vaddr loop op --- docs/isa/micro-isa/18-vaddr-loop-memory.md | 62 +++++++++- include/PTO/IR/VPTOOps.td | 23 ++++ lib/PTO/IR/VPTO.cpp | 117 ++++++++++++++++++ lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 97 ++++++++++++++- ...ctor_address_vaddr_loop_verify_invalid.pto | 24 ++++ .../vector_address_vaddr_loop_vpto_llvm.pto | 62 ++++++++++ .../vector_address_vaddr_verify_invalid.pto | 24 ++++ 7 files changed, 406 insertions(+), 3 deletions(-) create mode 100644 test/lit/vpto/vector_address_vaddr_loop_verify_invalid.pto create mode 100644 test/lit/vpto/vector_address_vaddr_loop_vpto_llvm.pto create mode 100644 test/lit/vpto/vector_address_vaddr_verify_invalid.pto diff --git a/docs/isa/micro-isa/18-vaddr-loop-memory.md b/docs/isa/micro-isa/18-vaddr-loop-memory.md index dc1a75bf86..d8f875086e 100644 --- a/docs/isa/micro-isa/18-vaddr-loop-memory.md +++ b/docs/isa/micro-isa/18-vaddr-loop-memory.md @@ -9,7 +9,8 @@ progression that is tied to the loop counter rather than to scalar pointer arithmetic. This chapter documents the PTO surface contract for `!pto.vaddr` and the -`pto.vag` / `pto.vald` / `pto.vast` / `pto.pald` / `pto.past` families. +`pto.vaddr_loop` / `pto.vag` / `pto.vald` / `pto.vast` / `pto.pald` / +`pto.past` families. --- @@ -77,6 +78,65 @@ active loop. --- +## Structured Vector-Address Loops + +### `pto.vaddr_loop` + +- **syntax:** `pto.vaddr_loop bounds(%ub0[, %ub1[, %ub2[, %ub3]]]) vaddr(%addr = strides(%s0[, %s1[, %s2[, %s3]]]) : !pto.vaddr[, ...]) { ... }` +- **semantics:** Execute the body in one to four canonical nested + vector-address loops. Each vector-address value declared by the `vaddr(...)` + header is available by name inside the loop body. +- **inputs:** + `%ub0` ... `%ub3` are `i16` upper bounds ordered from outer loop to inner + loop. Each `strides(...)` list contains `i32` byte strides ordered in the same + order as the bounds: outer loop to inner loop. +- **outputs:** + Each `%addr` declared in the header is a `!pto.vaddr` offset token scoped + to the loop body. +- **constraints and limitations:** + `pto.vaddr_loop` takes one to four `i16` upper bound operands. Each + `strides(...)` list MUST have exactly the same number of operands as + `bounds(...)`. The loop lower bound is `0` and the step is `1`. Use a zero + stride for a loop dimension that should not contribute to an address. + +For a four-dimensional loop: + +```mlir +pto.vaddr_loop bounds(%K, %L, %M, %N) + vaddr(%addr = strides(%k_stride, %l_stride, %m_stride, %n_stride) + : !pto.vaddr) { +} +``` + +the logical offset is: + +```text +addr = k * k_stride + + l * l_stride + + m * m_stride + + n * n_stride +``` + +Multiple vector addresses may be declared in the same loop. Each one carries its +own stride list: + +```mlir +pto.vaddr_loop bounds(%K, %L, %M, %N) + vaddr(%src_addr = strides(%src_k, %src_l, %src_m, %src_n) + : !pto.vaddr, + %dst_addr = strides(%dst_k, %dst_l, %dst_m, %dst_n) + : !pto.vaddr) { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + + %value = pto.vald %src[%src_addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %value, %dst[%dst_addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask +} +``` + +--- + ## Address Generation ### `pto.vag` diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td index 14d34d67bf..85f48f52ff 100644 --- a/include/PTO/IR/VPTOOps.td +++ b/include/PTO/IR/VPTOOps.td @@ -1361,6 +1361,29 @@ def PTO_VagOp : PTO_Op<"vag"> { }]; } +def PTO_VAddrLoopOp : PTO_Op<"vaddr_loop", [ + SingleBlock, + NoTerminator, + AttrSizedOperandSegments + ]> { + let summary = "Structured vector-address loop"; + let description = [{ + `pto.vaddr_loop` represents one to four canonical nested vector-address + loops. The op carries only upper bounds. Each vector-address value declared + in the header has one byte stride per loop dimension and is made available + as a body region argument. + }]; + + let arguments = (ins + Variadic:$bounds, + Variadic:$strides + ); + let regions = (region SizedRegion<1>:$body); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + def PTO_ValdOp : PTO_Op<"vald", [ DeclareOpInterfaceMethods ]> { diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 02a533ef94..565f759c34 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -4732,6 +4732,123 @@ LogicalResult VagOp::verify() { return success(); } +ParseResult VAddrLoopOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector bounds; + SmallVector strides; + SmallVector regionArgs; + + if (parser.parseKeyword("bounds") || parser.parseLParen() || + parser.parseOperandList(bounds) || parser.parseRParen()) + return failure(); + if (bounds.empty() || bounds.size() > 4) + return parser.emitError(parser.getCurrentLocation(), + "requires one to four upper bound operands"); + + if (parser.parseKeyword("vaddr") || parser.parseLParen()) + return failure(); + + if (succeeded(parser.parseOptionalRParen())) + return parser.emitError(parser.getCurrentLocation(), + "requires at least one vaddr declaration"); + + do { + OpAsmParser::Argument arg; + SmallVector group; + Type addrType; + if (parser.parseOperand(arg.ssaName) || parser.parseEqual() || + parser.parseKeyword("strides") || parser.parseLParen() || + parser.parseOperandList(group) || parser.parseRParen() || + parser.parseColonType(addrType)) + return failure(); + if (group.size() != bounds.size()) { + return parser.emitError(parser.getCurrentLocation()) + << "requires each vaddr stride list to contain " << bounds.size() + << " operands"; + } + arg.type = addrType; + regionArgs.push_back(arg); + llvm::append_range(strides, group); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen()) + return failure(); + + Type i16Type = parser.getBuilder().getIntegerType(16); + Type i32Type = parser.getBuilder().getIntegerType(32); + SmallVector boundTypes(bounds.size(), i16Type); + SmallVector strideTypes(strides.size(), i32Type); + if (parser.resolveOperands(bounds, boundTypes, parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(strides, strideTypes, parser.getCurrentLocation(), + result.operands)) + return failure(); + + auto &segments = + result.getOrAddProperties().operandSegmentSizes; + llvm::copy(ArrayRef{static_cast(bounds.size()), + static_cast(strides.size())}, + segments.begin()); + + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + +void VAddrLoopOp::print(OpAsmPrinter &printer) { + printer << " bounds("; + printer.printOperands(getBounds()); + printer << ") vaddr("; + + unsigned rank = getBounds().size(); + ValueRange strides = getStrides(); + Block &bodyBlock = getBody().front(); + for (auto [index, arg] : llvm::enumerate(bodyBlock.getArguments())) { + if (index != 0) + printer << ", "; + printer << arg << " = strides("; + printer.printOperands(strides.slice(index * rank, rank)); + printer << ") : " << arg.getType(); + } + printer << ") "; + printer.printRegion(getBody(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + printer.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +LogicalResult VAddrLoopOp::verify() { + if (getBounds().empty() || getBounds().size() > 4) + return emitOpError("requires one to four i16 upper bound operands"); + for (Value bound : getBounds()) { + if (!bound.getType().isInteger(16)) + return emitOpError("requires all upper bound operands to be i16"); + } + for (Value stride : getStrides()) { + if (!stride.getType().isInteger(32)) + return emitOpError("requires all stride operands to be i32"); + } + + Block &bodyBlock = getBody().front(); + if (bodyBlock.getNumArguments() == 0) + return emitOpError("requires at least one vaddr region argument"); + + size_t rank = getBounds().size(); + size_t expectedStrides = rank * bodyBlock.getNumArguments(); + if (getStrides().size() != expectedStrides) { + return emitOpError() + << "requires stride operand count to equal rank (" << rank + << ") times vaddr count (" << bodyBlock.getNumArguments() << ")"; + } + for (BlockArgument arg : bodyBlock.getArguments()) { + if (failed(verifyVAddrTypeLike(*this, arg.getType(), "body argument type"))) + return failure(); + } + return success(); +} + void ValdOp::getEffects( SmallVectorImpl> &effects) { diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index ef3e0e1f1e..17dbbbf659 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -29,6 +29,7 @@ #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -282,6 +283,11 @@ static Value getI32Constant(OpBuilder &builder, Location loc, uint64_t value) { .getResult(); } +static Value getI16Constant(OpBuilder &builder, Location loc, uint64_t value) { + return builder.create(loc, builder.getI16IntegerAttr(value)) + .getResult(); +} + [[maybe_unused]] static Value getI1Constant(OpBuilder &builder, Location loc, bool value) { return builder @@ -6634,6 +6640,91 @@ class LowerVagOpPattern final : public OpConversionPattern { LoweringState &state; }; +class LowerVAddrLoopOpPattern final + : public OpConversionPattern { +public: + explicit LowerVAddrLoopOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VAddrLoopOp op, pto::VAddrLoopOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange bounds = adaptor.getBounds(); + if (bounds.empty() || bounds.size() > 4) + return rewriter.notifyMatchFailure(op, "unexpected vaddr_loop rank"); + + Location loc = op.getLoc(); + Value lower = getI16Constant(rewriter, loc, 0); + Value step = getI16Constant(rewriter, loc, 1); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + SmallVector vaddrInsertionAnchors; + vaddrInsertionAnchors.reserve(bounds.size()); + for (Value bound : bounds) { + auto forOp = rewriter.create(loc, lower, bound, step); + if (!vaddrInsertionAnchors.empty()) + vaddrInsertionAnchors.back() = forOp; + vaddrInsertionAnchors.push_back(forOp.getBody()->getTerminator()); + rewriter.setInsertionPoint(forOp.getBody()->getTerminator()); + } + + IRMapping mapping; + ValueRange strides = adaptor.getStrides(); + Block &bodyBlock = op.getBody().front(); + size_t rank = bounds.size(); + if (strides.size() != rank * bodyBlock.getNumArguments()) + return rewriter.notifyMatchFailure(op, "unexpected vaddr stride count"); + + StringRef calleeName = buildVagCallee(op.getContext()); + for (auto [index, arg] : llvm::enumerate(bodyBlock.getArguments())) { + Type resultType = this->getTypeConverter()->convertType(arg.getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vaddr type"); + + ValueRange group = strides.slice(index * rank, rank); + unsigned activeDepth = 1; + for (auto [strideIndex, stride] : llvm::enumerate(group)) { + if (!matchPattern(stride, m_Zero())) + activeDepth = strideIndex + 1; + } + + rewriter.setInsertionPoint(vaddrInsertionAnchors[activeDepth - 1]); + SmallVector args; + args.reserve(4); + for (Value stride : llvm::reverse(group.take_front(activeDepth))) + args.push_back(stride); + if (args.size() < 4) { + Value zeroValue = getI32Constant(rewriter, loc, 0); + while (args.size() < 4) + args.push_back(zeroValue); + } + + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI32Type(), rewriter.getI32Type(), + rewriter.getI32Type(), rewriter.getI32Type()}, + TypeRange{resultType}); + auto call = rewriter.create(loc, calleeName, + TypeRange{resultType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + mapping.map(arg, call.getResult(0)); + } + + rewriter.setInsertionPoint(vaddrInsertionAnchors.back()); + for (Operation &nested : bodyBlock.getOperations()) + rewriter.clone(nested, mapping); + + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + class LowerValdOpPattern final : public OpConversionPattern { public: explicit LowerValdOpPattern(TypeConverter &typeConverter, MLIRContext *context, @@ -10485,7 +10576,8 @@ static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, LowerRuntimeQueryOpPattern, LowerRuntimeQueryOpPattern, LowerRuntimeQueryOpPattern, - LowerVldsOpPattern, LowerVagOpPattern, LowerValdOpPattern, + LowerVldsOpPattern, LowerVagOpPattern, LowerVAddrLoopOpPattern, + LowerValdOpPattern, LowerValdx2OpPattern, LowerVldsx2OpPattern, LowerVsldbOpPattern, LowerVldasOpPattern, LowerValdaOpPattern, LowerInitAlignOpPattern, @@ -10587,7 +10679,8 @@ static void configureVPTOOpLoweringTarget(ConversionTarget &target, pto::StoreVfSimtInfoOp, pto::SetMovPadValOp, pto::SetQuantPreOp>(); target.addIllegalOp(); - target.addIllegalOp&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vaddr_stride_count_matches_rank() attributes {pto.kernel} { + %c2_i16 = arith.constant 2 : i16 + %c4_i32 = arith.constant 4 : i32 + pto.vecscope { + pto.vaddr_loop bounds(%c2_i16, %c2_i16) + vaddr(%addr = strides(%c4_i32) : !pto.vaddr) { + } + } + return + } +} + +// CHECK: requires each vaddr stride list to contain 2 operands diff --git a/test/lit/vpto/vector_address_vaddr_loop_vpto_llvm.pto b/test/lit/vpto/vector_address_vaddr_loop_vpto_llvm.pto new file mode 100644 index 0000000000..d3e3f54eca --- /dev/null +++ b/test/lit/vpto/vector_address_vaddr_loop_vpto_llvm.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ( source /usr/local/Ascend/cann/set_env.sh >/dev/null 2>&1 || true; ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vaddr_loop_ops(%src_f32: !pto.ptr, %dst_f32: !pto.ptr) attributes {pto.kernel} { + %c0_i16 = arith.constant 0 : i16 + %c1_i16 = arith.constant 1 : i16 + %c2_i16 = arith.constant 2 : i16 + %c3_i16 = arith.constant 3 : i16 + %c4_i16 = arith.constant 4 : i16 + %c5_i16 = arith.constant 5 : i16 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c96_i32 = arith.constant 96 : i32 + %c128_i32 = arith.constant 128 : i32 + pto.vecscope { + pto.vaddr_loop bounds(%c2_i16, %c3_i16, %c4_i16, %c5_i16) + vaddr(%outer_addr = strides(%c128_i32, %c96_i32, %c64_i32, %c0_i32) + : !pto.vaddr, + %inner_addr = strides(%c0_i32, %c96_i32, %c64_i32, %c32_i32) + : !pto.vaddr) { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %loaded = pto.vald %src_f32[%outer_addr] {dist = "NORM"} + : !pto.ptr, !pto.vaddr -> !pto.vreg<64xf32> + pto.vast %loaded, %dst_f32[%inner_addr], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.vaddr, !pto.mask + } + } + return + } +} + +// ROUNDTRIP-LABEL: func.func @vaddr_loop_ops( +// ROUNDTRIP-NOT: ^bb0 +// ROUNDTRIP: pto.vaddr_loop bounds(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) vaddr([[OUTER:%[^ ]+]] = strides(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : !pto.vaddr, [[INNER:%[^ ]+]] = strides(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : !pto.vaddr) { +// ROUNDTRIP: pto.vald %{{.*}}[[OUTER]] +// ROUNDTRIP: pto.vast %{{.*}}, %{{.*}}[[INNER]] + +// CHECK-LABEL: llvm.func @vaddr_loop_ops_mix_aiv +// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK-DAG: %[[C64:.*]] = llvm.mlir.constant(64 : i32) : i32 +// CHECK-DAG: %[[C96:.*]] = llvm.mlir.constant(96 : i32) : i32 +// CHECK-DAG: %[[C128:.*]] = llvm.mlir.constant(128 : i32) : i32 +// CHECK: ^bb{{[0-9]+}}: +// CHECK: %[[PAD0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[OUTER_ADDR:.*]] = llvm.call @llvm.hivm.vag.32(%[[C64]], %[[C96]], %[[C128]], %[[PAD0]]) +// CHECK-SAME: : (i32, i32, i32, i32) -> vector<1xi32> +// CHECK-NEXT: llvm.br +// CHECK: %[[INNER_ADDR:.*]] = llvm.call @llvm.hivm.vag.32(%[[C32]], %[[C64]], %[[C96]], %[[C0]]) +// CHECK-SAME: : (i32, i32, i32, i32) -> vector<1xi32> +// CHECK: llvm.call @llvm.hivm.vldx1.v64f32(%{{.*}}, %[[OUTER_ADDR]], +// CHECK: llvm.call @llvm.hivm.vstx1.v64f32(%{{.*}}, %{{.*}}, %[[INNER_ADDR]], diff --git a/test/lit/vpto/vector_address_vaddr_verify_invalid.pto b/test/lit/vpto/vector_address_vaddr_verify_invalid.pto new file mode 100644 index 0000000000..94577e7501 --- /dev/null +++ b/test/lit/vpto/vector_address_vaddr_verify_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vaddr_loop_requires_vaddr_type() attributes {pto.kernel} { + %c2_i16 = arith.constant 2 : i16 + %c4_i32 = arith.constant 4 : i32 + pto.vecscope { + pto.vaddr_loop bounds(%c2_i16) + vaddr(%addr = strides(%c4_i32) : i32) { + } + } + return + } +} + +// CHECK: body argument type must be !pto.vaddr<...>